Skip to content

Commit

Permalink
Move apply(int) from MulticlassSVM to MulticlassMachine
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Apr 21, 2012
1 parent 17ff86c commit bde47e7
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 74 deletions.
47 changes: 46 additions & 1 deletion src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -334,8 +334,53 @@ int32_t CMulticlassMachine::maxvote_one_vs_one(const SGVector<float64_t> &predic
return winner;
}


float64_t CMulticlassMachine::apply(int32_t num)
{
SG_NOTIMPLEMENTED;
init_machines_for_apply(NULL);
if (m_multiclass_strategy==ONE_VS_REST_STRATEGY)
return classify_example_one_vs_rest(num);
else if (m_multiclass_strategy==ONE_VS_ONE_STRATEGY)
return classify_example_one_vs_one(num);
else
SG_ERROR("unknown multiclass strategy\n");
return 0;
}

float64_t CMulticlassMachine::classify_example_one_vs_rest(int32_t num)
{
ASSERT(m_machines->get_num_elements()>0);
SGVector<float64_t> outputs(m_machines->get_num_elements());

for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
{
CMachine *machine = get_machine(i);
outputs[i]=machine->apply(num);
SG_UNREF(machine);
}

float64_t winner = maxvote_one_vs_rest(outputs);
outputs.destroy_vector();

return winner;
}

float64_t CMulticlassMachine::classify_example_one_vs_one(int32_t num)
{
int32_t num_classes=m_labels->get_num_classes();
ASSERT(m_machines->get_num_elements()>0);
ASSERT(m_machines->get_num_elements()==num_classes*(num_classes-1)/2);

SGVector<float64_t> outputs(m_machines->get_num_elements());
for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
{
CMachine *machine = get_machine(i);
outputs[i] = machine->apply(num);
SG_UNREF(machine);
}

float64_t winner = maxvote_one_vs_one(outputs, num_classes);
outputs.destroy_vector();

return winner;
}
5 changes: 5 additions & 0 deletions src/shogun/machine/MulticlassMachine.h
Expand Up @@ -198,6 +198,11 @@ class CMulticlassMachine : public CMachine
return true;
}

/** classify example (one-vs-one strategy) */
virtual float64_t classify_example_one_vs_one(int32_t num);
/** classify example (one-vs-rest strategy) */
virtual float64_t classify_example_one_vs_rest(int32_t num);

private:

/** register parameters */
Expand Down
52 changes: 0 additions & 52 deletions src/shogun/multiclass/MulticlassSVM.cpp
Expand Up @@ -110,58 +110,6 @@ bool CMulticlassSVM::init_machines_for_apply(CFeatures* data)
return true;
}

float64_t CMulticlassSVM::apply(int32_t num)
{
if (m_multiclass_strategy==ONE_VS_REST_STRATEGY)
return classify_example_one_vs_rest(num);
else if (m_multiclass_strategy==ONE_VS_ONE_STRATEGY)
return classify_example_one_vs_one(num);
else
SG_ERROR("unknown multiclass strategy\n");

return 0;
}

float64_t CMulticlassSVM::classify_example_one_vs_rest(int32_t num)
{
ASSERT(m_machines->get_num_elements()>0);
SGVector<float64_t> outputs(m_machines->get_num_elements());

for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
{
CSVM *svm = get_svm(i);
svm->set_kernel(m_kernel);
outputs[i]=svm->apply(num);
SG_UNREF(svm);
}

float64_t winner = maxvote_one_vs_rest(outputs);
outputs.destroy_vector();

return winner;
}

float64_t CMulticlassSVM::classify_example_one_vs_one(int32_t num)
{
int32_t num_classes=m_labels->get_num_classes();
ASSERT(m_machines->get_num_elements()>0);
ASSERT(m_machines->get_num_elements()==num_classes*(num_classes-1)/2);

SGVector<float64_t> outputs(m_machines->get_num_elements());
for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
{
CSVM *svm = get_svm(i);
svm->set_kernel(m_kernel);
outputs[i] = svm->apply(num);
SG_UNREF(svm);
}

float64_t winner = maxvote_one_vs_one(outputs, num_classes);
outputs.destroy_vector();

return winner;
}

bool CMulticlassSVM::load(FILE* modelfl)
{
bool result=true;
Expand Down
21 changes: 0 additions & 21 deletions src/shogun/multiclass/MulticlassSVM.h
Expand Up @@ -71,27 +71,6 @@ class CMulticlassSVM : public CKernelMulticlassMachine
return dynamic_cast<CSVM *>(m_machines->get_element_safe(num));
}

/** classify one example
*
* @param num number of example to classify
* @return resulting classification
*/
virtual float64_t apply(int32_t num);

/** classify one example one vs rest
*
* @param num number of example of classify
* @return resulting classification
*/
virtual float64_t classify_example_one_vs_rest(int32_t num);

/** classify one example one vs one
*
* @param num number of example of classify
* @return resulting classification
*/
float64_t classify_example_one_vs_one(int32_t num);

/** load a Multiclass SVM from file
* @param svm_file the file handle
*/
Expand Down

0 comments on commit bde47e7

Please sign in to comment.