Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'nearest_centroid' of git://github.com/PhilippeTillet/sh…
…ogun
- Loading branch information
Showing
3 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
54 changes: 54 additions & 0 deletions
54
examples/undocumented/libshogun/classifier_nearest_centroid.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#include <shogun/features/Labels.h> | ||
#include <shogun/features/SimpleFeatures.h> | ||
#include <shogun/distance/EuclidianDistance.h> | ||
#include <shogun/classifier/NearestCentroid.h> | ||
#include <shogun/base/init.h> | ||
|
||
using namespace shogun; | ||
|
||
void print_message(FILE* target, const char* str) | ||
{ | ||
fprintf(target, "%s", str); | ||
} | ||
|
||
int main(){ | ||
init_shogun(&print_message); | ||
index_t num_vec=7; | ||
index_t num_feat=2; | ||
index_t num_class=2; | ||
|
||
// create some data | ||
SGMatrix<float64_t> matrix(num_feat, num_vec); | ||
CMath::range_fill_vector(matrix.matrix, num_feat*num_vec); | ||
|
||
// Create features ; shogun will now own the matrix created | ||
CSimpleFeatures<float64_t>* features=new CSimpleFeatures<float64_t>(matrix); | ||
|
||
CMath::display_matrix(matrix.matrix,num_feat,num_vec); | ||
|
||
//Create labels | ||
CLabels* labels=new CLabels(num_vec); | ||
for (index_t i=0; i<num_vec; ++i) | ||
labels->set_label(i, i%num_class); | ||
|
||
//Create Euclidian Distance | ||
CEuclidianDistance* distance = new CEuclidianDistance(features,features); | ||
|
||
//Create Nearest Centroid | ||
CNearestCentroid* nearest_centroid = new CNearestCentroid(distance, labels); | ||
nearest_centroid->train(); | ||
|
||
// classify on training examples | ||
CLabels* output=nearest_centroid->apply(); | ||
CMath::display_vector(output->get_labels().vector, output->get_num_labels(), | ||
"batch output"); | ||
|
||
SG_UNREF(output); | ||
|
||
// free up memory | ||
SG_UNREF(nearest_centroid); | ||
|
||
exit_shogun(); | ||
return 0; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Philippe Tillet | ||
*/ | ||
|
||
#include <shogun/classifier/NearestCentroid.h> | ||
|
||
namespace shogun{ | ||
|
||
CNearestCentroid::CNearestCentroid() : CDistanceMachine() | ||
{ | ||
init(); | ||
} | ||
|
||
CNearestCentroid::CNearestCentroid(CDistance* d, CLabels* trainlab) : CDistanceMachine() | ||
{ | ||
init(); | ||
ASSERT(d); | ||
ASSERT(trainlab); | ||
set_distance(d); | ||
set_labels(trainlab); | ||
} | ||
|
||
CNearestCentroid::~CNearestCentroid() | ||
{ | ||
if(m_is_trained) | ||
distance->remove_lhs(); | ||
else | ||
delete m_centroids; | ||
} | ||
|
||
void CNearestCentroid::init() | ||
{ | ||
m_shrinking=0; | ||
m_is_trained=false; | ||
set_store_model_features(true); | ||
m_centroids = new CSimpleFeatures<float64_t>(); | ||
} | ||
|
||
void CNearestCentroid::store_model_features() | ||
{ | ||
|
||
} | ||
|
||
bool CNearestCentroid::train_machine(CFeatures* data) | ||
{ | ||
ASSERT(m_labels); | ||
ASSERT(distance); | ||
|
||
if (data) | ||
{ | ||
if (m_labels->get_num_labels() != data->get_num_vectors()) | ||
SG_ERROR("Number of training vectors does not match number of labels\n"); | ||
distance->init(data, data); | ||
} | ||
else | ||
{ | ||
data = distance->get_lhs(); | ||
} | ||
int32_t num_vectors = data->get_num_vectors(); | ||
int32_t num_classes = m_labels->get_num_classes(); | ||
int32_t num_feats = ((CSimpleFeatures<float64_t>*)data)->get_num_features(); | ||
float64_t* centroids = SG_CALLOC(float64_t, num_feats*num_classes); | ||
for(int32_t i=0 ; i < num_feats*num_classes ; i++) | ||
{ | ||
centroids[i]=0; | ||
} | ||
m_centroids->set_num_features(num_feats); | ||
m_centroids->set_num_vectors(num_classes); | ||
|
||
int64_t* num_per_class = new int64_t[num_classes]; | ||
for(int32_t i=0 ; i<num_classes ; i++) | ||
{ | ||
num_per_class[i]=0; | ||
} | ||
|
||
for(int32_t idx=0 ; idx<num_vectors ; idx++) | ||
{ | ||
int32_t current_len; | ||
bool current_free; | ||
int32_t current_class = m_labels->get_label(idx); | ||
float64_t* target = centroids + num_feats*current_class; | ||
float64_t* current = ((CSimpleFeatures<float64_t>*)data)->get_feature_vector(idx,current_len,current_free); | ||
CMath::add(target,1.0,target,1.0,current,current_len); | ||
num_per_class[current_class]++; | ||
((CSimpleFeatures<float64_t>*)data)->free_feature_vector(current, current_len, current_free); | ||
} | ||
|
||
|
||
for(int32_t i=0 ; i<num_classes ; i++) | ||
{ | ||
float64_t* target = centroids + num_feats*i; | ||
int32_t total = num_per_class[i]; | ||
float64_t scale = 0; | ||
if(total>1) | ||
scale = 1.0/((float64_t)(total-1)); | ||
else | ||
scale = 1.0/(float64_t)total; | ||
|
||
CMath::scale_vector(scale,target,num_feats); | ||
} | ||
|
||
m_centroids->free_feature_matrix(); | ||
m_centroids->set_feature_matrix(centroids,num_feats,num_classes); | ||
|
||
|
||
distance->init(m_centroids,distance->get_rhs()); | ||
|
||
m_is_trained=true; | ||
|
||
delete [] num_per_class; | ||
|
||
return true; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Philippe Tillet | ||
*/ | ||
|
||
#ifndef _NEAREST_CENTROID_H__ | ||
#define _NEAREST_CENTROID_H__ | ||
|
||
#include <stdio.h> | ||
#include <shogun/lib/common.h> | ||
#include <shogun/io/SGIO.h> | ||
#include <shogun/features/Features.h> | ||
#include <shogun/features/SimpleFeatures.h> | ||
#include <shogun/distance/Distance.h> | ||
#include <shogun/lib/CoverTree.h> | ||
#include <shogun/machine/DistanceMachine.h> | ||
|
||
namespace shogun | ||
{ | ||
|
||
class CDistanceMachine; | ||
|
||
/** @brief Class NearestCentroid, an implementation of Nearest Shrunk Centroid classifier | ||
* | ||
* To define how close examples are | ||
* NearestCentroid requires a CDistance object to work with (e.g., CEuclideanDistance ). | ||
*/ | ||
|
||
class CNearestCentroid : public CDistanceMachine{ | ||
|
||
public: | ||
/** | ||
* Default constructor | ||
*/ | ||
CNearestCentroid(); | ||
|
||
/** constructor | ||
* | ||
* @param d distance | ||
* @param trainlab labels for training | ||
*/ | ||
CNearestCentroid(CDistance* distance, CLabels* trainlab); | ||
|
||
/** Destructor | ||
*/ | ||
virtual ~CNearestCentroid(); | ||
|
||
/** Set shrinking constant | ||
* | ||
* @param shrinking to be set | ||
*/ | ||
void set_shrinking(float64_t shrinking) { | ||
m_shrinking = shrinking ; | ||
} | ||
|
||
/** Get shrinking constant | ||
* | ||
* @return value of the shrinking constant | ||
*/ | ||
float64_t get_shrinking() const{ | ||
return m_shrinking; | ||
} | ||
|
||
CSimpleFeatures<float64_t>* get_centroids() const{ | ||
return m_centroids; | ||
} | ||
|
||
protected: | ||
/** train Nearest Centroid classifier | ||
* | ||
* @param data training data (parameter can be avoided if distance or | ||
* kernel-based classifiers are used and distance/kernels are | ||
* initialized with train data) | ||
* | ||
* @return whether training was successful | ||
*/ | ||
virtual bool train_machine(CFeatures* data=NULL); | ||
|
||
virtual void store_model_features(); | ||
|
||
private: | ||
void init(); | ||
|
||
protected: | ||
int32_t m_num_classes; | ||
float64_t m_shrinking; | ||
CSimpleFeatures<float64_t>* m_centroids; | ||
bool m_is_trained; | ||
}; | ||
|
||
} | ||
|
||
#endif |