Skip to content

Commit

Permalink
Merge branch 'nearest_centroid' of git://github.com/PhilippeTillet/sh…
Browse files Browse the repository at this point in the history
…ogun
  • Loading branch information
lisitsyn committed Apr 10, 2012
2 parents 7b3b7bc + 04677a7 commit a690dd7
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 0 deletions.
54 changes: 54 additions & 0 deletions examples/undocumented/libshogun/classifier_nearest_centroid.cpp
@@ -0,0 +1,54 @@
#include <shogun/features/Labels.h>
#include <shogun/features/SimpleFeatures.h>
#include <shogun/distance/EuclidianDistance.h>
#include <shogun/classifier/NearestCentroid.h>
#include <shogun/base/init.h>

using namespace shogun;

void print_message(FILE* target, const char* str)
{
fprintf(target, "%s", str);
}

int main(){
init_shogun(&print_message);
index_t num_vec=7;
index_t num_feat=2;
index_t num_class=2;

// create some data
SGMatrix<float64_t> matrix(num_feat, num_vec);
CMath::range_fill_vector(matrix.matrix, num_feat*num_vec);

// Create features ; shogun will now own the matrix created
CSimpleFeatures<float64_t>* features=new CSimpleFeatures<float64_t>(matrix);

CMath::display_matrix(matrix.matrix,num_feat,num_vec);

//Create labels
CLabels* labels=new CLabels(num_vec);
for (index_t i=0; i<num_vec; ++i)
labels->set_label(i, i%num_class);

//Create Euclidian Distance
CEuclidianDistance* distance = new CEuclidianDistance(features,features);

//Create Nearest Centroid
CNearestCentroid* nearest_centroid = new CNearestCentroid(distance, labels);
nearest_centroid->train();

// classify on training examples
CLabels* output=nearest_centroid->apply();
CMath::display_vector(output->get_labels().vector, output->get_num_labels(),
"batch output");

SG_UNREF(output);

// free up memory
SG_UNREF(nearest_centroid);

exit_shogun();
return 0;

}
120 changes: 120 additions & 0 deletions src/shogun/classifier/NearestCentroid.cpp
@@ -0,0 +1,120 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Philippe Tillet
*/

#include <shogun/classifier/NearestCentroid.h>

namespace shogun{

CNearestCentroid::CNearestCentroid() : CDistanceMachine()
{
init();
}

CNearestCentroid::CNearestCentroid(CDistance* d, CLabels* trainlab) : CDistanceMachine()
{
init();
ASSERT(d);
ASSERT(trainlab);
set_distance(d);
set_labels(trainlab);
}

CNearestCentroid::~CNearestCentroid()
{
if(m_is_trained)
distance->remove_lhs();
else
delete m_centroids;
}

void CNearestCentroid::init()
{
m_shrinking=0;
m_is_trained=false;
set_store_model_features(true);
m_centroids = new CSimpleFeatures<float64_t>();
}

void CNearestCentroid::store_model_features()
{

}

bool CNearestCentroid::train_machine(CFeatures* data)
{
ASSERT(m_labels);
ASSERT(distance);

if (data)
{
if (m_labels->get_num_labels() != data->get_num_vectors())
SG_ERROR("Number of training vectors does not match number of labels\n");
distance->init(data, data);
}
else
{
data = distance->get_lhs();
}
int32_t num_vectors = data->get_num_vectors();
int32_t num_classes = m_labels->get_num_classes();
int32_t num_feats = ((CSimpleFeatures<float64_t>*)data)->get_num_features();
float64_t* centroids = SG_CALLOC(float64_t, num_feats*num_classes);
for(int32_t i=0 ; i < num_feats*num_classes ; i++)
{
centroids[i]=0;
}
m_centroids->set_num_features(num_feats);
m_centroids->set_num_vectors(num_classes);

int64_t* num_per_class = new int64_t[num_classes];
for(int32_t i=0 ; i<num_classes ; i++)
{
num_per_class[i]=0;
}

for(int32_t idx=0 ; idx<num_vectors ; idx++)
{
int32_t current_len;
bool current_free;
int32_t current_class = m_labels->get_label(idx);
float64_t* target = centroids + num_feats*current_class;
float64_t* current = ((CSimpleFeatures<float64_t>*)data)->get_feature_vector(idx,current_len,current_free);
CMath::add(target,1.0,target,1.0,current,current_len);
num_per_class[current_class]++;
((CSimpleFeatures<float64_t>*)data)->free_feature_vector(current, current_len, current_free);
}


for(int32_t i=0 ; i<num_classes ; i++)
{
float64_t* target = centroids + num_feats*i;
int32_t total = num_per_class[i];
float64_t scale = 0;
if(total>1)
scale = 1.0/((float64_t)(total-1));
else
scale = 1.0/(float64_t)total;

CMath::scale_vector(scale,target,num_feats);
}

m_centroids->free_feature_matrix();
m_centroids->set_feature_matrix(centroids,num_feats,num_classes);


distance->init(m_centroids,distance->get_rhs());

m_is_trained=true;

delete [] num_per_class;

return true;
}

}
97 changes: 97 additions & 0 deletions src/shogun/classifier/NearestCentroid.h
@@ -0,0 +1,97 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Philippe Tillet
*/

#ifndef _NEAREST_CENTROID_H__
#define _NEAREST_CENTROID_H__

#include <stdio.h>
#include <shogun/lib/common.h>
#include <shogun/io/SGIO.h>
#include <shogun/features/Features.h>
#include <shogun/features/SimpleFeatures.h>
#include <shogun/distance/Distance.h>
#include <shogun/lib/CoverTree.h>
#include <shogun/machine/DistanceMachine.h>

namespace shogun
{

class CDistanceMachine;

/** @brief Class NearestCentroid, an implementation of Nearest Shrunk Centroid classifier
*
* To define how close examples are
* NearestCentroid requires a CDistance object to work with (e.g., CEuclideanDistance ).
*/

class CNearestCentroid : public CDistanceMachine{

public:
/**
* Default constructor
*/
CNearestCentroid();

/** constructor
*
* @param d distance
* @param trainlab labels for training
*/
CNearestCentroid(CDistance* distance, CLabels* trainlab);

/** Destructor
*/
virtual ~CNearestCentroid();

/** Set shrinking constant
*
* @param shrinking to be set
*/
void set_shrinking(float64_t shrinking) {
m_shrinking = shrinking ;
}

/** Get shrinking constant
*
* @return value of the shrinking constant
*/
float64_t get_shrinking() const{
return m_shrinking;
}

CSimpleFeatures<float64_t>* get_centroids() const{
return m_centroids;
}

protected:
/** train Nearest Centroid classifier
*
* @param data training data (parameter can be avoided if distance or
* kernel-based classifiers are used and distance/kernels are
* initialized with train data)
*
* @return whether training was successful
*/
virtual bool train_machine(CFeatures* data=NULL);

virtual void store_model_features();

private:
void init();

protected:
int32_t m_num_classes;
float64_t m_shrinking;
CSimpleFeatures<float64_t>* m_centroids;
bool m_is_trained;
};

}

#endif

0 comments on commit a690dd7

Please sign in to comment.