Skip to content

Commit

Permalink
Introduced CV MC storage
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Aug 24, 2012
1 parent cbc0150 commit cfc8322
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 1 deletion.
@@ -0,0 +1,84 @@
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Written (W) 2012 Heiko Strathmann
# Copyright (C) 2012 Berlin Institute of Technology and Max-Planck-Society
#

from numpy.random import randn
from numpy import *

# generate some overlapping training vectors
num_vectors=100
vec_distance=1
traindat=concatenate((randn(2,num_vectors)-vec_distance,
randn(2,num_vectors)+vec_distance), axis=1)
label_traindat=concatenate((zeros(num_vectors), ones(num_vectors)));

parameter_list = [[traindat,label_traindat]]

def evaluation_cross_validation_multiclass_storage(traindat=traindat, label_traindat=label_traindat):
from shogun.Evaluation import CrossValidation, CrossValidationResult
from shogun.Evaluation import CrossValidationPrintOutput
from shogun.Evaluation import CrossValidationMKLStorage, CrossValidationMulticlassStorage
from shogun.Evaluation import MulticlassAccuracy
from shogun.Evaluation import StratifiedCrossValidationSplitting
from shogun.Features import MulticlassLabels
from shogun.Features import RealFeatures, CombinedFeatures
from shogun.Kernel import GaussianKernel, CombinedKernel
from shogun.Classifier import MKLMulticlass
from shogun.Mathematics import Statistics, MSG_DEBUG

# training data, combined features all on same data
features=RealFeatures(traindat)
comb_features=CombinedFeatures()
comb_features.append_feature_obj(features)
comb_features.append_feature_obj(features)
comb_features.append_feature_obj(features)
labels=MulticlassLabels(label_traindat)

# kernel, different Gaussians combined
kernel=CombinedKernel()
kernel.append_kernel(GaussianKernel(10, 0.1))
kernel.append_kernel(GaussianKernel(10, 1))
kernel.append_kernel(GaussianKernel(10, 2))

# create mkl using libsvm, due to a mem-bug, interleaved is not possible
svm=MKLMulticlass(1.0,kernel,labels);
svm.set_kernel(kernel);

# splitting strategy for 5 fold cross-validation (for classification its better
# to use "StratifiedCrossValidation", but the standard
# "StratifiedCrossValidationSplitting" is also available
splitting_strategy=StratifiedCrossValidationSplitting(labels, 5)

# evaluation method
evaluation_criterium=MulticlassAccuracy()

# cross-validation instance
cross_validation=CrossValidation(svm, comb_features, labels,
splitting_strategy, evaluation_criterium)
cross_validation.set_autolock(False)

# append cross vlaidation output classes
#cross_validation.add_cross_validation_output(CrossValidationPrintOutput())
#mkl_storage=CrossValidationMKLStorage()
#cross_validation.add_cross_validation_output(mkl_storage)
multiclass_storage=CrossValidationMulticlassStorage()
cross_validation.add_cross_validation_output(multiclass_storage)
cross_validation.set_num_runs(3)

# perform cross-validation
cross_validation.io.set_loglevel(MSG_DEBUG)
result=cross_validation.evaluate()

roc_0_0_0 = multiclass_storage.get_fold_ROC(0,0,0)
print roc_0_0_0


if __name__=='__main__':
print('Evaluation CrossValidationMulticlassStorage')
evaluation_cross_validation_multiclass_storage(*parameter_list[0])
2 changes: 2 additions & 0 deletions src/interfaces/modular/Evaluation.i
Expand Up @@ -47,6 +47,7 @@
%rename(CrossValidationOutput) CCrossValidationOutput;
%rename(CrossValidationPrintOutput) CCrossValidationPrintOutput;
%rename(CrossValidationMKLStorage) CCrossValidationMKLStorage;
%rename(CrossValidationMulticlassStorage) CCrossValidationMulticlassStorage;
%rename(StructuredAccuracy) CStructuredAccuracy;

/* Include Class Headers to make them visible from within the target language */
Expand Down Expand Up @@ -76,4 +77,5 @@
%include <shogun/evaluation/CrossValidationOutput.h>
%include <shogun/evaluation/CrossValidationPrintOutput.h>
%include <shogun/evaluation/CrossValidationMKLStorage.h>
%include <shogun/evaluation/CrossValidationMulticlassStorage.h>
%include <shogun/evaluation/StructuredAccuracy.h>
1 change: 1 addition & 0 deletions src/interfaces/modular/Evaluation_includes.i
Expand Up @@ -25,6 +25,7 @@
#include <shogun/evaluation/CrossValidationOutput.h>
#include <shogun/evaluation/CrossValidationPrintOutput.h>
#include <shogun/evaluation/CrossValidationMKLStorage.h>
#include <shogun/evaluation/CrossValidationMulticlassStorage.h>
#include <shogun/evaluation/StructuredAccuracy.h>
%}

4 changes: 4 additions & 0 deletions src/shogun/evaluation/CrossValidation.cpp
Expand Up @@ -117,6 +117,8 @@ CEvaluationResult* CCrossValidation::evaluate()
{
current->init_num_runs(m_num_runs);
current->init_num_folds(m_splitting_strategy->get_num_subsets());
current->init_expose_labels(m_labels);
current->post_init();
SG_UNREF(current);
current=(CCrossValidationOutput*)
m_xval_outputs->get_next_element();
Expand Down Expand Up @@ -268,6 +270,7 @@ float64_t CCrossValidation::evaluate_one_run()
current->update_test_indices(subset_indices, "\t");
current->update_test_result(result_labels, "\t");
current->update_test_true_result(m_labels, "\t");
current->post_update_results();
current->update_evaluation_result(results[i], "\t");
SG_UNREF(current);
current=(CCrossValidationOutput*)
Expand Down Expand Up @@ -378,6 +381,7 @@ float64_t CCrossValidation::evaluate_one_run()
current->update_test_indices(subset_indices, "\t");
current->update_test_result(result_labels, "\t");
current->update_test_true_result(m_labels, "\t");
current->post_update_results();
current->update_evaluation_result(results[i], "\t");
SG_UNREF(current);
current=(CCrossValidationOutput*)
Expand Down
55 changes: 55 additions & 0 deletions src/shogun/evaluation/CrossValidationMulticlassStorage.cpp
@@ -0,0 +1,55 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
*/

#include <shogun/evaluation/CrossValidationMulticlassStorage.h>
#include <shogun/evaluation/ROCEvaluation.h>

using namespace shogun;

void CCrossValidationMulticlassStorage::post_init()
{
SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes);
m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
new (&m_fold_ROC_graphs[i]) SGMatrix<float64_t>();
}

void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels)
{
ASSERT((CMulticlassLabels*)labels);
m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes();
}

void CCrossValidationMulticlassStorage::post_update_results()
{
CROCEvaluation eval;
for (int32_t c=0; c<m_num_classes; c++)
{
SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c);
CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
pred_labels_binary->get_labels().display_vector();
eval.evaluate(pred_labels_binary, true_labels_binary);
m_fold_ROC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] =
eval.get_ROC();
SG_UNREF(pred_labels_binary);
SG_UNREF(true_labels_binary);
}
}

void CCrossValidationMulticlassStorage::update_test_result(CLabels* results, const char* prefix)
{
m_pred_labels = (CMulticlassLabels*)results;
}

void CCrossValidationMulticlassStorage::update_test_true_result(CLabels* results, const char* prefix)
{
m_true_labels = (CMulticlassLabels*)results;
}

114 changes: 114 additions & 0 deletions src/shogun/evaluation/CrossValidationMulticlassStorage.h
@@ -0,0 +1,114 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann, Sergey Lisitsyn
*
*/

#ifndef CROSSVALIDATIONMULTICLASSSTORAGE_H_
#define CROSSVALIDATIONMULTICLASSSTORAGE_H_

#include <shogun/evaluation/CrossValidationOutput.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/lib/SGMatrix.h>

namespace shogun
{

class CMachine;
class CLabels;
class CEvaluation;

/** @brief Class for storing multiclass evaluation information in every fold of cross-validation */
class CCrossValidationMulticlassStorage: public CCrossValidationOutput
{
public:

/** constructor */
CCrossValidationMulticlassStorage() : CCrossValidationOutput()
{
m_pred_labels = NULL;
m_true_labels = NULL;
m_num_classes = 0;
}

/** destructor */
virtual ~CCrossValidationMulticlassStorage()
{
for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
{
m_fold_ROC_graphs[i].~SGMatrix<float64_t>();
}
SG_FREE(m_fold_ROC_graphs);
};

/** returns ROC
*
* @param run run
* @param fold fold
* @param c class
* @return ROC of 'run' run, 'fold' fold and 'c' class
*/
SGMatrix<float64_t> get_fold_ROC(int32_t run, int32_t fold, int32_t c)
{
ASSERT(0<=run);
ASSERT(run<m_num_runs);
ASSERT(0<=fold);
ASSERT(fold<m_num_folds);
ASSERT(0<=c);
ASSERT(c<m_num_classes);
return m_fold_ROC_graphs[run*m_num_folds*m_num_classes+fold*m_num_classes+c];
}

/** post init */
virtual void post_init();

/** post update results */
virtual void post_update_results();

/** expose labels
* @param labels labels to expose
*/
virtual void init_expose_labels(CLabels* labels);

/** update test result
*
* @param results result labels for test/validation run
* @param prefix prefix for output
*/
virtual void update_test_result(CLabels* results,
const char* prefix="");

/** update test true result
*
* @param results ground truth labels for test/validation run
* @param prefix prefix for output
*/
virtual void update_test_true_result(CLabels* results,
const char* prefix="");

/** @return name of SG_SERIALIZABLE */
virtual const char* get_name() const { return "CrossValidationMulticlassStorage"; }

protected:

/** fold ROC graphs */
SGMatrix<float64_t>* m_fold_ROC_graphs;

/** predicted results */
CMulticlassLabels* m_pred_labels;

/** true labels */
CMulticlassLabels* m_true_labels;

/** number of classes */
int32_t m_num_classes;

};

}

#endif /* CROSSVALIDATIONMULTICLASSSTORAGE_H_ */
12 changes: 12 additions & 0 deletions src/shogun/evaluation/CrossValidationOutput.h
Expand Up @@ -75,6 +75,14 @@ class CCrossValidationOutput: public CSGObject
m_num_folds=num_folds;
}

/** initially expose labels before usage
* @param labels labels to expose to CV output
*/
virtual void init_expose_labels(CLabels* labels) { }

/** post init action (called once) */
virtual void post_init() { }

/** update run index (called every iteration). saves to local variable
*
* @param run_index index of current run
Expand Down Expand Up @@ -137,6 +145,10 @@ class CCrossValidationOutput: public CSGObject
virtual void update_test_true_result(CLabels* results,
const char* prefix="") {}

/** post update test and true results
*/
virtual void post_update_results() {}

/** update evaluate result
*
* @param result evaluation result
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/evaluation/CrossValidationPrintOutput.cpp
Expand Up @@ -81,8 +81,8 @@ void CCrossValidationPrintOutput::update_trained_machine(
CMulticlassMachine* mc_machine=(CMulticlassMachine*)machine;
for (int i=0; i<mc_machine->get_num_machines(); i++)
{
SG_PRINT("%smulti-class machine %d:\n", i);
CMachine* sub_machine=mc_machine->get_machine(i);
//SG_PRINT("%smulti-class machine %d:\n", i, sub_machine);
this->update_trained_machine(sub_machine, new_prefix);
SG_UNREF(sub_machine);
}
Expand Down

0 comments on commit cfc8322

Please sign in to comment.