Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
273 additions
and
1 deletion.
There are no files selected for viewing
84 changes: 84 additions & 0 deletions
84
examples/undocumented/python_modular/evaluation_cross_validation_multiclass_storage.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# | ||
# This program is free software; you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation; either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# Written (W) 2012 Heiko Strathmann | ||
# Copyright (C) 2012 Berlin Institute of Technology and Max-Planck-Society | ||
# | ||
|
||
from numpy.random import randn | ||
from numpy import * | ||
|
||
# generate some overlapping training vectors | ||
num_vectors=100 | ||
vec_distance=1 | ||
traindat=concatenate((randn(2,num_vectors)-vec_distance, | ||
randn(2,num_vectors)+vec_distance), axis=1) | ||
label_traindat=concatenate((zeros(num_vectors), ones(num_vectors))); | ||
|
||
parameter_list = [[traindat,label_traindat]] | ||
|
||
def evaluation_cross_validation_multiclass_storage(traindat=traindat, label_traindat=label_traindat): | ||
from shogun.Evaluation import CrossValidation, CrossValidationResult | ||
from shogun.Evaluation import CrossValidationPrintOutput | ||
from shogun.Evaluation import CrossValidationMKLStorage, CrossValidationMulticlassStorage | ||
from shogun.Evaluation import MulticlassAccuracy | ||
from shogun.Evaluation import StratifiedCrossValidationSplitting | ||
from shogun.Features import MulticlassLabels | ||
from shogun.Features import RealFeatures, CombinedFeatures | ||
from shogun.Kernel import GaussianKernel, CombinedKernel | ||
from shogun.Classifier import MKLMulticlass | ||
from shogun.Mathematics import Statistics, MSG_DEBUG | ||
|
||
# training data, combined features all on same data | ||
features=RealFeatures(traindat) | ||
comb_features=CombinedFeatures() | ||
comb_features.append_feature_obj(features) | ||
comb_features.append_feature_obj(features) | ||
comb_features.append_feature_obj(features) | ||
labels=MulticlassLabels(label_traindat) | ||
|
||
# kernel, different Gaussians combined | ||
kernel=CombinedKernel() | ||
kernel.append_kernel(GaussianKernel(10, 0.1)) | ||
kernel.append_kernel(GaussianKernel(10, 1)) | ||
kernel.append_kernel(GaussianKernel(10, 2)) | ||
|
||
# create mkl using libsvm, due to a mem-bug, interleaved is not possible | ||
svm=MKLMulticlass(1.0,kernel,labels); | ||
svm.set_kernel(kernel); | ||
|
||
# splitting strategy for 5 fold cross-validation (for classification its better | ||
# to use "StratifiedCrossValidation", but the standard | ||
# "StratifiedCrossValidationSplitting" is also available | ||
splitting_strategy=StratifiedCrossValidationSplitting(labels, 5) | ||
|
||
# evaluation method | ||
evaluation_criterium=MulticlassAccuracy() | ||
|
||
# cross-validation instance | ||
cross_validation=CrossValidation(svm, comb_features, labels, | ||
splitting_strategy, evaluation_criterium) | ||
cross_validation.set_autolock(False) | ||
|
||
# append cross vlaidation output classes | ||
#cross_validation.add_cross_validation_output(CrossValidationPrintOutput()) | ||
#mkl_storage=CrossValidationMKLStorage() | ||
#cross_validation.add_cross_validation_output(mkl_storage) | ||
multiclass_storage=CrossValidationMulticlassStorage() | ||
cross_validation.add_cross_validation_output(multiclass_storage) | ||
cross_validation.set_num_runs(3) | ||
|
||
# perform cross-validation | ||
cross_validation.io.set_loglevel(MSG_DEBUG) | ||
result=cross_validation.evaluate() | ||
|
||
roc_0_0_0 = multiclass_storage.get_fold_ROC(0,0,0) | ||
print roc_0_0_0 | ||
|
||
|
||
if __name__=='__main__': | ||
print('Evaluation CrossValidationMulticlassStorage') | ||
evaluation_cross_validation_multiclass_storage(*parameter_list[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
src/shogun/evaluation/CrossValidationMulticlassStorage.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann | ||
*/ | ||
|
||
#include <shogun/evaluation/CrossValidationMulticlassStorage.h> | ||
#include <shogun/evaluation/ROCEvaluation.h> | ||
|
||
using namespace shogun; | ||
|
||
void CCrossValidationMulticlassStorage::post_init() | ||
{ | ||
SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes); | ||
m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes); | ||
for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) | ||
new (&m_fold_ROC_graphs[i]) SGMatrix<float64_t>(); | ||
} | ||
|
||
void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels) | ||
{ | ||
ASSERT((CMulticlassLabels*)labels); | ||
m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes(); | ||
} | ||
|
||
void CCrossValidationMulticlassStorage::post_update_results() | ||
{ | ||
CROCEvaluation eval; | ||
for (int32_t c=0; c<m_num_classes; c++) | ||
{ | ||
SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c); | ||
CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c); | ||
CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c); | ||
pred_labels_binary->get_labels().display_vector(); | ||
eval.evaluate(pred_labels_binary, true_labels_binary); | ||
m_fold_ROC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = | ||
eval.get_ROC(); | ||
SG_UNREF(pred_labels_binary); | ||
SG_UNREF(true_labels_binary); | ||
} | ||
} | ||
|
||
void CCrossValidationMulticlassStorage::update_test_result(CLabels* results, const char* prefix) | ||
{ | ||
m_pred_labels = (CMulticlassLabels*)results; | ||
} | ||
|
||
void CCrossValidationMulticlassStorage::update_test_true_result(CLabels* results, const char* prefix) | ||
{ | ||
m_true_labels = (CMulticlassLabels*)results; | ||
} | ||
|
114 changes: 114 additions & 0 deletions
114
src/shogun/evaluation/CrossValidationMulticlassStorage.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Heiko Strathmann, Sergey Lisitsyn | ||
* | ||
*/ | ||
|
||
#ifndef CROSSVALIDATIONMULTICLASSSTORAGE_H_ | ||
#define CROSSVALIDATIONMULTICLASSSTORAGE_H_ | ||
|
||
#include <shogun/evaluation/CrossValidationOutput.h> | ||
#include <shogun/labels/MulticlassLabels.h> | ||
#include <shogun/lib/SGMatrix.h> | ||
|
||
namespace shogun | ||
{ | ||
|
||
class CMachine; | ||
class CLabels; | ||
class CEvaluation; | ||
|
||
/** @brief Class for storing multiclass evaluation information in every fold of cross-validation */ | ||
class CCrossValidationMulticlassStorage: public CCrossValidationOutput | ||
{ | ||
public: | ||
|
||
/** constructor */ | ||
CCrossValidationMulticlassStorage() : CCrossValidationOutput() | ||
{ | ||
m_pred_labels = NULL; | ||
m_true_labels = NULL; | ||
m_num_classes = 0; | ||
} | ||
|
||
/** destructor */ | ||
virtual ~CCrossValidationMulticlassStorage() | ||
{ | ||
for (int32_t i=0; i<m_num_folds*m_num_runs; i++) | ||
{ | ||
m_fold_ROC_graphs[i].~SGMatrix<float64_t>(); | ||
} | ||
SG_FREE(m_fold_ROC_graphs); | ||
}; | ||
|
||
/** returns ROC | ||
* | ||
* @param run run | ||
* @param fold fold | ||
* @param c class | ||
* @return ROC of 'run' run, 'fold' fold and 'c' class | ||
*/ | ||
SGMatrix<float64_t> get_fold_ROC(int32_t run, int32_t fold, int32_t c) | ||
{ | ||
ASSERT(0<=run); | ||
ASSERT(run<m_num_runs); | ||
ASSERT(0<=fold); | ||
ASSERT(fold<m_num_folds); | ||
ASSERT(0<=c); | ||
ASSERT(c<m_num_classes); | ||
return m_fold_ROC_graphs[run*m_num_folds*m_num_classes+fold*m_num_classes+c]; | ||
} | ||
|
||
/** post init */ | ||
virtual void post_init(); | ||
|
||
/** post update results */ | ||
virtual void post_update_results(); | ||
|
||
/** expose labels | ||
* @param labels labels to expose | ||
*/ | ||
virtual void init_expose_labels(CLabels* labels); | ||
|
||
/** update test result | ||
* | ||
* @param results result labels for test/validation run | ||
* @param prefix prefix for output | ||
*/ | ||
virtual void update_test_result(CLabels* results, | ||
const char* prefix=""); | ||
|
||
/** update test true result | ||
* | ||
* @param results ground truth labels for test/validation run | ||
* @param prefix prefix for output | ||
*/ | ||
virtual void update_test_true_result(CLabels* results, | ||
const char* prefix=""); | ||
|
||
/** @return name of SG_SERIALIZABLE */ | ||
virtual const char* get_name() const { return "CrossValidationMulticlassStorage"; } | ||
|
||
protected: | ||
|
||
/** fold ROC graphs */ | ||
SGMatrix<float64_t>* m_fold_ROC_graphs; | ||
|
||
/** predicted results */ | ||
CMulticlassLabels* m_pred_labels; | ||
|
||
/** true labels */ | ||
CMulticlassLabels* m_true_labels; | ||
|
||
/** number of classes */ | ||
int32_t m_num_classes; | ||
|
||
}; | ||
|
||
} | ||
|
||
#endif /* CROSSVALIDATIONMULTICLASSSTORAGE_H_ */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters