Skip to content

Commit

Permalink
A bunch of fixes for multiclass machine
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Apr 23, 2012
1 parent 1ccb9e4 commit eeaa056
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
Expand Up @@ -9,17 +9,17 @@

def classifier_multiclasslinearmachine_modular (fm_train_real=traindat,fm_test_real=testdat,label_train_multiclass=label_traindat,width=2.1,C=1,epsilon=1e-5):
from shogun.Features import RealFeatures, Labels
from shogun.Classifier import LibLinear, L2R_L2LOSS_SVC, LinearMulticlassMachine, ONE_VS_REST_STRATEGY, ONE_VS_ONE_STRATEGY
from shogun.Classifier import LibLinear, L2R_L2LOSS_SVC, LinearMulticlassMachine, MulticlassOneVsOneStrategy, MulticlassOneVsRestStrategy

feats_train = RealFeatures(fm_train_real)
feats_test = RealFeatures(fm_test_real)

labels = Labels(label_train_multiclass)

classifier = LibLinear(L2R_L2LOSS_SVC)
classifier.set_epsilon(epsilon)
classifier.set_bias_enabled(True)
mc_classifier = LinearMulticlassMachine(ONE_VS_ONE_STRATEGY, feats_train, classifier, labels)
mc_classifier = LinearMulticlassMachine(MulticlassOneVsOneStrategy(), feats_train, classifier, labels)

mc_classifier.train()
out = mc_classifier.apply().get_labels()
Expand Down
Expand Up @@ -10,7 +10,7 @@
def classifier_multiclassmachine_modular (fm_train_real=traindat,fm_test_real=testdat,label_train_multiclass=label_traindat,width=2.1,C=1,epsilon=1e-5):
from shogun.Features import RealFeatures, Labels
from shogun.Kernel import GaussianKernel
from shogun.Classifier import LibSVM, KernelMulticlassMachine, ONE_VS_REST_STRATEGY
from shogun.Classifier import LibSVM, KernelMulticlassMachine, MulticlassOneVsRestStrategy

feats_train=RealFeatures(fm_train_real)
feats_test=RealFeatures(fm_test_real)
Expand All @@ -20,7 +20,8 @@ def classifier_multiclassmachine_modular (fm_train_real=traindat,fm_test_real=te

classifier = LibSVM(C, kernel, labels)
classifier.set_epsilon(epsilon)
mc_classifier = KernelMulticlassMachine(ONE_VS_REST_STRATEGY,kernel,classifier,labels)
print labels.get_labels()
mc_classifier = KernelMulticlassMachine(MulticlassOneVsRestStrategy(),kernel,classifier,labels)
mc_classifier.train()

kernel.init(feats_train, feats_test)
Expand Down
21 changes: 14 additions & 7 deletions src/interfaces/modular/Classifier.i
Expand Up @@ -31,7 +31,6 @@
%rename(LPBoost) CLPBoost;
%rename(LPM) CLPM;
%rename(MPDSVM) CMPDSVM;
%rename(MulticlassSVM) CMulticlassSVM;
%rename(OnlineSVMSGD) COnlineSVMSGD;
%rename(OnlineLibLinear) COnlineLibLinear;
%rename(Perceptron) CPerceptron;
Expand All @@ -51,7 +50,6 @@
%rename(MKL) CMKL;
%rename(MKLClassification) CMKLClassification;
%rename(MKLOneClass) CMKLOneClass;
%rename(MKLMulticlass) CMKLMulticlass;
%rename(VowpalWabbit) CVowpalWabbit;
%rename(ConjugateIndex) CConjugateIndex;
#ifdef USE_SVMLIGHT
Expand All @@ -60,8 +58,13 @@
%rename(DomainAdaptationSVMLinear) CDomainAdaptationSVMLinear;
#endif //USE_SVMLIGHT

%rename(MulticlassStrategy) CMulticlassStrategy;
%rename(MulticlassOneVsRestStrategy) CMulticlassOneVsRestStrategy;
%rename(MulticlassOneVsOneStrategy) CMulticlassOneVsOneStrategy;
%rename(KernelMulticlassMachine) CKernelMulticlassMachine;
%rename(LinearMulticlassMachine) CLinearMulticlassMachine;
%rename(MulticlassSVM) CMulticlassSVM;
%rename(MKLMulticlass) CMKLMulticlass;

/* These functions return new Objects */
%newobject apply();
Expand All @@ -75,9 +78,6 @@
%include <shogun/machine/KernelMachine.h>
%include <shogun/machine/DistanceMachine.h>
%include <shogun/classifier/svm/SVM.h>
%include <shogun/machine/MulticlassMachine.h>
%include <shogun/machine/KernelMulticlassMachine.h>
%include <shogun/multiclass/MulticlassSVM.h>
%include <shogun/machine/LinearMachine.h>
%include <shogun/machine/OnlineLinearMachine.h>
%include <shogun/classifier/GaussianNaiveBayes.h>
Expand Down Expand Up @@ -107,13 +107,20 @@
%include <shogun/classifier/mkl/MKL.h>
%include <shogun/classifier/mkl/MKLClassification.h>
%include <shogun/classifier/mkl/MKLOneClass.h>
%include <shogun/classifier/mkl/MKLMulticlass.h>
%include <shogun/classifier/vw/VowpalWabbit.h>
%include <shogun/classifier/svm/DomainAdaptationSVMLinear.h>
%include <shogun/classifier/ConjugateIndex.h>
%include <shogun/machine/LinearMulticlassMachine.h>
%include <shogun/classifier/svm/NewtonSVM.h>

%include <shogun/multiclass/MulticlassStrategy.h>
%include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
%include <shogun/multiclass/MulticlassOneVsOneStrategy.h>
%include <shogun/machine/MulticlassMachine.h>
%include <shogun/machine/LinearMulticlassMachine.h>
%include <shogun/machine/KernelMulticlassMachine.h>
%include <shogun/multiclass/MulticlassSVM.h>
%include <shogun/classifier/mkl/MKLMulticlass.h>

#ifdef USE_SVMLIGHT

%ignore VERSION;
Expand Down
10 changes: 8 additions & 2 deletions src/interfaces/modular/Classifier_includes.i
Expand Up @@ -32,13 +32,19 @@
#include <shogun/classifier/mkl/MKL.h>
#include <shogun/classifier/mkl/MKLClassification.h>
#include <shogun/classifier/mkl/MKLOneClass.h>
#include <shogun/classifier/mkl/MKLMulticlass.h>
#include <shogun/classifier/vw/VowpalWabbit.h>
#include <shogun/classifier/ConjugateIndex.h>
#include <shogun/classifier/svm/NewtonSVM.h>

#include <shogun/multiclass/MulticlassStrategy.h>
#include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
#include <shogun/multiclass/MulticlassOneVsOneStrategy.h>
#include <shogun/machine/MulticlassMachine.h>
#include <shogun/machine/KernelMulticlassMachine.h>
#include <shogun/machine/LinearMulticlassMachine.h>
#include <shogun/classifier/svm/NewtonSVM.h>
#include <shogun/multiclass/MulticlassSVM.h>
#include <shogun/classifier/mkl/MKLMulticlass.h>

#ifdef USE_SVMLIGHT
#include <shogun/classifier/svm/SVMLight.h>
#include <shogun/classifier/svm/SVMLightOneClass.h>
Expand Down
4 changes: 4 additions & 0 deletions src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -20,6 +20,7 @@ CMulticlassMachine::CMulticlassMachine()
: CMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()),
m_machine(NULL), m_machines(new CDynamicObjectArray())
{
SG_REF(m_multiclass_strategy);
register_parameters();
}

Expand All @@ -29,6 +30,7 @@ CMulticlassMachine::CMulticlassMachine(
: CMachine(), m_multiclass_strategy(strategy),
m_machines(new CDynamicObjectArray())
{
SG_REF(strategy);
set_labels(labs);
SG_REF(machine);
m_machine = machine;
Expand Down Expand Up @@ -105,6 +107,8 @@ CLabels* CMulticlassMachine::apply()

bool CMulticlassMachine::train_machine(CFeatures* data)
{
ASSERT(m_multiclass_strategy);

if ( !data && !is_ready() )
SG_ERROR("Please provide training data.\n");
else
Expand Down

0 comments on commit eeaa056

Please sign in to comment.