Skip to content

Commit

Permalink
Proper rejection handling
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Mar 10, 2012
1 parent 47abf3f commit d123ca6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
10 changes: 8 additions & 2 deletions src/shogun/evaluation/MulticlassAccuracy.cpp
Expand Up @@ -26,18 +26,24 @@ float64_t CMulticlassAccuracy::evaluate(CLabels* predicted, CLabels* ground_trut
if (predicted->get_int_label(i)==ground_truth->get_int_label(i))
correct++;
}
return ((float64_t)correct)/length;
}
else
{
int32_t total = length;
for (int32_t i=0; i<length; i++)
{
int32_t predicted_label = predicted->get_int_label(i);

if (predicted_label!=predicted->REJECTION_LABEL && predicted_label==ground_truth->get_int_label(i))
if (predicted_label==predicted->REJECTION_LABEL)
total--;
else
correct++;

return ((float64_t)correct)/total;
}
}
return ((float64_t)correct)/length;
return 0.0;
}

SGMatrix<int32_t> CMulticlassAccuracy::get_confusion_matrix(CLabels* predicted, CLabels* ground_truth)
Expand Down
9 changes: 6 additions & 3 deletions src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -33,6 +33,7 @@ CMulticlassMachine::CMulticlassMachine(

CMulticlassMachine::~CMulticlassMachine()
{
SG_UNREF(m_rejection_strategy);
SG_UNREF(m_machine);

clear_machines();
Expand Down Expand Up @@ -146,17 +147,19 @@ CLabels* CMulticlassMachine::classify_one_vs_rest()
for (int32_t i=0; i<num_vectors; i++)
{
int32_t winner = 0;
float64_t max_out = outputs[0]->get_label(i);

for (int32_t j=0; j<num_machines; j++)
outputs_for_i[j] = outputs[j]->get_label(i);

if (m_rejection_strategy && m_rejection_strategy->reject(outputs_for_i))
if (m_rejection_strategy)
{
winner=result->REJECTION_LABEL;
if (m_rejection_strategy->reject(outputs_for_i))
winner=result->REJECTION_LABEL;
}
else
{
float64_t max_out = outputs[0]->get_label(i);

for (int32_t j=1; j<num_machines; j++)
{
if (outputs_for_i[j]>max_out)
Expand Down
2 changes: 2 additions & 0 deletions src/shogun/machine/MulticlassMachine.h
Expand Up @@ -113,12 +113,14 @@ class CMulticlassMachine : public CMachine
/** get rejection strategy */
inline CRejectionStrategy* get_rejection_strategy() const
{
SG_REF(m_rejection_strategy);
return m_rejection_strategy;
}
/** set rejection strategy */
inline void set_rejection_strategy(CRejectionStrategy* rejection_strategy)
{
SG_UNREF(m_rejection_strategy);
SG_REF(rejection_strategy);
m_rejection_strategy = rejection_strategy;
}

Expand Down

0 comments on commit d123ca6

Please sign in to comment.