Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix up clustering evaluation example
- seperate minimal doc from code - turn it into a proper function to enable integration tests
- Loading branch information
Soeren Sonnenburg
committed
Aug 29, 2012
1 parent
c35b5ed
commit 5398c6e
Showing
3 changed files
with
72 additions
and
68 deletions.
There are no files selected for viewing
Submodule data
updated
1 files
+1,797 −0 | toy/optdigits.tes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Example on how to evaluate the clustering performance (given ground-truth) |
137 changes: 70 additions & 67 deletions
137
examples/undocumented/python_modular/evaluation_clustering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,78 @@ | ||
##!/usr/bin/env python | ||
# Example on how to evaluate the clustering performance (given ground-truth) | ||
|
||
from shogun.Distance import EuclideanDistance | ||
from shogun.Features import RealFeatures | ||
from shogun.Features import MulticlassLabels | ||
from shogun.Evaluation import ClusteringAccuracy | ||
from shogun.Evaluation import ClusteringMutualInformation | ||
|
||
def get_dataset(): | ||
from os.path import exists | ||
try: | ||
from urllib2 import urlopen | ||
except ImportError: | ||
from urllib.request import urlopen | ||
|
||
filename = "../data/optdigits.tes" | ||
if exists(filename): | ||
return open(filename) | ||
else: | ||
print("Retrieving data...") | ||
return urlopen("http://archive.ics.uci.edu/ml/machine-learning-databases/optdigits/optdigits.tes") | ||
from os.path import exists | ||
filename = "../data/optdigits.tes" | ||
if exists(filename): | ||
return open(filename) | ||
else: | ||
#print("Retrieving data...") | ||
try: | ||
from urllib2 import urlopen | ||
except ImportError: | ||
from urllib.request import urlopen | ||
return urlopen("http://archive.ics.uci.edu/ml/machine-learning-databases/optdigits/optdigits.tes") | ||
|
||
def prepare_data(): | ||
from numpy import loadtxt | ||
|
||
stream = get_dataset() | ||
print("Loading data...") | ||
data = loadtxt(stream, delimiter=',') | ||
fea = data[:, :-1] | ||
gnd = data[:, -1] | ||
return (fea.T, gnd) | ||
|
||
def run_clustering(data, k): | ||
from shogun.Clustering import KMeans | ||
from shogun.Mathematics import Math_init_random | ||
|
||
Math_init_random(42) | ||
fea = RealFeatures(data) | ||
distance = EuclideanDistance(fea, fea) | ||
kmeans=KMeans(k, distance) | ||
from numpy import loadtxt | ||
stream = get_dataset() | ||
#print("Loading data...") | ||
data = loadtxt(stream, delimiter=',') | ||
fea = data[:, :-1] | ||
gnd = data[:, -1] | ||
return (fea.T, gnd) | ||
|
||
print("Running clustering...") | ||
kmeans.train() | ||
(fea, gnd_raw) = prepare_data() | ||
parameter_list = [[fea, gnd_raw, 10]] | ||
|
||
return kmeans.get_cluster_centers() | ||
|
||
def assign_labels(data, centroids): | ||
from shogun.Classifier import KNN | ||
from numpy import arange | ||
|
||
labels = MulticlassLabels(arange(0.,10.)) | ||
fea = RealFeatures(data) | ||
fea_centroids = RealFeatures(centroids) | ||
distance = EuclideanDistance(fea_centroids, fea_centroids) | ||
knn = KNN(1, distance, labels) | ||
knn.train() | ||
return knn.apply(fea) | ||
def run_clustering(data, k): | ||
from shogun.Clustering import KMeans | ||
from shogun.Mathematics import Math_init_random | ||
from shogun.Distance import EuclideanDistance | ||
from shogun.Features import RealFeatures | ||
|
||
Math_init_random(42) | ||
fea = RealFeatures(data) | ||
distance = EuclideanDistance(fea, fea) | ||
kmeans=KMeans(k, distance) | ||
|
||
#print("Running clustering...") | ||
kmeans.train() | ||
|
||
return kmeans.get_cluster_centers() | ||
|
||
def assign_labels(data, centroids, ncenters): | ||
from shogun.Distance import EuclideanDistance | ||
from shogun.Features import RealFeatures, MulticlassLabels | ||
from shogun.Classifier import KNN | ||
from numpy import arange | ||
|
||
labels = MulticlassLabels(arange(0.,ncenters)) | ||
fea = RealFeatures(data) | ||
fea_centroids = RealFeatures(centroids) | ||
distance = EuclideanDistance(fea_centroids, fea_centroids) | ||
knn = KNN(1, distance, labels) | ||
knn.train() | ||
return knn.apply(fea) | ||
|
||
def evaluation_clustering(features=fea, ground_truth=gnd_raw, ncenters=10): | ||
from shogun.Evaluation import ClusteringAccuracy, ClusteringMutualInformation | ||
from shogun.Features import MulticlassLabels | ||
centroids = run_clustering(features, ncenters) | ||
gnd_hat = assign_labels(features, centroids, ncenters) | ||
gnd = MulticlassLabels(ground_truth) | ||
|
||
AccuracyEval = ClusteringAccuracy() | ||
AccuracyEval.best_map(gnd_hat, gnd) | ||
|
||
accuracy = AccuracyEval.evaluate(gnd_hat, gnd) | ||
#print(('Clustering accuracy = %.4f' % accuracy)) | ||
|
||
MIEval = ClusteringMutualInformation() | ||
mutual_info = MIEval.evaluate(gnd_hat, gnd) | ||
#print(('Clustering mutual information = %.4f' % mutual_info)) | ||
|
||
return gnd, accuracy, mutual_info | ||
|
||
if __name__ == '__main__': | ||
(fea, gnd_raw) = prepare_data() | ||
centroids = run_clustering(fea, 10) | ||
gnd_hat = assign_labels(fea, centroids) | ||
gnd = MulticlassLabels(gnd_raw) | ||
|
||
AccuracyEval = ClusteringAccuracy() | ||
AccuracyEval.best_map(gnd_hat, gnd) | ||
|
||
accuracy = AccuracyEval.evaluate(gnd_hat, gnd) | ||
print(('Clustering accuracy = %.4f' % accuracy)) | ||
|
||
MIEval = ClusteringMutualInformation() | ||
mutual_info = MIEval.evaluate(gnd_hat, gnd) | ||
print(('Clustering mutual information = %.4f' % mutual_info)) | ||
|
||
print('Evaluation Clustering') | ||
evaluation_clustering(*parameter_list[0]) |