Skip to content

Commit

Permalink
Store number of classes in MulticlassStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Apr 28, 2012
1 parent b4fbbdb commit 7830726
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 18 deletions.
18 changes: 17 additions & 1 deletion src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -35,6 +35,9 @@ CMulticlassMachine::CMulticlassMachine(
SG_REF(machine);
m_machine = machine;
register_parameters();

if (labs)
init_strategy();
}

CMulticlassMachine::~CMulticlassMachine()
Expand All @@ -44,13 +47,26 @@ CMulticlassMachine::~CMulticlassMachine()
SG_UNREF(m_machines);
}

void CMulticlassMachine::set_labels(CLabels* lab)
{
CMachine::set_labels(lab);
if (lab)
init_strategy();
}

void CMulticlassMachine::register_parameters()
{
SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**)&m_machines, "machines", "Machines that jointly make up the multi-class machine.", MS_NOT_AVAILABLE);
}

void CMulticlassMachine::init_strategy()
{
int32_t num_classes = m_labels->get_num_classes();
m_multiclass_strategy->set_num_classes(num_classes);
}

CLabels* CMulticlassMachine::apply(CFeatures* features)
{
init_machines_for_apply(features);
Expand Down Expand Up @@ -96,7 +112,7 @@ CLabels* CMulticlassMachine::apply()
SG_UNREF(outputs[i]);

SG_FREE(outputs);

return result;
}
else
Expand Down
8 changes: 8 additions & 0 deletions src/shogun/machine/MulticlassMachine.h
Expand Up @@ -39,6 +39,12 @@ class CMulticlassMachine : public CMachine
/** destructor */
virtual ~CMulticlassMachine();

/** set labels
*
* @param lab labels
*/
virtual void set_labels(CLabels* lab);

/** set machine
*
* @param num index of machine
Expand Down Expand Up @@ -110,6 +116,8 @@ class CMulticlassMachine : public CMachine
}

protected:
/** init strategy */
void init_strategy();

/** clear machines */
void clear_machines();
Expand Down
3 changes: 1 addition & 2 deletions src/shogun/multiclass/MulticlassOneVsOneStrategy.cpp
Expand Up @@ -13,14 +13,13 @@
using namespace shogun;

CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy()
:CMulticlassStrategy(), m_num_machines(0), m_num_classes(0)
:CMulticlassStrategy(), m_num_machines(0)
{
}

void CMulticlassOneVsOneStrategy::train_start(CLabels *orig_labels, CLabels *train_labels)
{
CMulticlassStrategy::train_start(orig_labels, train_labels);
m_num_classes = m_orig_labels->get_num_classes();
m_num_machines=m_num_classes*(m_num_classes-1)/2;

m_train_pair_idx_1 = 0;
Expand Down
6 changes: 2 additions & 4 deletions src/shogun/multiclass/MulticlassOneVsOneStrategy.h
Expand Up @@ -40,11 +40,10 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);

/** get number of machines used in this strategy.
* @param num_classes number of classes in this problem
*/
virtual int32_t get_num_machines(int32_t num_classes)
virtual int32_t get_num_machines()
{
return num_classes*(num_classes-1)/2;
return m_num_classes*(m_num_classes-1)/2;
}

/** get name */
Expand All @@ -55,7 +54,6 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy

protected:
int32_t m_num_machines;
int32_t m_num_classes;
int32_t m_train_pair_idx_1;
int32_t m_train_pair_idx_2;
};
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/multiclass/MulticlassOneVsRestStrategy.cpp
Expand Up @@ -13,12 +13,12 @@
using namespace shogun;

CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy()
:CMulticlassStrategy(), m_num_machines(0), m_rejection_strategy(NULL)
:CMulticlassStrategy(), m_rejection_strategy(NULL)
{
}

CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy(CRejectionStrategy *rejection_strategy)
:CMulticlassStrategy(), m_num_machines(0), m_rejection_strategy(rejection_strategy)
:CMulticlassStrategy(), m_rejection_strategy(rejection_strategy)
{
SG_REF(m_rejection_strategy);
}
Expand Down
9 changes: 3 additions & 6 deletions src/shogun/multiclass/MulticlassOneVsRestStrategy.h
Expand Up @@ -47,13 +47,12 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
virtual void train_start(CLabels *orig_labels, CLabels *train_labels)
{
CMulticlassStrategy::train_start(orig_labels, train_labels);
m_num_machines=m_orig_labels->get_num_classes();
}

/** has more training phase */
virtual bool train_has_more()
{
return m_train_iter < m_num_machines;
return m_train_iter < m_num_classes;
}

/** prepare for the next training phase.
Expand All @@ -68,11 +67,10 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);

/** get number of machines used in this strategy.
* @param num_classes number of classes in this problem
*/
virtual int32_t get_num_machines(int32_t num_classes)
virtual int32_t get_num_machines()
{
return num_classes;
return m_num_classes;
}

/** get name */
Expand All @@ -82,7 +80,6 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
};

protected:
int32_t m_num_machines;
CRejectionStrategy *m_rejection_strategy;
};

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassSVM.cpp
Expand Up @@ -46,7 +46,7 @@ bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
{
if (num_classes>0)
{
int32_t num_svms=m_multiclass_strategy->get_num_machines(num_classes);
int32_t num_svms=m_multiclass_strategy->get_num_machines();

m_machines->clear_array();
for (index_t i=0; i<num_svms; ++i)
Expand Down
16 changes: 14 additions & 2 deletions src/shogun/multiclass/MulticlassStrategy.h
Expand Up @@ -33,6 +33,18 @@ class CMulticlassStrategy: public CSGObject
return "MulticlassStrategy";
};

/** set number of classes */
void set_num_classes(int32_t num_classes)
{
m_num_classes = num_classes;
}

/** get number of classes */
int32_t get_num_classes() const
{
return m_num_classes;
}

/** start training */
virtual void train_start(CLabels *orig_labels, CLabels *train_labels);

Expand All @@ -54,14 +66,14 @@ class CMulticlassStrategy: public CSGObject
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)=0;

/** get number of machines used in this strategy.
* @param num_classes number of classes in this problem
*/
virtual int32_t get_num_machines(int32_t num_classes)=0;
virtual int32_t get_num_machines()=0;

protected:
CLabels *m_train_labels;
CLabels *m_orig_labels;
int32_t m_train_iter;
int32_t m_num_classes;
};

} // namespace shogun
Expand Down

0 comments on commit 7830726

Please sign in to comment.