Skip to content

Commit

Permalink
* fix MulticlassModel argmax, the loss must be included in training t…
Browse files Browse the repository at this point in the history
…o find the MMV
  • Loading branch information
iglesias committed Aug 22, 2012
1 parent 50ba31c commit 1a87bf0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
23 changes: 21 additions & 2 deletions src/shogun/structure/MulticlassModel.cpp
Expand Up @@ -81,11 +81,13 @@ CResultSet* CMulticlassModel::argmax(
// Find the class that gives the maximum score

float64_t score = 0, ypred = 0;
float64_t max_score = df->dense_dot(feat_idx, w.vector, feats_dim);
float64_t max_score = -CMath::INFTY;

for ( int32_t c = 1 ; c < m_num_classes ; ++c )
{
score = df->dense_dot(feat_idx, w.vector+c*feats_dim, feats_dim);
if ( training )
score += delta_loss(feat_idx, c);

if ( score > max_score )
{
Expand Down Expand Up @@ -119,7 +121,24 @@ float64_t CMulticlassModel::delta_loss(CStructuredData* y1, CStructuredData* y2)
ASSERT(rn1 != NULL);
ASSERT(rn2 != NULL);

return ( rn1->value == rn2->value ) ? 0 : 1;
return delta_loss(rn1->value, rn2->value);
}

float64_t CMulticlassModel::delta_loss(int32_t y1_idx, float64_t y2)
{
REQUIRE(y1_idx >= 0 || y1_idx < m_labels->get_num_labels(),
"The label index must be inside [0, num_labels-1]\n");

CRealNumber* rn1 = CRealNumber::obtain_from_generic(m_labels->get_label(y1_idx));
float64_t ret = delta_loss(rn1->value, y2);
SG_UNREF(rn1);

return ret;
}

float64_t CMulticlassModel::delta_loss(float64_t y1, float64_t y2)
{
return (y1 == y2) ? 0 : 1;
}

void CMulticlassModel::init_opt(
Expand Down
4 changes: 4 additions & 0 deletions src/shogun/structure/MulticlassModel.h
Expand Up @@ -133,6 +133,10 @@ class CMulticlassModel : public CStructuredModel
private:
void init();

/** Different flavours of the delta_loss that become handy */
float64_t delta_loss(float64_t y1, float64_t y2);
float64_t delta_loss(int32_t y1_idx, float64_t y2);

private:
/** number of classes */
int32_t m_num_classes;
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/structure/StructuredModel.cpp
Expand Up @@ -98,8 +98,8 @@ SGVector< float64_t > CStructuredModel::get_joint_feature_vector(

float64_t CStructuredModel::delta_loss(int32_t ytrue_idx, CStructuredData* ypred)
{
if ( ytrue_idx < 0 || ytrue_idx >= m_labels->get_num_labels() )
SG_ERROR("The label index must be inside [0, num_labels-1]\n");
REQUIRE(ytrue_idx >= 0 || ytrue_idx < m_labels->get_num_labels(),
"The label index must be inside [0, num_labels-1]\n");

CStructuredData* ytrue = m_labels->get_label(ytrue_idx);
float64_t ret = delta_loss(ytrue, ypred);
Expand Down

0 comments on commit 1a87bf0

Please sign in to comment.