Skip to content

Commit

Permalink
Added bias usage capability for crammer-singer liblinear
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Mar 8, 2012
1 parent cdd0218 commit db96f7c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
13 changes: 8 additions & 5 deletions src/shogun/classifier/svm/MulticlassLibLinear.cpp
Expand Up @@ -10,6 +10,7 @@


#include <shogun/classifier/svm/MulticlassLibLinear.h>
#include <shogun/classifier/svm/SVM_linear.h>
#include <shogun/mathematics/Math.h>

using namespace shogun;
Expand All @@ -24,13 +25,13 @@ bool CMulticlassLibLinear::train_machine(CFeatures* data)

problem mc_problem;
mc_problem.l = num_vectors;
mc_problem.n = m_features->get_dim_feature_space();
mc_problem.n = m_features->get_dim_feature_space()+1;
mc_problem.y = SG_MALLOC(int32_t, mc_problem.l);
for (int32_t i=0; i<num_vectors; i++)
mc_problem.y[i] = labels->get_int_label(i);

mc_problem.x = m_features;
mc_problem.use_bias = true;
mc_problem.use_bias = m_use_bias;

float64_t* w = SG_MALLOC(float64_t, mc_problem.n*num_classes);
float64_t* C = SG_MALLOC(float64_t, num_vectors);
Expand All @@ -45,10 +46,12 @@ bool CMulticlassLibLinear::train_machine(CFeatures* data)
for (int32_t i=0; i<num_classes; i++)
{
CLinearMachine* machine = new CLinearMachine();
SGVector<float64_t> cw(mc_problem.n);
for (int32_t j=0; j<mc_problem.n; j++)
float64_t* cw = SG_MALLOC(float64_t, mc_problem.n);
for (int32_t j=0; j<mc_problem.n-1; j++)
cw[j] = w[j*num_classes+i];
machine->set_w(cw);
machine->set_w(SGVector<float64_t>(cw,mc_problem.n-1));
CMath::display_vector(cw,mc_problem.n);
machine->set_bias(w[(mc_problem.n-1)*num_classes+i]);

m_machines[i] = machine;
}
Expand Down
20 changes: 19 additions & 1 deletion src/shogun/classifier/svm/MulticlassLibLinear.h
Expand Up @@ -13,7 +13,6 @@

#include <shogun/lib/common.h>
#include <shogun/features/DotFeatures.h>
#include <shogun/classifier/svm/SVM_linear.h>
#include <shogun/machine/multiclass/LinearMulticlassMachine.h>

namespace shogun
Expand All @@ -38,6 +37,7 @@ class CMulticlassLibLinear : public CLinearMulticlassMachine
{
set_epsilon(1e-2);
set_max_iter(10000);
set_use_bias(false);
}

/** destructor */
Expand Down Expand Up @@ -77,6 +77,21 @@ class CMulticlassLibLinear : public CLinearMulticlassMachine
*/
inline float64_t get_epsilon() const { return m_epsilon; }

/** set use bias
* @param use_bias use_bias value
*/
inline void set_use_bias(bool use_bias)
{
m_use_bias = use_bias;
}
/** get use bias
* @return use_bias value
*/
inline bool get_use_bias() const
{
return m_use_bias;
}

/** set max iter
* @param max_iter max iter value
*/
Expand Down Expand Up @@ -106,6 +121,9 @@ class CMulticlassLibLinear : public CLinearMulticlassMachine

/** max number of iterations */
int32_t m_max_iter;

/** use bias */
bool m_use_bias;
};
}
#endif
20 changes: 20 additions & 0 deletions src/shogun/classifier/svm/SVM_linear.cpp
Expand Up @@ -429,6 +429,8 @@ void Solver_MCSVM_CS::Solve(double *w)
alpha_index[i*nr_class+m] = m;

QD[i] = prob->x->dot(i, prob->x,i);
if (prob->use_bias)
QD[i] += 1.0;

active_size_i[i] = nr_class;
y_index[i] = prob->y[i];
Expand Down Expand Up @@ -468,6 +470,15 @@ void Solver_MCSVM_CS::Solve(double *w)
G[m] += w_i[alpha_index_i[m]]*(feature_value);
got_feature = prob->x->get_next_feature(feature_index,feature_value,feature_iter);
}
// experimental
// ***
if (prob->use_bias)
{
double *w_i = &w[(w_size-1)*nr_class];
for(m=0;m<active_size_i[i];m++)
G[m] += w_i[alpha_index_i[m]];
}
// ***
prob->x->free_feature_iterator(feature_iter);

double minG = CMath::INFTY;
Expand Down Expand Up @@ -547,6 +558,15 @@ void Solver_MCSVM_CS::Solve(double *w)
w_i[d_ind[m]] += d_val[m]*feature_value;
got_feature = prob->x->get_next_feature(feature_index,feature_value,feature_iter);
}
// experimental
// ***
if (prob->use_bias)
{
double *w_i = &w[(w_size-1)*nr_class];
for(m=0;m<nz_d;m++)
w_i[d_ind[m]] += d_val[m];
}
// ***
prob->x->free_feature_iterator(feature_iter);
}
}
Expand Down

0 comments on commit db96f7c

Please sign in to comment.