Skip to content

Commit

Permalink
Added multiclass OVR evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Aug 13, 2012
1 parent df80e5c commit c49b8cf
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 1 deletion.
@@ -0,0 +1,28 @@
from tools.load import LoadMatrix
from numpy import random
lm=LoadMatrix()

random.seed(17)
ground_truth = lm.load_labels('../data/label_train_multiclass.dat')

parameter_list = [[ground_truth]]

def evaluation_multiclassovrevaluation_modular(ground_truth):
from shogun.Features import MulticlassLabels
from shogun.Evaluation import MulticlassAccuracy,ROCEvaluation

ground_truth_labels = MulticlassLabels(ground_truth)
predicted_labels = MulticlassLabels(ground_truth)

binary_evaluator = ROCEvaluation()
evaluator = MulticlassAccuracy(binary_evaluator)
mean_roc = evaluator.evaluate(predicted_labels,ground_truth_labels)
print mean_roc

return mean_roc


if __name__=='__main__':
print('MulticlassOVREvaluation')
evaluation_multiclassovrevaluation_modular(*parameter_list[0])

3 changes: 2 additions & 1 deletion src/interfaces/modular/Evaluation.i
Expand Up @@ -42,7 +42,7 @@
%rename(DifferentiableFunction) CDifferentiableFunction;
%rename(GradientCriterion) CGradientCriterion;
%rename(GradientEvaluation) CGradientEvaluation;

%rename(MulticlassOVREvaluation) CMulticlassOVREvaluation;



Expand All @@ -67,5 +67,6 @@
%include <shogun/evaluation/GradientCriterion.h>
%include <shogun/evaluation/GradientEvaluation.h>
%include <shogun/evaluation/GradientResult.h>
%include <shogun/evaluation/MulticlassOVREvaluation.h>
%include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
%include <shogun/evaluation/CrossValidationSplitting.h>
1 change: 1 addition & 0 deletions src/interfaces/modular/Evaluation_includes.i
Expand Up @@ -18,6 +18,7 @@
#include <shogun/evaluation/GradientCriterion.h>
#include <shogun/evaluation/GradientEvaluation.h>
#include <shogun/evaluation/GradientResult.h>
#include <shogun/evaluation/MulticlassOVREvaluation.h>
#include <shogun/evaluation/SplittingStrategy.h>
#include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
#include <shogun/evaluation/CrossValidationSplitting.h>
Expand Down
95 changes: 95 additions & 0 deletions src/shogun/evaluation/MulticlassOVREvaluation.cpp
@@ -0,0 +1,95 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Copyright (C) 2012 Sergey Lisitsyn
*/

#include <shogun/evaluation/MulticlassOVREvaluation.h>
#include <shogun/evaluation/ROCEvaluation.h>
#include <shogun/evaluation/PRCEvaluation.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/mathematics/Statistics.h>

using namespace shogun;

CMulticlassOVREvaluation::CMulticlassOVREvaluation() :
CEvaluation(), m_binary_evaluation(NULL), m_graph_results(NULL), m_num_graph_results(0)
{
}

CMulticlassOVREvaluation::CMulticlassOVREvaluation(CBinaryClassEvaluation* binary_evaluation) :
CEvaluation(), m_binary_evaluation(NULL), m_graph_results(NULL), m_num_graph_results(0)
{
set_binary_evaluation(binary_evaluation);
}

CMulticlassOVREvaluation::~CMulticlassOVREvaluation()
{
SG_UNREF(m_binary_evaluation);
if (m_graph_results)
{
for (int32_t i=0; i<m_num_graph_results; i++)
m_graph_results[i].~SGMatrix<float64_t>();
SG_FREE(m_graph_results);
}
}

float64_t CMulticlassOVREvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
{
ASSERT(m_binary_evaluation);
ASSERT(predicted);
ASSERT(ground_truth);
int32_t n_labels = predicted->get_num_labels();
ASSERT(n_labels);
CMulticlassLabels* predicted_mc = (CMulticlassLabels*)predicted;
CMulticlassLabels* ground_truth_mc = (CMulticlassLabels*)ground_truth;
int32_t n_classes = predicted_mc->get_multiclass_confidences(0).size();
m_last_results = SGVector<float64_t>(n_classes);

SGMatrix<float64_t> all(n_labels,n_classes);
for (int32_t i=0; i<n_labels; i++)
{
SGVector<float64_t> confs = predicted_mc->get_multiclass_confidences(i);
for (int32_t j=0; j<n_classes; j++)
{
all(i,j) = confs[j];
}
}
if (dynamic_cast<CROCEvaluation*>(m_binary_evaluation) || dynamic_cast<CPRCEvaluation*>(m_binary_evaluation))
{
for (int32_t i=0; i<m_num_graph_results; i++)
m_graph_results[i].~SGMatrix<float64_t>();
SG_FREE(m_graph_results);
m_graph_results = SG_MALLOC(SGMatrix<float64_t>, n_classes);
m_num_graph_results = n_classes;
}
for (int32_t c=0; c<n_classes; c++)
{
CLabels* pred = new CBinaryLabels(SGVector<float64_t>(all.get_column_vector(c),n_labels,false));
SGVector<float64_t> gt_vec(n_labels);
for (int32_t i=0; i<n_labels; i++)
{
if (ground_truth_mc->get_label(i)==c)
gt_vec[i] = +1.0;
else
gt_vec[i] = -1.0;
}
CLabels* gt = new CBinaryLabels(gt_vec);
m_last_results[c] = m_binary_evaluation->evaluate(pred, gt);

if (dynamic_cast<CROCEvaluation*>(m_binary_evaluation))
{
new (&m_graph_results[c]) SGMatrix<float64_t>();
m_graph_results[c] = ((CROCEvaluation*)m_binary_evaluation)->get_ROC();
}
if (dynamic_cast<CPRCEvaluation*>(m_binary_evaluation))
{
new (&m_graph_results[c]) SGMatrix<float64_t>();
m_graph_results[c] = ((CPRCEvaluation*)m_binary_evaluation)->get_PRC();
}
}
return CStatistics::mean(m_last_results);
}
104 changes: 104 additions & 0 deletions src/shogun/evaluation/MulticlassOVREvaluation.h
@@ -0,0 +1,104 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Copyright (C) 2012 Sergey Lisitsyn
*/

#ifndef MULTICLASSOVREVALUATION_H_
#define MULTICLASSOVREVALUATION_H_

#include <shogun/evaluation/Evaluation.h>
#include <shogun/evaluation/BinaryClassEvaluation.h>
#include <shogun/labels/Labels.h>

namespace shogun
{

class CLabels;

/** @brief The class MulticlassOVREvaluation
* used to compute evaluation parameters
* of multiclass classification via
* binary OvR decomposition and given binary
* evaluation technique.
*/
class CMulticlassOVREvaluation: public CEvaluation
{
public:
/** constructor */
CMulticlassOVREvaluation();

/** constructor */
CMulticlassOVREvaluation(CBinaryClassEvaluation* binary_evaluation);

/** destructor */
virtual ~CMulticlassOVREvaluation();

/** set evaluation */
void set_binary_evaluation(CBinaryClassEvaluation* binary_evaluation)
{
SG_REF(binary_evaluation);
SG_UNREF(m_binary_evaluation);
m_binary_evaluation = binary_evaluation;
}

/** get evaluation */
CBinaryClassEvaluation* get_binary_evaluation()
{
SG_REF(m_binary_evaluation);
return m_binary_evaluation;
}

/** evaluate accuracy
* @param predicted labels to be evaluated
* @param ground_truth labels assumed to be correct
* @return mean of OvR binary evaluations
*/
virtual float64_t evaluate(CLabels* predicted, CLabels* ground_truth);

/** returns last results per class */
SGVector<float64_t> get_last_results()
{
return m_last_results;
}

/** returns graph for ith class */
SGMatrix<float64_t> get_graph_for_class(int32_t i)
{
ASSERT(m_graph_results);
ASSERT(i>=0);
ASSERT(i<m_num_graph_results);
return m_graph_results[i];
}

/** returns evaluation direction */
virtual EEvaluationDirection get_evaluation_direction()
{
return m_binary_evaluation->get_evaluation_direction();
}

/** get name */
virtual const char* get_name() const { return "MulticlassOVREvaluation"; }

protected:

/** binary evaluation to be used */
CBinaryClassEvaluation* m_binary_evaluation;

/** last per class results */
SGVector<float64_t> m_last_results;

/** stores graph (ROC,PRC) results per class */
SGMatrix<float64_t>* m_graph_results;

/** number of graph results */
int32_t m_num_graph_results;

};

}

#endif /* MULTICLASSOVREVALUATION_H_ */
27 changes: 27 additions & 0 deletions src/shogun/labels/MulticlassLabels.cpp
Expand Up @@ -6,19 +6,46 @@ using namespace shogun;

CMulticlassLabels::CMulticlassLabels() : CDenseLabels()
{
m_num_multiclass_confidences = 0;
}

CMulticlassLabels::CMulticlassLabels(int32_t num_labels) : CDenseLabels(num_labels)
{
m_multiclass_confidences = SG_MALLOC(SGVector<float64_t>, num_labels);
m_num_multiclass_confidences = num_labels;
for (int32_t i=0; i<num_labels; i++)
new (&m_multiclass_confidences[i]) SGVector<float64_t>();
}

CMulticlassLabels::CMulticlassLabels(const SGVector<float64_t> src) : CDenseLabels()
{
set_labels(src);
m_multiclass_confidences = SG_MALLOC(SGVector<float64_t>, src.vlen);
m_num_multiclass_confidences = src.vlen;
for (int32_t i=0; i<src.vlen; i++)
new (&m_multiclass_confidences[i]) SGVector<float64_t>();
}

CMulticlassLabels::CMulticlassLabels(CFile* loader) : CDenseLabels(loader)
{
m_num_multiclass_confidences = 0;
}

CMulticlassLabels::~CMulticlassLabels()
{
for (int32_t i=0; i<m_num_multiclass_confidences; i++)
m_multiclass_confidences[i].~SGVector<float64_t>();
SG_FREE(m_multiclass_confidences);
}

void CMulticlassLabels::set_multiclass_confidences(int32_t i, SGVector<float64_t> confidences)
{
m_multiclass_confidences[i] = confidences;
}

SGVector<float64_t> CMulticlassLabels::get_multiclass_confidences(int32_t i)
{
return m_multiclass_confidences[i];
}

CMulticlassLabels* CMulticlassLabels::obtain_from_generic(CLabels* base_labels)
Expand Down
25 changes: 25 additions & 0 deletions src/shogun/labels/MulticlassLabels.h
Expand Up @@ -54,6 +54,9 @@ class CMulticlassLabels : public CDenseLabels
*/
CMulticlassLabels(CFile* loader);

/** destructor */
~CMulticlassLabels();

/** helper method used to specialize a base class instance
*
* @param base_labels its dynamic type must be CMulticlassLabels
Expand Down Expand Up @@ -97,8 +100,30 @@ class CMulticlassLabels : public CDenseLabels
*/
int32_t get_num_classes();

/** returns multiclass confidences
*
* @param i index
* @return confidences of ith result
*/
SGVector<float64_t> get_multiclass_confidences(int32_t i);

/** sets multiclass confidences
*
* @param i index
* @param confidences confidences to be set for ith result
*/
void set_multiclass_confidences(int32_t i, SGVector<float64_t> confidences);

/** @return object name */
inline virtual const char* get_name() const { return "MulticlassLabels"; }

protected:

/** multiclass confidences */
SGVector<float64_t>* m_multiclass_confidences;

/** number of multiclass confidences */
int32_t m_num_multiclass_confidences;
};
}
#endif
1 change: 1 addition & 0 deletions src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -124,6 +124,7 @@ CMulticlassLabels* CMulticlassMachine::apply_multiclass(CFeatures* data)
output_for_i[j] = outputs[j]->get_confidence(i);

result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
result->set_multiclass_confidences(i, output_for_i);
}

for (int32_t i=0; i < num_machines; ++i)
Expand Down

0 comments on commit c49b8cf

Please sign in to comment.