Skip to content

Commit

Permalink
Merge pull request #458 from pluskid/multiclass-refactor2
Browse files Browse the repository at this point in the history
Refactoring of multiclass
  • Loading branch information
Soeren Sonnenburg committed Apr 20, 2012
2 parents 39c82e6 + 681c0d5 commit 48b3ffd
Show file tree
Hide file tree
Showing 32 changed files with 906 additions and 1,074 deletions.
Expand Up @@ -8,7 +8,6 @@
parameter_list = [[traindat,testdat,label_traindat,0.9,1,2000],[traindat,testdat,label_traindat,3,1,5000]]

def classifier_larank_modular (fm_train_real=traindat,fm_test_real=testdat,label_train_multiclass=label_traindat,C=0.9,num_threads=1,num_iter=5):

from shogun.Features import RealFeatures, Labels
from shogun.Kernel import GaussianKernel
from shogun.Classifier import LaRank
Expand Down Expand Up @@ -36,5 +35,4 @@ def classifier_larank_modular (fm_train_real=traindat,fm_test_real=testdat,label

if __name__=='__main__':
print('LaRank')
classifier_larank_modular(*parameter_list[0])

[predictions, svm, labels] = classifier_larank_modular(*parameter_list[0])
16 changes: 8 additions & 8 deletions src/interfaces/modular/Classifier.i
Expand Up @@ -28,14 +28,14 @@
%rename(ScatterSVM) CScatterSVM;
%rename(LibSVM) CLibSVM;
%rename(LaRank) CLaRank;
%rename(LibSVMMultiClass) CLibSVMMultiClass;
%rename(LibSVMMultiClass) CLibSVMMulticlass;
%rename(LibSVMOneClass) CLibSVMOneClass;
%rename(LinearMachine) CLinearMachine;
%rename(OnlineLinearMachine) COnlineLinearMachine;
%rename(LPBoost) CLPBoost;
%rename(LPM) CLPM;
%rename(MPDSVM) CMPDSVM;
%rename(MultiClassSVM) CMultiClassSVM;
%rename(MulticlassSVM) CMulticlassSVM;
%rename(OnlineSVMSGD) COnlineSVMSGD;
%rename(OnlineLibLinear) COnlineLibLinear;
%rename(Perceptron) CPerceptron;
Expand All @@ -55,7 +55,7 @@
%rename(MKL) CMKL;
%rename(MKLClassification) CMKLClassification;
%rename(MKLOneClass) CMKLOneClass;
%rename(MKLMultiClass) CMKLMultiClass;
%rename(MKLMulticlass) CMKLMulticlass;
%rename(VowpalWabbit) CVowpalWabbit;
%rename(ConjugateIndex) CConjugateIndex;
#ifdef USE_SVMLIGHT
Expand All @@ -79,7 +79,9 @@
%include <shogun/machine/KernelMachine.h>
%include <shogun/machine/DistanceMachine.h>
%include <shogun/classifier/svm/SVM.h>
%include <shogun/classifier/svm/MultiClassSVM.h>
%include <shogun/machine/MulticlassMachine.h>
%include <shogun/machine/KernelMulticlassMachine.h>
%include <shogun/multiclass/MulticlassSVM.h>
%include <shogun/machine/LinearMachine.h>
%include <shogun/machine/OnlineLinearMachine.h>
%include <shogun/classifier/GaussianNaiveBayes.h>
Expand All @@ -93,7 +95,7 @@
%include <shogun/classifier/svm/ScatterSVM.h>
%include <shogun/classifier/svm/LibSVM.h>
%include <shogun/classifier/svm/LaRank.h>
%include <shogun/classifier/svm/LibSVMMultiClass.h>
%include <shogun/classifier/svm/LibSVMMulticlass.h>
%include <shogun/classifier/svm/LibSVMOneClass.h>
%include <shogun/classifier/LPBoost.h>
%include <shogun/classifier/LPM.h>
Expand All @@ -113,12 +115,10 @@
%include <shogun/classifier/mkl/MKL.h>
%include <shogun/classifier/mkl/MKLClassification.h>
%include <shogun/classifier/mkl/MKLOneClass.h>
%include <shogun/classifier/mkl/MKLMultiClass.h>
%include <shogun/classifier/mkl/MKLMulticlass.h>
%include <shogun/classifier/vw/VowpalWabbit.h>
%include <shogun/classifier/svm/DomainAdaptationSVMLinear.h>
%include <shogun/classifier/ConjugateIndex.h>
%include <shogun/machine/MulticlassMachine.h>
%include <shogun/machine/KernelMulticlassMachine.h>
%include <shogun/machine/LinearMulticlassMachine.h>
%include <shogun/classifier/svm/NewtonSVM.h>

Expand Down
6 changes: 3 additions & 3 deletions src/interfaces/modular/Classifier_includes.i
Expand Up @@ -11,14 +11,14 @@
#include <shogun/classifier/svm/ScatterSVM.h>
#include <shogun/classifier/svm/LibSVM.h>
#include <shogun/classifier/svm/LaRank.h>
#include <shogun/classifier/svm/LibSVMMultiClass.h>
#include <shogun/classifier/svm/LibSVMMulticlass.h>
#include <shogun/classifier/svm/LibSVMOneClass.h>
#include <shogun/machine/LinearMachine.h>
#include <shogun/machine/OnlineLinearMachine.h>
#include <shogun/classifier/LPBoost.h>
#include <shogun/classifier/LPM.h>
#include <shogun/classifier/svm/MPDSVM.h>
#include <shogun/classifier/svm/MultiClassSVM.h>
#include <shogun/multiclass/MulticlassSVM.h>
#include <shogun/classifier/svm/OnlineSVMSGD.h>
#include <shogun/classifier/svm/OnlineLibLinear.h>
#include <shogun/classifier/Perceptron.h>
Expand All @@ -37,7 +37,7 @@
#include <shogun/classifier/mkl/MKL.h>
#include <shogun/classifier/mkl/MKLClassification.h>
#include <shogun/classifier/mkl/MKLOneClass.h>
#include <shogun/classifier/mkl/MKLMultiClass.h>
#include <shogun/classifier/mkl/MKLMulticlass.h>
#include <shogun/classifier/vw/VowpalWabbit.h>
#include <shogun/classifier/ConjugateIndex.h>
#include <shogun/machine/MulticlassMachine.h>
Expand Down
Expand Up @@ -8,14 +8,14 @@
* Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
*/

#include <shogun/classifier/mkl/MKLMultiClass.h>
#include <shogun/classifier/mkl/MKLMulticlass.h>
#include <shogun/io/SGIO.h>

using namespace shogun;


CMKLMultiClass::CMKLMultiClass()
: CMultiClassSVM(ONE_VS_REST)
CMKLMulticlass::CMKLMulticlass()
: CMulticlassSVM(ONE_VS_REST_STRATEGY)
{
svm=NULL;
lpw=NULL;
Expand All @@ -25,8 +25,8 @@ CMKLMultiClass::CMKLMultiClass()
pnorm=1;
}

CMKLMultiClass::CMKLMultiClass(float64_t C, CKernel* k, CLabels* lab)
: CMultiClassSVM(ONE_VS_REST, C, k, lab)
CMKLMulticlass::CMKLMulticlass(float64_t C, CKernel* k, CLabels* lab)
: CMulticlassSVM(ONE_VS_REST_STRATEGY, C, k, lab)
{
svm=NULL;
lpw=NULL;
Expand All @@ -37,71 +37,71 @@ CMKLMultiClass::CMKLMultiClass(float64_t C, CKernel* k, CLabels* lab)
}


CMKLMultiClass::~CMKLMultiClass()
CMKLMulticlass::~CMKLMulticlass()
{
SG_UNREF(svm);
svm=NULL;
delete lpw;
lpw=NULL;
}

CMKLMultiClass::CMKLMultiClass( const CMKLMultiClass & cm)
: CMultiClassSVM(ONE_VS_REST)
CMKLMulticlass::CMKLMulticlass( const CMKLMulticlass & cm)
: CMulticlassSVM(ONE_VS_REST_STRATEGY)
{
svm=NULL;
lpw=NULL;
SG_ERROR(
" CMKLMultiClass::CMKLMultiClass(const CMKLMultiClass & cm): must "
" CMKLMulticlass::CMKLMulticlass(const CMKLMulticlass & cm): must "
"not be called, glpk structure is currently not copyable");
}

CMKLMultiClass CMKLMultiClass::operator=( const CMKLMultiClass & cm)
CMKLMulticlass CMKLMulticlass::operator=( const CMKLMulticlass & cm)
{
SG_ERROR(
" CMKLMultiClass CMKLMultiClass::operator=(...): must "
" CMKLMulticlass CMKLMulticlass::operator=(...): must "
"not be called, glpk structure is currently not copyable");
return (*this);
}


void CMKLMultiClass::initsvm()
void CMKLMulticlass::initsvm()
{
if (!m_labels)
{
SG_ERROR("CMKLMultiClass::initsvm(): the set labels is NULL\n");
SG_ERROR("CMKLMulticlass::initsvm(): the set labels is NULL\n");
}

SG_UNREF(svm);
svm=new CGMNPSVM;
SG_REF(svm);

svm->set_C(get_C1(),get_C2());
svm->set_epsilon(epsilon);
svm->set_epsilon(get_epsilon());

if (m_labels->get_num_labels()<=0)
{
SG_ERROR("CMKLMultiClass::initsvm(): the number of labels is "
SG_ERROR("CMKLMulticlass::initsvm(): the number of labels is "
"nonpositive, do not know how to handle this!\n");
}

svm->set_labels(m_labels);
}

void CMKLMultiClass::initlpsolver()
void CMKLMulticlass::initlpsolver()
{
if (!kernel)
if (!m_kernel)
{
SG_ERROR("CMKLMultiClass::initlpsolver(): the set kernel is NULL\n");
SG_ERROR("CMKLMulticlass::initlpsolver(): the set kernel is NULL\n");
}

if (kernel->get_kernel_type()!=K_COMBINED)
if (m_kernel->get_kernel_type()!=K_COMBINED)
{
SG_ERROR("CMKLMultiClass::initlpsolver(): given kernel is not of type"
SG_ERROR("CMKLMulticlass::initlpsolver(): given kernel is not of type"
" K_COMBINED %d required by Multiclass Mkl \n",
kernel->get_kernel_type());
m_kernel->get_kernel_type());
}

int numker=dynamic_cast<CCombinedKernel *>(kernel)->get_num_subkernels();
int numker=dynamic_cast<CCombinedKernel *>(m_kernel)->get_num_subkernels();

ASSERT(numker>0);
/*
Expand All @@ -111,22 +111,22 @@ void CMKLMultiClass::initlpsolver()
}
*/

//lpw=new MKLMultiClassGLPK;
//lpw=new MKLMulticlassGLPK;
if(pnorm>1)
{
lpw=new MKLMultiClassGradient;
lpw=new MKLMulticlassGradient;
lpw->set_mkl_norm(pnorm);
}
else
{
lpw=new MKLMultiClassGLPK;
lpw=new MKLMulticlassGLPK;
}
lpw->setup(numker);

}


bool CMKLMultiClass::evaluatefinishcriterion(const int32_t
bool CMKLMulticlass::evaluatefinishcriterion(const int32_t
numberofsilpiterations)
{
if ( (max_num_mkl_iters>0) && (numberofsilpiterations>=max_num_mkl_iters) )
Expand Down Expand Up @@ -202,7 +202,7 @@ bool CMKLMultiClass::evaluatefinishcriterion(const int32_t
return(false);
}

void CMKLMultiClass::addingweightsstep( const std::vector<float64_t> &
void CMKLMulticlass::addingweightsstep( const std::vector<float64_t> &
curweights)
{

Expand All @@ -215,18 +215,18 @@ void CMKLMultiClass::addingweightsstep( const std::vector<float64_t> &
weights=SG_MALLOC(float64_t, curweights.size());
std::copy(curweights.begin(),curweights.end(),weights);

kernel->set_subkernel_weights(SGVector<float64_t>(weights, curweights.size()));
m_kernel->set_subkernel_weights(SGVector<float64_t>(weights, curweights.size()));
SG_FREE(weights);
weights=NULL;

initsvm();

svm->set_kernel(kernel);
svm->set_kernel(m_kernel);
svm->train();

float64_t sumofsignfreealphas=getsumofsignfreealphas();
int32_t numkernels=
dynamic_cast<CCombinedKernel *>(kernel)->get_num_subkernels();
dynamic_cast<CCombinedKernel *>(m_kernel)->get_num_subkernels();


normweightssquared.resize(numkernels);
Expand All @@ -238,7 +238,7 @@ void CMKLMultiClass::addingweightsstep( const std::vector<float64_t> &
lpw->addconstraint(normweightssquared,sumofsignfreealphas);
}

float64_t CMKLMultiClass::getsumofsignfreealphas()
float64_t CMKLMulticlass::getsumofsignfreealphas()
{

std::vector<int> trainlabels2(m_labels->get_num_labels());
Expand Down Expand Up @@ -286,10 +286,10 @@ float64_t CMKLMultiClass::getsumofsignfreealphas()
return(sum);
}

float64_t CMKLMultiClass::getsquarenormofprimalcoefficients(
float64_t CMKLMulticlass::getsquarenormofprimalcoefficients(
const int32_t ind)
{
CKernel * ker=dynamic_cast<CCombinedKernel *>(kernel)->get_kernel(ind);
CKernel * ker=dynamic_cast<CCombinedKernel *>(m_kernel)->get_kernel(ind);

float64_t tmp=0;

Expand Down Expand Up @@ -322,26 +322,26 @@ float64_t CMKLMultiClass::getsquarenormofprimalcoefficients(
}


bool CMKLMultiClass::train_machine(CFeatures* data)
bool CMKLMulticlass::train_machine(CFeatures* data)
{
int numcl=m_labels->get_num_classes();
ASSERT(kernel);
ASSERT(m_kernel);
ASSERT(m_labels && m_labels->get_num_labels());

if (data)
{
if (m_labels->get_num_labels() != data->get_num_vectors())
SG_ERROR("Number of training vectors does not match number of "
"labels\n");
kernel->init(data, data);
m_kernel->init(data, data);
}

initlpsolver();

weightshistory.clear();

int32_t numkernels=
dynamic_cast<CCombinedKernel *>(kernel)->get_num_subkernels();
dynamic_cast<CCombinedKernel *>(m_kernel)->get_num_subkernels();

::std::vector<float64_t> curweights(numkernels,1.0/numkernels);
weightshistory.push_back(curweights);
Expand Down Expand Up @@ -400,7 +400,7 @@ bool CMKLMultiClass::train_machine(CFeatures* data)



float64_t* CMKLMultiClass::getsubkernelweights(int32_t & numweights)
float64_t* CMKLMulticlass::getsubkernelweights(int32_t & numweights)
{
if ( weightshistory.empty() )
{
Expand All @@ -416,19 +416,19 @@ float64_t* CMKLMultiClass::getsubkernelweights(int32_t & numweights)
return res;
}

void CMKLMultiClass::set_mkl_epsilon(float64_t eps )
void CMKLMulticlass::set_mkl_epsilon(float64_t eps )
{
mkl_eps=eps;
}

void CMKLMultiClass::set_max_num_mkliters(int32_t maxnum)
void CMKLMulticlass::set_max_num_mkliters(int32_t maxnum)
{
max_num_mkl_iters=maxnum;
}

void CMKLMultiClass::set_mkl_norm(float64_t norm)
void CMKLMulticlass::set_mkl_norm(float64_t norm)
{
pnorm=norm;
if(pnorm<1 )
SG_ERROR("CMKLMultiClass::set_mkl_norm(float64_t norm) : parameter pnorm<1");
SG_ERROR("CMKLMulticlass::set_mkl_norm(float64_t norm) : parameter pnorm<1");
}

0 comments on commit 48b3ffd

Please sign in to comment.