Skip to content

Commit

Permalink
Refactored multiclass machines to support C parameter selection
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Aug 22, 2012
1 parent 2f999e4 commit 90d4424
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/shogun/classifier/mkl/MKLMulticlass.cpp
Expand Up @@ -77,7 +77,7 @@ void CMKLMulticlass::initsvm()
svm=new CGMNPSVM;
SG_REF(svm);

svm->set_C(get_C1(),get_C2());
svm->set_C(get_C());
svm->set_epsilon(get_epsilon());

if (m_labels->get_num_labels()<=0)
Expand Down
1 change: 1 addition & 0 deletions src/shogun/machine/KernelMulticlassMachine.h
Expand Up @@ -65,6 +65,7 @@ class CKernelMulticlassMachine : public CMulticlassMachine
*/
void set_kernel(CKernel* k)
{
((CKernelMachine*)m_machine)->set_kernel(k);
SG_REF(k);
SG_UNREF(m_kernel);
m_kernel=k;
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/GMNPSVM.cpp
Expand Up @@ -80,7 +80,7 @@ bool CGMNPSVM::train_machine(CFeatures* data)

}

float64_t C = get_C1();
float64_t C = get_C();
int32_t tmax = 1000000000;
float64_t tolabs = 0;
float64_t tolrel = get_epsilon();
Expand Down
12 changes: 6 additions & 6 deletions src/shogun/multiclass/LaRank.cpp
Expand Up @@ -632,7 +632,7 @@ bool CLaRank::train_machine(CFeatures* data)
float64_t gap = DBL_MAX;

SG_INFO("Training on %d examples\n", nb_train);
while (gap > get_C1() && (!CSignal::cancel_computations())) // stopping criteria
while (gap > get_C() && (!CSignal::cancel_computations())) // stopping criteria
{
float64_t tr_err = 0;
int32_t ind = step;
Expand All @@ -653,7 +653,7 @@ bool CLaRank::train_machine(CFeatures* data)
SG_DEBUG("End of iteration %d\n", n_it++);
SG_DEBUG("Train error (online): %f%%\n", (tr_err / nb_train) * 100);
gap = computeGap ();
SG_ABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(DBL_MAX), -CMath::log10(get_C1()), 6);
SG_ABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(DBL_MAX), -CMath::log10(get_C()), 6);

if (!batch_mode) // skip stopping criteria if online mode
gap = 0;
Expand Down Expand Up @@ -825,7 +825,7 @@ float64_t CLaRank::computeGap ()
}
sum_sl += CMath::max (0.0, gi - gmin);
}
return CMath::max (0.0, computeW2 () + get_C1() * sum_sl - sum_bi);
return CMath::max (0.0, computeW2 () + get_C() * sum_sl - sum_bi);
}

// Nuber of classes so far
Expand Down Expand Up @@ -937,7 +937,7 @@ CLaRank::process_return_t CLaRank::process (const LaRankPattern & pattern, proce
bool support = ptype == processOptimize || output->isSupportVector (pattern.x_id);
bool goodclass = current.output == pattern.y;
if ((!support && goodclass) ||
(support && output->getBeta (pattern.x_id) < (goodclass ? get_C1() : 0)))
(support && output->getBeta (pattern.x_id) < (goodclass ? get_C() : 0)))
{
ygp = current;
outp = output;
Expand Down Expand Up @@ -986,12 +986,12 @@ CLaRank::process_return_t CLaRank::process (const LaRankPattern & pattern, proce
{
float64_t beta = outp->getBeta (pattern.x_id);
if (ygp.output == pattern.y)
lambda = CMath::min (lambda, get_C1() - beta);
lambda = CMath::min (lambda, get_C() - beta);
else
lambda = CMath::min (lambda, fabs (beta));
}
else
lambda = CMath::min (lambda, get_C1());
lambda = CMath::min (lambda, get_C());

/*
** update the solution
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassLibSVM.cpp
Expand Up @@ -79,7 +79,7 @@ bool CMulticlassLibSVM::train_machine(CFeatures* data)
param.kernel=m_kernel;
param.cache_size = m_kernel->get_cache_size();
param.max_train_time = m_max_train_time;
param.C = get_C1();
param.C = get_C();
param.eps = get_epsilon();
param.p = 0.1;
param.shrinking = 1;
Expand Down
7 changes: 4 additions & 3 deletions src/shogun/multiclass/MulticlassSVM.cpp
Expand Up @@ -16,20 +16,20 @@
using namespace shogun;

CMulticlassSVM::CMulticlassSVM()
:CKernelMulticlassMachine(new CMulticlassOneVsRestStrategy(), NULL, new CSVM(0), NULL)
:CKernelMulticlassMachine(new CMulticlassOneVsRestStrategy(), NULL, new CSVM(0), NULL), m_C(0)
{
init();
}

CMulticlassSVM::CMulticlassSVM(CMulticlassStrategy *strategy)
:CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL)
:CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL), m_C(0)
{
init();
}

CMulticlassSVM::CMulticlassSVM(
CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab)
:CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab)
: CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab), m_C(C)
{
init();
}
Expand All @@ -40,6 +40,7 @@ CMulticlassSVM::~CMulticlassSVM()

void CMulticlassSVM::init()
{
SG_ADD(&m_C, "C", "C regularization constant",MS_AVAILABLE);
}

bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
Expand Down
19 changes: 10 additions & 9 deletions src/shogun/multiclass/MulticlassSVM.h
Expand Up @@ -107,15 +107,10 @@ class CMulticlassSVM : public CKernelMulticlassMachine
*/
float64_t get_nu() { return svm_proto()->get_nu(); }
// TODO remove if unnecessary here
/** get C1 of base SVM
* @return C1 of base SVM
/** get C of base SVM
* @return C of base SVM
*/
float64_t get_C1() { return svm_proto()->get_C1(); }
// TODO remove if unnecessary here
/** get C2 of base SVM
* @return C1 of base SVM
*/
float64_t get_C2() { return svm_proto()->get_C2(); }
float64_t get_C() { return m_C; }
// TODO remove if unnecessary here
/** get qpsize of base SVM
* @return qpsize of base SVM
Expand Down Expand Up @@ -163,7 +158,7 @@ class CMulticlassSVM : public CKernelMulticlassMachine
* @param c_neg C for negatives
* @param c_pos C for positives
*/
void set_C(float64_t c_neg, float64_t c_pos) { svm_proto()->set_C(c_neg, c_pos); }
void set_C(float64_t C) { svm_proto()->set_C(C,C); m_C = C; }
// TODO remove in unnecessary here
/** set epsilon value
* @param eps epsilon value
Expand Down Expand Up @@ -236,7 +231,13 @@ class CMulticlassSVM : public CKernelMulticlassMachine
}

private:

void init();

protected:

/** C regularization constant */
float64_t m_C;
};
}
#endif
6 changes: 3 additions & 3 deletions src/shogun/multiclass/ScatterSVM.cpp
Expand Up @@ -130,7 +130,7 @@ bool CScatterSVM::train_no_bias_libsvm()
}

int32_t weights_label[2]={-1,+1};
float64_t weights[2]={1.0,get_C2()/get_C1()};
float64_t weights[2]={1.0,get_C()/get_C()};

ASSERT(m_kernel && m_kernel->has_features());
ASSERT(m_kernel->get_num_vec_lhs()==problem.l);
Expand Down Expand Up @@ -224,7 +224,7 @@ bool CScatterSVM::train_no_bias_svmlight()
m_kernel->set_normalizer(n);
m_kernel->init_normalizer();

CSVMLightOneClass* light=new CSVMLightOneClass(get_C1(), m_kernel);
CSVMLightOneClass* light=new CSVMLightOneClass(get_C(), m_kernel);
light->set_linadd_enabled(false);
light->train();

Expand Down Expand Up @@ -264,7 +264,7 @@ bool CScatterSVM::train_testrule12()
}

int32_t weights_label[2]={-1,+1};
float64_t weights[2]={1.0,get_C2()/get_C1()};
float64_t weights[2]={1.0,get_C()/get_C()};

ASSERT(m_kernel && m_kernel->has_features());
ASSERT(m_kernel->get_num_vec_lhs()==problem.l);
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/ui/GUIClassifier.cpp
Expand Up @@ -468,7 +468,7 @@ bool CGUIClassifier::train_mkl_multiclass()
mkl->set_max_train_time(max_train_time);
mkl->set_tube_epsilon(svm_tube_epsilon);
mkl->set_nu(svm_nu);
mkl->set_C(svm_C1, svm_C2);
mkl->set_C(svm_C1);
mkl->set_qpsize(svm_qpsize);
mkl->set_shrinking_enabled(svm_use_shrinking);
mkl->set_linadd_enabled(svm_use_linadd);
Expand Down Expand Up @@ -589,7 +589,7 @@ bool CGUIClassifier::train_svm()
svm->set_max_train_time(max_train_time);
svm->set_tube_epsilon(svm_tube_epsilon);
svm->set_nu(svm_nu);
svm->set_C(svm_C1, svm_C2);
svm->set_C(svm_C1);
svm->set_qpsize(svm_qpsize);
svm->set_shrinking_enabled(svm_use_shrinking);
svm->set_linadd_enabled(svm_use_linadd);
Expand Down

0 comments on commit 90d4424

Please sign in to comment.