Skip to content

Commit

Permalink
Some fixes for multiclass machine
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed May 22, 2012
1 parent d59c2b2 commit 31c9307
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
12 changes: 6 additions & 6 deletions src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -71,11 +71,11 @@ void CMulticlassMachine::init_strategy()
m_multiclass_strategy->set_num_classes(num_classes);
}

CRegressionLabels* CMulticlassMachine::get_submachine_outputs(int32_t i)
CBinaryLabels* CMulticlassMachine::get_submachine_outputs(int32_t i)
{
CMachine *machine = (CMachine*)m_machines->get_element(i);
ASSERT(machine);
CRegressionLabels* output = (CRegressionLabels*)machine->apply();
CBinaryLabels* output = machine->apply_binary();
SG_UNREF(machine);
return output;
}
Expand Down Expand Up @@ -108,16 +108,16 @@ CMulticlassLabels* CMulticlassMachine::apply_multiclass(CFeatures* data)
SG_ERROR("num_machines = %d, did you train your machine?", num_machines);

CMulticlassLabels* result=new CMulticlassLabels(num_vectors);
CRegressionLabels** outputs=SG_MALLOC(CRegressionLabels*, num_machines);
CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);

for (int32_t i=0; i < num_machines; ++i)
outputs[i] = (CRegressionLabels*) get_submachine_outputs(i);
outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);

SGVector<float64_t> output_for_i(num_machines);
for (int32_t i=0; i<num_vectors; i++)
{
for (int32_t j=0; j<num_machines; j++)
output_for_i[j] = outputs[j]->get_label(i);
output_for_i[j] = outputs[j]->get_confidence(i);

result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
}
Expand Down Expand Up @@ -146,7 +146,7 @@ bool CMulticlassMachine::train_machine(CFeatures* data)
init_machine_for_train(data);

m_machines->clear_array();
CMulticlassLabels* train_labels = new CMulticlassLabels(get_num_rhs_vectors());
CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors());
SG_REF(train_labels);
m_machine->set_labels(train_labels);

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/machine/MulticlassMachine.h
Expand Up @@ -85,7 +85,7 @@ class CMulticlassMachine : public CMachine
* @param i number of submachine
* @return outputs
*/
virtual CRegressionLabels* get_submachine_outputs(int32_t i);
virtual CBinaryLabels* get_submachine_outputs(int32_t i);

/** get output of i-th submachine for num-th vector
* @param i number of submachine
Expand Down
Expand Up @@ -105,19 +105,19 @@ SGMatrix<float64_t> CDomainAdaptationMulticlassLibLinear::obtain_regularizer_mat
return w0;
}

CRegressionLabels* CDomainAdaptationMulticlassLibLinear::get_submachine_outputs(int32_t i)
CBinaryLabels* CDomainAdaptationMulticlassLibLinear::get_submachine_outputs(int32_t i)
{
CRegressionLabels* target_outputs = CMulticlassMachine::get_submachine_outputs(i);
CRegressionLabels* source_outputs = m_source_machine->get_submachine_outputs(i);
CBinaryLabels* target_outputs = CMulticlassMachine::get_submachine_outputs(i);
CBinaryLabels* source_outputs = m_source_machine->get_submachine_outputs(i);
int32_t n_target_outputs = target_outputs->get_num_labels();
ASSERT(n_target_outputs==source_outputs->get_num_labels());
SGVector<float64_t> result(n_target_outputs);
for (int32_t j=0; j<result.vlen; j++)
result[j] = (1-m_source_bias)*target_outputs->get_label(j) + m_source_bias*source_outputs->get_label(j);
result[j] = (1-m_source_bias)*target_outputs->get_confidence(j) + m_source_bias*source_outputs->get_confidence(j);

SG_UNREF(target_outputs);
SG_UNREF(source_outputs);

return new CRegressionLabels(result);
return new CBinaryLabels(result);
}
#endif /* HAVE_LAPACK */
Expand Up @@ -40,7 +40,7 @@ class CDomainAdaptationMulticlassLibLinear : public CMulticlassLibLinear
virtual ~CDomainAdaptationMulticlassLibLinear();

/** get submachine outputs */
virtual CRegressionLabels* get_submachine_outputs(int32_t);
virtual CBinaryLabels* get_submachine_outputs(int32_t);

/** get name */
virtual const char* get_name() const
Expand Down

0 comments on commit 31c9307

Please sign in to comment.