Skip to content

Commit

Permalink
Extended CV MC storage more
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Aug 25, 2012
1 parent d30bf9e commit 56d4a32
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 7 deletions.
Expand Up @@ -24,7 +24,7 @@ def evaluation_cross_validation_multiclass_storage(traindat=traindat, label_trai
from shogun.Evaluation import CrossValidation, CrossValidationResult
from shogun.Evaluation import CrossValidationPrintOutput
from shogun.Evaluation import CrossValidationMKLStorage, CrossValidationMulticlassStorage
from shogun.Evaluation import MulticlassAccuracy
from shogun.Evaluation import MulticlassAccuracy, F1Measure
from shogun.Evaluation import StratifiedCrossValidationSplitting
from shogun.Features import MulticlassLabels
from shogun.Features import RealFeatures, CombinedFeatures
Expand Down Expand Up @@ -68,6 +68,7 @@ def evaluation_cross_validation_multiclass_storage(traindat=traindat, label_trai
#mkl_storage=CrossValidationMKLStorage()
#cross_validation.add_cross_validation_output(mkl_storage)
multiclass_storage=CrossValidationMulticlassStorage()
multiclass_storage.append_binary_evaluation(F1Measure())
cross_validation.add_cross_validation_output(multiclass_storage)
cross_validation.set_num_runs(3)

Expand All @@ -77,6 +78,8 @@ def evaluation_cross_validation_multiclass_storage(traindat=traindat, label_trai

roc_0_0_0 = multiclass_storage.get_fold_ROC(0,0,0)
print roc_0_0_0
auc_0_0_0 = multiclass_storage.get_fold_evaluation_result(0,0,0,0)
print auc_0_0_0


if __name__=='__main__':
Expand Down
28 changes: 25 additions & 3 deletions src/shogun/evaluation/CrossValidationMulticlassStorage.cpp
Expand Up @@ -9,6 +9,7 @@

#include <shogun/evaluation/CrossValidationMulticlassStorage.h>
#include <shogun/evaluation/ROCEvaluation.h>
#include <shogun/evaluation/PRCEvaluation.h>

using namespace shogun;

Expand All @@ -18,6 +19,13 @@ void CCrossValidationMulticlassStorage::post_init()
m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
new (&m_fold_ROC_graphs[i]) SGMatrix<float64_t>();

SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes);
m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
new (&m_fold_PRC_graphs[i]) SGMatrix<float64_t>();

m_evaluations_results = SGVector<float64_t>(m_num_folds*m_num_runs*m_num_classes*m_binary_evaluations->get_num_elements());
}

void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels)
Expand All @@ -28,15 +36,29 @@ void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels)

void CCrossValidationMulticlassStorage::post_update_results()
{
CROCEvaluation eval;
CROCEvaluation eval_ROC;
CPRCEvaluation eval_PRC;
int32_t n_evals = m_binary_evaluations->get_num_elements();
for (int32_t c=0; c<m_num_classes; c++)
{
SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c);
CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
eval.evaluate(pred_labels_binary, true_labels_binary);
eval_ROC.evaluate(pred_labels_binary, true_labels_binary);
m_fold_ROC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] =
eval.get_ROC();
eval_ROC.get_ROC();
eval_PRC.evaluate(pred_labels_binary, true_labels_binary);
m_fold_PRC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] =
eval_PRC.get_PRC();

for (int32_t i=0; i<n_evals; i++)
{
CBinaryClassEvaluation* evaluator = (CBinaryClassEvaluation*)m_binary_evaluations->get_element_safe(i);
m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] =
evaluator->evaluate(pred_labels_binary, true_labels_binary);
SG_UNREF(evaluator);
}

SG_UNREF(pred_labels_binary);
SG_UNREF(true_labels_binary);
}
Expand Down
84 changes: 81 additions & 3 deletions src/shogun/evaluation/CrossValidationMulticlassStorage.h
Expand Up @@ -12,8 +12,10 @@
#define CROSSVALIDATIONMULTICLASSSTORAGE_H_

#include <shogun/evaluation/CrossValidationOutput.h>
#include <shogun/evaluation/BinaryClassEvaluation.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/DynamicObjectArray.h>

namespace shogun
{
Expand All @@ -22,7 +24,10 @@ class CMachine;
class CLabels;
class CEvaluation;

/** @brief Class for storing multiclass evaluation information in every fold of cross-validation */
/** @brief Class for storing multiclass evaluation information in every fold of cross-validation.
*
* Be careful - can be very expensive memory-wise.
*/
class CCrossValidationMulticlassStorage: public CCrossValidationOutput
{
public:
Expand All @@ -33,19 +38,26 @@ class CCrossValidationMulticlassStorage: public CCrossValidationOutput
m_pred_labels = NULL;
m_true_labels = NULL;
m_num_classes = 0;
m_binary_evaluations = new CDynamicObjectArray();
}

/** destructor */
virtual ~CCrossValidationMulticlassStorage()
{
for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
{
m_fold_ROC_graphs[i].~SGMatrix<float64_t>();
}
SG_FREE(m_fold_ROC_graphs);
for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
{
m_fold_PRC_graphs[i].~SGMatrix<float64_t>();
}
SG_FREE(m_fold_PRC_graphs);
SG_UNREF(m_binary_evaluations);
};

/** returns ROC
/** returns ROC of 1-v-R in given fold and run
*
* @param run run
* @param fold fold
Expand All @@ -62,6 +74,63 @@ class CCrossValidationMulticlassStorage: public CCrossValidationOutput
ASSERT(c<m_num_classes);
return m_fold_ROC_graphs[run*m_num_folds*m_num_classes+fold*m_num_classes+c];
}

/** returns PRC of 1-v-R in given fold and run
*
* @param run run
* @param fold fold
* @param c class
* @return ROC of 'run' run, 'fold' fold and 'c' class
*/
SGMatrix<float64_t> get_fold_PRC(int32_t run, int32_t fold, int32_t c)
{
ASSERT(0<=run);
ASSERT(run<m_num_runs);
ASSERT(0<=fold);
ASSERT(fold<m_num_folds);
ASSERT(0<=c);
ASSERT(c<m_num_classes);
return m_fold_PRC_graphs[run*m_num_folds*m_num_classes+fold*m_num_classes+c];
}

/** appends a binary evaluation instance
*
* @param evaluation binary evaluation to add
*/
void append_binary_evaluation(CBinaryClassEvaluation* evaluation)
{
m_binary_evaluations->push_back(evaluation);
}

/** returns binary evalution appended before
*
* @param idx
*/
CBinaryClassEvaluation* get_binary_evaluation(int32_t idx)
{
return (CBinaryClassEvaluation*)m_binary_evaluations->get_element_safe(idx);
}

/** returns evaluation result of 1-v-R in given fold and run
*
* @param run run
* @param fold fold
* @param c class
* @param e evaluation number
*/
float64_t get_fold_evaluation_result(int32_t run, int32_t fold, int32_t c, int32_t e)
{
ASSERT(0<=run);
ASSERT(run<m_num_runs);
ASSERT(0<=fold);
ASSERT(fold<m_num_folds);
ASSERT(0<=c);
ASSERT(c<m_num_classes);
ASSERT(0<=e);
int32_t n_evals = m_binary_evaluations->get_num_elements();
ASSERT(e<n_evals);
return m_evaluations_results[run*m_num_folds*m_num_classes*n_evals+fold*m_num_classes*n_evals+c*n_evals+e];
}

/** post init */
virtual void post_init();
Expand Down Expand Up @@ -95,8 +164,17 @@ class CCrossValidationMulticlassStorage: public CCrossValidationOutput

protected:

/** custom binary evaluators */
CDynamicObjectArray* m_binary_evaluations;

/** fold evaluation results */
SGVector<float64_t> m_evaluations_results;

/** fold ROC graphs */
SGMatrix<float64_t>* m_fold_ROC_graphs;

/** fold PRC graphs */
SGMatrix<float64_t>* m_fold_PRC_graphs;

/** predicted results */
CMulticlassLabels* m_pred_labels;
Expand Down

0 comments on commit 56d4a32

Please sign in to comment.