Skip to content

Commit

Permalink
Proper multiclass labels handling in KNN
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed May 22, 2012
1 parent de154b3 commit d204fcd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/shogun/multiclass/KNN.cpp
Expand Up @@ -108,7 +108,7 @@ bool CKNN::train_machine(CFeatures* data)
return true;
}

CLabels* CKNN::apply(CFeatures* data)
CMulticlassLabels* CKNN::apply_multiclass(CFeatures* data)
{
if (data)
init_distance(data);
Expand All @@ -124,7 +124,7 @@ CLabels* CKNN::apply(CFeatures* data)
int32_t num_lab=distance->get_num_vec_rhs();
ASSERT(m_k<=distance->get_num_vec_lhs());

CRegressionLabels* output=new CRegressionLabels(num_lab);
CMulticlassLabels* output=new CMulticlassLabels(num_lab);

float64_t* dists = NULL;
int32_t* train_lab = NULL;
Expand Down Expand Up @@ -273,15 +273,15 @@ CLabels* CKNN::apply(CFeatures* data)
return output;
}

CLabels* CKNN::classify_NN()
CMulticlassLabels* CKNN::classify_NN()
{
ASSERT(distance);
ASSERT(m_num_classes>0);

int32_t num_lab = distance->get_num_vec_rhs();
ASSERT(num_lab);

CRegressionLabels* output = new CRegressionLabels(num_lab);
CMulticlassLabels* output = new CMulticlassLabels(num_lab);
float64_t* distances = SG_MALLOC(float64_t, m_train_labels.vlen);

SG_INFO("%d test examples\n", num_lab);
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/multiclass/KNN.h
Expand Up @@ -81,7 +81,7 @@ class CKNN : public CDistanceMachine
* @param data (test)data to be classified
* @return classified labels
*/
virtual CLabels* apply(CFeatures* data=NULL);
virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);

/// get output for example "vec_idx"
virtual float64_t apply_one(int32_t vec_idx)
Expand Down Expand Up @@ -168,7 +168,7 @@ class CKNN : public CDistanceMachine
/** classify all examples with nearest neighbor (k=1)
* @return classified labels
*/
virtual CLabels* classify_NN();
virtual CMulticlassLabels* classify_NN();

/** init distances to test examples
* @param data test examples
Expand Down

0 comments on commit d204fcd

Please sign in to comment.