Skip to content

Commit

Permalink
Merge pull request #730 from iglesias/so
Browse files Browse the repository at this point in the history
Fixes in Multiclass SO
  • Loading branch information
Soeren Sonnenburg committed Aug 17, 2012
2 parents 2fecaab + 2bf574c commit 7636b27
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
16 changes: 11 additions & 5 deletions examples/undocumented/libshogun/so_multiclass.cpp
Expand Up @@ -53,6 +53,8 @@ void gen_rand_data(SGVector< float64_t > labs, SGMatrix< float64_t > feats)
fprintf(pfile, "\n");
}
}

fclose(pfile);
}

int main(int argc, char ** argv)
Expand Down Expand Up @@ -83,7 +85,6 @@ int main(int argc, char ** argv)

sosvm->train();
CStructuredLabels* out = CStructuredLabels::obtain_from_generic(sosvm->apply());
SG_REF(out);

// Create liblinear svm classifier with L2-regularized L2-loss
CLibLinear* svm = new CLibLinear(L2R_L2LOSS_SVC);
Expand All @@ -101,14 +102,17 @@ int main(int argc, char ** argv)
// Train the multiclass machine using the data passed in the constructor
mc_svm->train();
CMulticlassLabels* mout = CMulticlassLabels::obtain_from_generic(mc_svm->apply());
SG_REF(mout);

int32_t sosvm_ncorrect = 0, mc_ncorrect = 0;
SGVector< float64_t > slacks = sosvm->get_slacks();
for ( int i = 0 ; i < out->get_num_labels() ; ++i )
{
sosvm_ncorrect += mlabels->get_label(i) == ( (CRealNumber*) out->get_label(i) )->value;
CRealNumber* ypred = (CRealNumber*) out->get_label(i);

sosvm_ncorrect += mlabels->get_label(i) == ypred->value;
mc_ncorrect += mlabels->get_label(i) == mout->get_label(i);

SG_UNREF(ypred); // because of CStructuredLabels::get_label()
}

SGVector< float64_t > w = sosvm->get_w();
Expand All @@ -118,10 +122,12 @@ int main(int argc, char ** argv)

for ( int32_t i = 0 ; i < NUM_CLASSES ; ++i )
{
SGVector< float64_t > mw =
((CLinearMachine*) mc_svm->get_machine(i))->get_w();
CLinearMachine* lm = (CLinearMachine*) mc_svm->get_machine(i);
SGVector< float64_t > mw = lm->get_w();
for ( int32_t j = 0 ; j < mw.vlen ; ++j )
SG_SPRINT("%10f ", mw[j]);

SG_UNREF(lm); // because of CLinearMulticlassMachine::get_machine()
}
SG_SPRINT("\n");

Expand Down
3 changes: 2 additions & 1 deletion src/shogun/mathematics/Mosek.cpp
Expand Up @@ -238,7 +238,8 @@ MSKrescodee CMosek::wrapper_putaveclist(
for ( index_t i = 0 ; i < A.num_rows-1 ; ++i )
ptre[i] = ptrb[i+1];

ptre[A.num_rows-1] = nnza;
if ( A.num_rows > 0 )
ptre[A.num_rows-1] = nnza;

MSKrescodee ret = MSK_putaveclist(task, MSK_ACC_CON, A.num_rows, sub.vector,
ptrb.vector, ptre.vector,
Expand Down

0 comments on commit 7636b27

Please sign in to comment.