Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Merge pull request #386 from iglesias/qda
QDA
- Loading branch information
Showing
13 changed files
with
960 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Fernando José Iglesias García | ||
* Copyright (C) 2012 Fernando José Iglesias García | ||
*/ | ||
|
||
#include <shogun/base/init.h> | ||
#include <shogun/classifier/QDA.h> | ||
#include <shogun/features/SimpleFeatures.h> | ||
#include <shogun/io/SGIO.h> | ||
#include <shogun/lib/common.h> | ||
#include <shogun/mathematics/Math.h> | ||
|
||
using namespace shogun; | ||
|
||
#define NUM 100 | ||
#define DIMS 2 | ||
#define DIST 0.5 | ||
|
||
void gen_rand_data(SGVector< float64_t > lab, SGMatrix< float64_t > feat) | ||
{ | ||
for (int32_t i = 0; i < NUM; i++) | ||
{ | ||
if (i < NUM/2) | ||
{ | ||
lab[i] = 0.0; | ||
|
||
for (int32_t j = 0; j < DIMS; j++) | ||
feat[i*DIMS + j] = CMath::random(0.0,1.0) + DIST; | ||
} | ||
else | ||
{ | ||
lab[i] = 1.0; | ||
|
||
for (int32_t j = 0; j < DIMS; j++) | ||
feat[i*DIMS + j] = CMath::random(0.0,1.0) - DIST; | ||
} | ||
} | ||
} | ||
|
||
int main(int argc, char ** argv) | ||
{ | ||
const int32_t feature_cache = 0; | ||
|
||
init_shogun_with_defaults(); | ||
|
||
SGVector< float64_t > lab(NUM); | ||
SGMatrix< float64_t > feat(NUM, DIMS); | ||
|
||
gen_rand_data(lab, feat); | ||
|
||
// Create train labels | ||
CLabels* labels = new CLabels(lab); | ||
|
||
// Create train features | ||
CSimpleFeatures< float64_t >* features = new CSimpleFeatures< float64_t >(feature_cache); | ||
features->set_feature_matrix(feat.matrix, DIMS, NUM); | ||
|
||
// Create QDA classifier | ||
CQDA* qda = new CQDA(features, labels); | ||
SG_REF(qda); | ||
qda->train(); | ||
|
||
// Classify and display output | ||
CLabels* out_labels = qda->apply(); | ||
SG_REF(out_labels); | ||
|
||
// Free memory | ||
SG_UNREF(out_labels); | ||
SG_UNREF(qda); | ||
|
||
exit_shogun(); | ||
|
||
return 0; | ||
} |
117 changes: 117 additions & 0 deletions
117
examples/undocumented/python_modular/graphical/multiclass_qda.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
Shogun demo | ||
Fernando J. Iglesias Garcia | ||
""" | ||
|
||
import numpy as np | ||
import matplotlib as mpl | ||
import pylab | ||
import util | ||
|
||
from scipy import linalg | ||
from shogun.Classifier import QDA | ||
from shogun.Features import RealFeatures, Labels | ||
|
||
# colormap | ||
cmap = mpl.colors.LinearSegmentedColormap('color_classes', | ||
{'red': [(0, 1, 1), | ||
(1, .7, .7)], | ||
'green': [(0, 1, 1), | ||
(1, .7, .7)], | ||
'blue': [(0, 1, 1), | ||
(1, .7, .7)]}) | ||
pylab.cm.register_cmap(cmap = cmap) | ||
|
||
# Generate data from Gaussian distributions | ||
def gen_data(): | ||
np.random.seed(0) | ||
covs = np.array([[[0., -1. ], [2.5, .7]], | ||
[[3., -1.5], [1.2, .3]], | ||
[[ 2, 0 ], [ .0, 1.5 ]]]) | ||
X = np.r_[np.dot(np.random.randn(N, dim), covs[0]) + np.array([-4, 3]), | ||
np.dot(np.random.randn(N, dim), covs[1]) + np.array([-1, -5]), | ||
np.dot(np.random.randn(N, dim), covs[2]) + np.array([3, 4])]; | ||
Y = np.hstack((np.zeros(N), np.ones(N), 2*np.ones(N))) | ||
return X, Y | ||
|
||
def plot_data(qda, X, y, y_pred, ax): | ||
X0, X1, X2 = X[y == 0], X[y == 1], X[y == 2] | ||
|
||
# Correctly classified | ||
tp = (y == y_pred) | ||
tp0, tp1, tp2 = tp[y == 0], tp[y == 1], tp[y == 2] | ||
X0_tp, X1_tp, X2_tp = X0[tp0], X1[tp1], X2[tp2] | ||
|
||
# Misclassified | ||
X0_fp, X1_fp, X2_fp = X0[tp0 != True], X1[tp1 != True], X2[tp2 != True] | ||
|
||
# Class 0 data | ||
pylab.plot(X0_tp[:, 0], X0_tp[:, 1], 'o', color = cols[0]) | ||
pylab.plot(X0_fp[:, 0], X0_fp[:, 1], 's', color = cols[0]) | ||
m0 = qda.get_mean(0) | ||
pylab.plot(m0[0], m0[1], 'o', color = 'black', markersize = 8) | ||
|
||
# Class 1 data | ||
pylab.plot(X1_tp[:, 0], X1_tp[:, 1], 'o', color = cols[1]) | ||
pylab.plot(X1_fp[:, 0], X1_fp[:, 1], 's', color = cols[1]) | ||
m1 = qda.get_mean(1) | ||
pylab.plot(m1[0], m1[1], 'o', color = 'black', markersize = 8) | ||
|
||
# Class 2 data | ||
pylab.plot(X2_tp[:, 0], X2_tp[:, 1], 'o', color = cols[2]) | ||
pylab.plot(X2_fp[:, 0], X2_fp[:, 1], 's', color = cols[2]) | ||
m2 = qda.get_mean(2) | ||
pylab.plot(m2[0], m2[1], 'o', color = 'black', markersize = 8) | ||
|
||
def plot_cov(plot, mean, cov, color): | ||
v, w = linalg.eigh(cov) | ||
u = w[0] / linalg.norm(w[0]) | ||
angle = np.arctan(u[1] / u[0]) # rad | ||
angle = 180 * angle / np.pi # degrees | ||
# Filled gaussian at 2 standard deviation | ||
ell = mpl.patches.Ellipse(mean, 2*v[0]**0.5, 2*v[1]**0.5, 180 + angle, color = color) | ||
ell.set_clip_box(plot.bbox) | ||
ell.set_alpha(0.5) | ||
plot.add_artist(ell) | ||
|
||
def plot_regions(qda): | ||
nx, ny = 500, 500 | ||
x_min, x_max = pylab.xlim() | ||
y_min, y_max = pylab.ylim() | ||
xx, yy = np.meshgrid(np.linspace(x_min, x_max, nx), | ||
np.linspace(y_min, y_max, ny)) | ||
dense = RealFeatures(np.array((np.ravel(xx), np.ravel(yy)))) | ||
dense_labels = qda.apply(dense).get_labels() | ||
Z = dense_labels.reshape(xx.shape) | ||
pylab.pcolormesh(xx, yy, Z) | ||
pylab.contour(xx, yy, Z, linewidths = 3, colors = 'k') | ||
|
||
# Number of classes | ||
M = 3 | ||
# Number of samples of each class | ||
N = 300 | ||
# Dimension of the data | ||
dim = 2 | ||
|
||
cols = ['blue', 'green', 'red'] | ||
|
||
fig = pylab.figure() | ||
ax = fig.add_subplot(111) | ||
pylab.title('Quadratic Discrimant Analysis') | ||
|
||
X, y = gen_data() | ||
|
||
labels = Labels(y) | ||
features = RealFeatures(X.T) | ||
qda = QDA(features, labels, 1e-4, True) | ||
qda.train() | ||
ypred = qda.apply().get_labels() | ||
|
||
plot_data(qda, X, y, ypred, ax) | ||
for i in range(M): | ||
plot_cov(ax, qda.get_mean(i), qda.get_cov(i), cols[i]) | ||
plot_regions(qda) | ||
|
||
pylab.connect('key_press_event', util.quit) | ||
pylab.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from shogun.Features import RealFeatures | ||
from shogun.Features import Labels | ||
from shogun.Classifier import QDA | ||
from pylab import pcolor, contour, colorbar, connect, show, plot, axis | ||
|
||
import numpy as np | ||
import util | ||
|
||
N = 500 | ||
size = 100 | ||
|
||
|
||
# positive examples | ||
mean_pos = [-1, 4] | ||
cov_pos = [[1,40], [50, -2]] | ||
|
||
x_pos, y_pos = np.random.multivariate_normal(mean_pos, cov_pos, 500).T | ||
plot(x_pos, y_pos, 'bo'); | ||
|
||
# negative examples | ||
mean_neg = [0, -3] | ||
cov_neg = [[100,50], [20, 3]] | ||
|
||
x_neg, y_neg = np.random.multivariate_normal(mean_neg, cov_neg, 500).T | ||
plot(x_neg, y_neg, 'ro'); | ||
|
||
# train qda | ||
labels = Labels( np.concatenate([np.zeros(N), np.ones(N)]) ) | ||
pos = np.array([x_pos, y_pos]) | ||
neg = np.array([x_neg, y_neg]) | ||
features = RealFeatures( np.array(np.concatenate([pos, neg], 1)) ) | ||
|
||
qda = QDA() | ||
qda.set_labels(labels) | ||
qda.train(features) | ||
|
||
# compute output plot iso-lines | ||
xs = np.array(np.concatenate([x_pos, x_neg])) | ||
ys = np.array(np.concatenate([y_pos, y_neg])) | ||
|
||
x1_max = max(1.2*xs) | ||
x1_min = min(1.2*xs) | ||
x2_max = max(1.2*ys) | ||
x2_min = min(1.2*ys) | ||
|
||
x1 = np.linspace(x1_min, x1_max, size) | ||
x2 = np.linspace(x2_min, x2_max, size) | ||
|
||
x, y = np.meshgrid(x1, x2) | ||
|
||
dense = RealFeatures( np.array((np.ravel(x), np.ravel(y))) ) | ||
dense_labels = qda.apply(dense).get_labels() | ||
|
||
z = dense_labels.reshape((size, size)) | ||
|
||
pcolor(x, y, z, shading = 'interp') | ||
contour(x, y, z, linewidths = 1, colors = 'black', hold = True) | ||
|
||
axis([x1_min, x1_max, x2_min, x2_max]) | ||
|
||
connect('key_press_event', util.quit) | ||
|
||
show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.