Skip to content

Commit

Permalink
Use the remembered num_classes in a MulticlassStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Apr 28, 2012
1 parent 7830726 commit 022c18d
Show file tree
Hide file tree
Showing 11 changed files with 15 additions and 19 deletions.
5 changes: 2 additions & 3 deletions src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -80,7 +80,6 @@ CLabels* CMulticlassMachine::apply()

if (is_ready())
{
int32_t num_classes=m_labels->get_num_classes();
int32_t num_vectors=get_num_rhs_vectors();
int32_t num_machines=m_machines->get_num_elements();
if (num_machines <= 0)
Expand All @@ -104,7 +103,7 @@ CLabels* CMulticlassMachine::apply()
for (int32_t j=0; j<num_machines; j++)
output_for_i[j] = outputs[j]->get_label(i);

result->set_label(i, m_multiclass_strategy->decide_label(output_for_i, num_classes));
result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
}

output_for_i.destroy_vector();
Expand Down Expand Up @@ -176,7 +175,7 @@ float64_t CMulticlassMachine::apply(int32_t num)
SG_UNREF(machine);
}

float64_t result=m_multiclass_strategy->decide_label(outputs, m_labels->get_num_classes());
float64_t result=m_multiclass_strategy->decide_label(outputs);
outputs.destroy_vector();

return result;
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/GMNPSVM.cpp
Expand Up @@ -65,7 +65,7 @@ bool CGMNPSVM::train_machine(CFeatures* data)
}

int32_t num_data = m_labels->get_num_labels();
int32_t num_classes = m_labels->get_num_classes();
int32_t num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_virtual_data= num_data*(num_classes-1);

SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassLibSVM.cpp
Expand Up @@ -35,7 +35,7 @@ bool CMulticlassLibSVM::train_machine(CFeatures* data)
problem = svm_problem();

ASSERT(m_labels && m_labels->get_num_labels());
int32_t num_classes = m_labels->get_num_classes();
int32_t num_classes = m_multiclass_strategy->get_num_classes();
problem.l=m_labels->get_num_labels();
SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassOCAS.cpp
Expand Up @@ -69,7 +69,7 @@ bool CMulticlassOCAS::train_machine(CFeatures* data)
set_features((CDotFeatures*)data);

int32_t num_vectors = m_features->get_num_vectors();
int32_t num_classes = m_labels->get_num_classes();
int32_t num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_features = m_features->get_dim_feature_space();

float64_t C = m_C;
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/multiclass/MulticlassOneVsOneStrategy.cpp
Expand Up @@ -63,15 +63,15 @@ SGVector<int32_t> CMulticlassOneVsOneStrategy::train_prepare_next()
return SGVector<int32_t>(subset.vector, tot);
}

int32_t CMulticlassOneVsOneStrategy::decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)
int32_t CMulticlassOneVsOneStrategy::decide_label(const SGVector<float64_t> &outputs)
{
int32_t s=0;
SGVector<int32_t> votes(num_classes);
SGVector<int32_t> votes(m_num_classes);
votes.zero();

for (int32_t i=0; i<num_classes; i++)
for (int32_t i=0; i<m_num_classes; i++)
{
for (int32_t j=i+1; j<num_classes; j++)
for (int32_t j=i+1; j<m_num_classes; j++)
{
if (outputs[s++]>0)
votes[i]++;
Expand Down
3 changes: 1 addition & 2 deletions src/shogun/multiclass/MulticlassOneVsOneStrategy.h
Expand Up @@ -35,9 +35,8 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy

/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
* @param num_classes number of classes
*/
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);
virtual int32_t decide_label(const SGVector<float64_t> &outputs);

/** get number of machines used in this strategy.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassOneVsRestStrategy.cpp
Expand Up @@ -39,7 +39,7 @@ SGVector<int32_t> CMulticlassOneVsRestStrategy::train_prepare_next()
return SGVector<int32_t>();
}

int32_t CMulticlassOneVsRestStrategy::decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)
int32_t CMulticlassOneVsRestStrategy::decide_label(const SGVector<float64_t> &outputs)
{
if (m_rejection_strategy && m_rejection_strategy->reject(outputs))
return CLabels::REJECTION_LABEL;
Expand Down
3 changes: 1 addition & 2 deletions src/shogun/multiclass/MulticlassOneVsRestStrategy.h
Expand Up @@ -62,9 +62,8 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy

/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
* @param num_classes number of classes
*/
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);
virtual int32_t decide_label(const SGVector<float64_t> &outputs);

/** get number of machines used in this strategy.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassSVM.cpp
Expand Up @@ -265,7 +265,7 @@ bool CMulticlassSVM::save(FILE* modelfl)

SG_INFO( "Writing model file...");
fprintf(modelfl,"%%MultiClassSVM\n");
fprintf(modelfl,"num_classes=%d;\n", m_labels->get_num_classes());
fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());

Expand Down
3 changes: 1 addition & 2 deletions src/shogun/multiclass/MulticlassStrategy.h
Expand Up @@ -61,9 +61,8 @@ class CMulticlassStrategy: public CSGObject

/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
* @param num_classes number of classes
*/
virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)=0;
virtual int32_t decide_label(const SGVector<float64_t> &outputs)=0;

/** get number of machines used in this strategy.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/ScatterSVM.cpp
Expand Up @@ -48,7 +48,7 @@ CScatterSVM::~CScatterSVM()
bool CScatterSVM::train_machine(CFeatures* data)
{
ASSERT(m_labels && m_labels->get_num_labels());
m_num_classes = m_labels->get_num_classes();
m_num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_vectors = m_labels->get_num_labels();

if (data)
Expand Down

0 comments on commit 022c18d

Please sign in to comment.