Skip to content

Commit

Permalink
Merge branch 'multiclass-ecoc' of git://github.com/pluskid/shogun
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Apr 28, 2012
2 parents 1927645 + 1981a90 commit 3be395f
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 72 deletions.
23 changes: 19 additions & 4 deletions src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -35,6 +35,9 @@ CMulticlassMachine::CMulticlassMachine(
SG_REF(machine);
m_machine = machine;
register_parameters();

if (labs)
init_strategy();
}

CMulticlassMachine::~CMulticlassMachine()
Expand All @@ -44,13 +47,26 @@ CMulticlassMachine::~CMulticlassMachine()
SG_UNREF(m_machines);
}

void CMulticlassMachine::set_labels(CLabels* lab)
{
CMachine::set_labels(lab);
if (lab)
init_strategy();
}

void CMulticlassMachine::register_parameters()
{
SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**)&m_machines, "machines", "Machines that jointly make up the multi-class machine.", MS_NOT_AVAILABLE);
}

void CMulticlassMachine::init_strategy()
{
int32_t num_classes = m_labels->get_num_classes();
m_multiclass_strategy->set_num_classes(num_classes);
}

CLabels* CMulticlassMachine::apply(CFeatures* features)
{
init_machines_for_apply(features);
Expand All @@ -64,7 +80,6 @@ CLabels* CMulticlassMachine::apply()

if (is_ready())
{
int32_t num_classes=m_labels->get_num_classes();
int32_t num_vectors=get_num_rhs_vectors();
int32_t num_machines=m_machines->get_num_elements();
if (num_machines <= 0)
Expand All @@ -88,15 +103,15 @@ CLabels* CMulticlassMachine::apply()
for (int32_t j=0; j<num_machines; j++)
output_for_i[j] = outputs[j]->get_label(i);

result->set_label(i, m_multiclass_strategy->decide_label(output_for_i, num_classes));
result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
}

output_for_i.destroy_vector();
for (int32_t i=0; i < num_machines; ++i)
SG_UNREF(outputs[i]);

SG_FREE(outputs);

return result;
}
else
Expand Down Expand Up @@ -160,7 +175,7 @@ float64_t CMulticlassMachine::apply(int32_t num)
SG_UNREF(machine);
}

float64_t result=m_multiclass_strategy->decide_label(outputs, m_labels->get_num_classes());
float64_t result=m_multiclass_strategy->decide_label(outputs);
outputs.destroy_vector();

return result;
Expand Down
8 changes: 8 additions & 0 deletions src/shogun/machine/MulticlassMachine.h
Expand Up @@ -39,6 +39,12 @@ class CMulticlassMachine : public CMachine
/** destructor */
virtual ~CMulticlassMachine();

/** set labels
*
* @param lab labels
*/
virtual void set_labels(CLabels* lab);

/** set machine
*
* @param num index of machine
Expand Down Expand Up @@ -110,6 +116,8 @@ class CMulticlassMachine : public CMachine
}

protected:
/** init strategy */
void init_strategy();

/** clear machines */
void clear_machines();
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/GMNPSVM.cpp
Expand Up @@ -65,7 +65,7 @@ bool CGMNPSVM::train_machine(CFeatures* data)
}

int32_t num_data = m_labels->get_num_labels();
int32_t num_classes = m_labels->get_num_classes();
int32_t num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_virtual_data= num_data*(num_classes-1);

SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassLibSVM.cpp
Expand Up @@ -35,7 +35,7 @@ bool CMulticlassLibSVM::train_machine(CFeatures* data)
problem = svm_problem();

ASSERT(m_labels && m_labels->get_num_labels());
int32_t num_classes = m_labels->get_num_classes();
int32_t num_classes = m_multiclass_strategy->get_num_classes();
problem.l=m_labels->get_num_labels();
SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassOCAS.cpp
Expand Up @@ -69,7 +69,7 @@ bool CMulticlassOCAS::train_machine(CFeatures* data)
set_features((CDotFeatures*)data);

int32_t num_vectors = m_features->get_num_vectors();
int32_t num_classes = m_labels->get_num_classes();
int32_t num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_features = m_features->get_dim_feature_space();

float64_t C = m_C;
Expand Down
11 changes: 5 additions & 6 deletions src/shogun/multiclass/MulticlassOneVsOneStrategy.cpp
Expand Up @@ -13,14 +13,13 @@
using namespace shogun;

CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy()
:CMulticlassStrategy(), m_num_machines(0), m_num_classes(0)
:CMulticlassStrategy(), m_num_machines(0)
{
}

void CMulticlassOneVsOneStrategy::train_start(CLabels *orig_labels, CLabels *train_labels)
{
CMulticlassStrategy::train_start(orig_labels, train_labels);
m_num_classes = m_orig_labels->get_num_classes();
m_num_machines=m_num_classes*(m_num_classes-1)/2;

m_train_pair_idx_1 = 0;
Expand Down Expand Up @@ -64,15 +63,15 @@ SGVector<int32_t> CMulticlassOneVsOneStrategy::train_prepare_next()
return SGVector<int32_t>(subset.vector, tot);
}

int32_t CMulticlassOneVsOneStrategy::decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)
int32_t CMulticlassOneVsOneStrategy::decide_label(const SGVector<float64_t> &outputs)
{
int32_t s=0;
SGVector<int32_t> votes(num_classes);
SGVector<int32_t> votes(m_num_classes);
votes.zero();

for (int32_t i=0; i<num_classes; i++)
for (int32_t i=0; i<m_num_classes; i++)
{
for (int32_t j=i+1; j<num_classes; j++)
for (int32_t j=i+1; j<m_num_classes; j++)
{
if (outputs[s++]>0)
votes[i]++;
Expand Down
17 changes: 4 additions & 13 deletions src/shogun/multiclass/MulticlassOneVsOneStrategy.h
Expand Up @@ -12,7 +12,7 @@

namespace shogun
{

class CMulticlassOneVsOneStrategy: public CMulticlassStrategy
{
public:
Expand All @@ -35,22 +35,14 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy

/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
* @param num_classes number of classes
*/
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);
virtual int32_t decide_label(const SGVector<float64_t> &outputs);

/** get number of machines used in this strategy.
* @param num_classes number of classes in this problem
*/
virtual int32_t get_num_machines(int32_t num_classes)
{
return num_classes*(num_classes-1)/2;
}

/** get strategy type */
virtual EMulticlassStrategy get_strategy_type()
virtual int32_t get_num_machines()
{
return ONE_VS_ONE_STRATEGY;
return m_num_classes*(m_num_classes-1)/2;
}

/** get name */
Expand All @@ -61,7 +53,6 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy

protected:
int32_t m_num_machines;
int32_t m_num_classes;
int32_t m_train_pair_idx_1;
int32_t m_train_pair_idx_2;
};
Expand Down
6 changes: 3 additions & 3 deletions src/shogun/multiclass/MulticlassOneVsRestStrategy.cpp
Expand Up @@ -13,12 +13,12 @@
using namespace shogun;

CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy()
:CMulticlassStrategy(), m_num_machines(0), m_rejection_strategy(NULL)
:CMulticlassStrategy(), m_rejection_strategy(NULL)
{
}

CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy(CRejectionStrategy *rejection_strategy)
:CMulticlassStrategy(), m_num_machines(0), m_rejection_strategy(rejection_strategy)
:CMulticlassStrategy(), m_rejection_strategy(rejection_strategy)
{
SG_REF(m_rejection_strategy);
}
Expand All @@ -39,7 +39,7 @@ SGVector<int32_t> CMulticlassOneVsRestStrategy::train_prepare_next()
return SGVector<int32_t>();
}

int32_t CMulticlassOneVsRestStrategy::decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)
int32_t CMulticlassOneVsRestStrategy::decide_label(const SGVector<float64_t> &outputs)
{
if (m_rejection_strategy && m_rejection_strategy->reject(outputs))
return CLabels::REJECTION_LABEL;
Expand Down
20 changes: 5 additions & 15 deletions src/shogun/multiclass/MulticlassOneVsRestStrategy.h
Expand Up @@ -47,38 +47,29 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
virtual void train_start(CLabels *orig_labels, CLabels *train_labels)
{
CMulticlassStrategy::train_start(orig_labels, train_labels);
m_num_machines=m_orig_labels->get_num_classes();
}

/** has more training phase */
virtual bool train_has_more()
{
return m_train_iter < m_num_machines;
return m_train_iter < m_num_classes;
}

/** prepare for the next training phase.
* @return NULL, since no subset is needed in one-vs-rest strategy
*/
*/
virtual SGVector<int32_t> train_prepare_next();

/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
* @param num_classes number of classes
*/
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);
virtual int32_t decide_label(const SGVector<float64_t> &outputs);

/** get number of machines used in this strategy.
* @param num_classes number of classes in this problem
*/
virtual int32_t get_num_machines(int32_t num_classes)
{
return num_classes;
}

/** get strategy type */
virtual EMulticlassStrategy get_strategy_type()
virtual int32_t get_num_machines()
{
return ONE_VS_REST_STRATEGY;
return m_num_classes;
}

/** get name */
Expand All @@ -88,7 +79,6 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
};

protected:
int32_t m_num_machines;
CRejectionStrategy *m_rejection_strategy;
};

Expand Down
15 changes: 2 additions & 13 deletions src/shogun/multiclass/MulticlassSVM.cpp
Expand Up @@ -46,7 +46,7 @@ bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
{
if (num_classes>0)
{
int32_t num_svms=m_multiclass_strategy->get_num_machines(num_classes);
int32_t num_svms=m_multiclass_strategy->get_num_machines();

m_machines->clear_array();
for (index_t i=0; i<num_svms; ++i)
Expand Down Expand Up @@ -127,16 +127,6 @@ bool CMulticlassSVM::load(FILE* modelfl)
line_number++;
}

int_buffer=0;
if (fscanf(modelfl," multiclass_strategy=%d; \n", &int_buffer) != 1)
SG_ERROR( "error in svm file, line nr:%d\n", line_number);

if (!feof(modelfl))
line_number++;

if (int_buffer != m_multiclass_strategy->get_strategy_type())
SG_ERROR("multiclass strategy does not match %ld vs. %ld\n", int_buffer, m_multiclass_strategy->get_strategy_type());

int_buffer=0;
if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
SG_ERROR( "error in svm file, line nr:%d\n", line_number);
Expand Down Expand Up @@ -275,8 +265,7 @@ bool CMulticlassSVM::save(FILE* modelfl)

SG_INFO( "Writing model file...");
fprintf(modelfl,"%%MultiClassSVM\n");
fprintf(modelfl,"multiclass_strategy=%d;\n", m_multiclass_strategy->get_strategy_type());
fprintf(modelfl,"num_classes=%d;\n", m_labels->get_num_classes());
fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());

Expand Down
1 change: 1 addition & 0 deletions src/shogun/multiclass/MulticlassStrategy.cpp
Expand Up @@ -17,6 +17,7 @@ using namespace shogun;
CMulticlassStrategy::CMulticlassStrategy()
:m_train_labels(NULL), m_orig_labels(NULL), m_train_iter(0)
{
SG_ADD(&m_num_classes, "num_classes", "Number of classes", MS_NOT_AVAILABLE);
}

void CMulticlassStrategy::train_start(CLabels *orig_labels, CLabels *train_labels)
Expand Down
28 changes: 14 additions & 14 deletions src/shogun/multiclass/MulticlassStrategy.h
Expand Up @@ -18,14 +18,6 @@
namespace shogun
{

#ifndef DOXYGEN_SHOULD_SKIP_THIS
enum EMulticlassStrategy
{
ONE_VS_REST_STRATEGY,
ONE_VS_ONE_STRATEGY,
};
#endif

class CMulticlassStrategy: public CSGObject
{
public:
Expand All @@ -41,8 +33,17 @@ class CMulticlassStrategy: public CSGObject
return "MulticlassStrategy";
};

/** get strategy type */
virtual EMulticlassStrategy get_strategy_type()=0;
/** set number of classes */
void set_num_classes(int32_t num_classes)
{
m_num_classes = num_classes;
}

/** get number of classes */
int32_t get_num_classes() const
{
return m_num_classes;
}

/** start training */
virtual void train_start(CLabels *orig_labels, CLabels *train_labels);
Expand All @@ -60,19 +61,18 @@ class CMulticlassStrategy: public CSGObject

/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
* @param num_classes number of classes
*/
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)=0;
virtual int32_t decide_label(const SGVector<float64_t> &outputs)=0;

/** get number of machines used in this strategy.
* @param num_classes number of classes in this problem
*/
virtual int32_t get_num_machines(int32_t num_classes)=0;
virtual int32_t get_num_machines()=0;

protected:
CLabels *m_train_labels;
CLabels *m_orig_labels;
int32_t m_train_iter;
int32_t m_num_classes;
};

} // namespace shogun
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/ScatterSVM.cpp
Expand Up @@ -48,7 +48,7 @@ CScatterSVM::~CScatterSVM()
bool CScatterSVM::train_machine(CFeatures* data)
{
ASSERT(m_labels && m_labels->get_num_labels());
m_num_classes = m_labels->get_num_classes();
m_num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_vectors = m_labels->get_num_labels();

if (data)
Expand Down

0 comments on commit 3be395f

Please sign in to comment.