Skip to content

Commit

Permalink
Fixed multiclass ROC
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Aug 17, 2012
1 parent 6684bae commit 80685b0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
Expand Up @@ -3,23 +3,32 @@
lm=LoadMatrix()

random.seed(17)
ground_truth = lm.load_labels('../data/label_train_multiclass.dat')
import classifier_multiclass_shared
[traindat, label_traindat, testdat, label_testdat] = classifier_multiclass_shared.prepare_data(False)

parameter_list = [[ground_truth]]
parameter_list = [[traindat, label_traindat, testdat, label_testdat]]

def evaluation_multiclassovrevaluation_modular(ground_truth):
def evaluation_multiclassovrevaluation_modular(traindat, label_traindat, testdat, label_testdat):
from shogun.Features import MulticlassLabels
from shogun.Evaluation import MulticlassAccuracy,ROCEvaluation
from shogun.Evaluation import MulticlassOVREvaluation,ROCEvaluation
from modshogun import MulticlassLibLinear,RealFeatures,ContingencyTableEvaluation,ACCURACY

ground_truth_labels = MulticlassLabels(ground_truth)
predicted_labels = MulticlassLabels(ground_truth)
ground_truth_labels = MulticlassLabels(label_traindat)
svm = MulticlassLibLinear(1.0,RealFeatures(traindat),MulticlassLabels(label_traindat))
svm.train()
predicted_labels = svm.apply()

binary_evaluator = ROCEvaluation()
evaluator = MulticlassAccuracy(binary_evaluator)
evaluator = MulticlassOVREvaluation(binary_evaluator)
mean_roc = evaluator.evaluate(predicted_labels,ground_truth_labels)
print mean_roc

binary_evaluator = ContingencyTableEvaluation(ACCURACY)
evaluator = MulticlassOVREvaluation(binary_evaluator)
mean_accuracy = evaluator.evaluate(predicted_labels,ground_truth_labels)
print mean_accuracy

return mean_roc
return mean_roc, mean_accuracy


if __name__=='__main__':
Expand Down
1 change: 1 addition & 0 deletions src/shogun/evaluation/MulticlassOVREvaluation.cpp
Expand Up @@ -47,6 +47,7 @@ float64_t CMulticlassOVREvaluation::evaluate(CLabels* predicted, CLabels* ground
CMulticlassLabels* predicted_mc = (CMulticlassLabels*)predicted;
CMulticlassLabels* ground_truth_mc = (CMulticlassLabels*)ground_truth;
int32_t n_classes = predicted_mc->get_multiclass_confidences(0).size();
ASSERT(n_classes>0);
m_last_results = SGVector<float64_t>(n_classes);

SGMatrix<float64_t> all(n_labels,n_classes);
Expand Down
7 changes: 3 additions & 4 deletions src/shogun/labels/MulticlassLabels.cpp
Expand Up @@ -21,14 +21,13 @@ CMulticlassLabels::CMulticlassLabels(int32_t num_labels) : CDenseLabels(num_labe
CMulticlassLabels::CMulticlassLabels(const SGVector<float64_t> src) : CDenseLabels()
{
set_labels(src);
m_multiclass_confidences = SG_MALLOC(SGVector<float64_t>, src.vlen);
m_num_multiclass_confidences = src.vlen;
for (int32_t i=0; i<src.vlen; i++)
new (&m_multiclass_confidences[i]) SGVector<float64_t>();
m_multiclass_confidences = NULL;
m_num_multiclass_confidences = 0;
}

CMulticlassLabels::CMulticlassLabels(CFile* loader) : CDenseLabels(loader)
{
m_multiclass_confidences = NULL;
m_num_multiclass_confidences = 0;
}

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -124,7 +124,7 @@ CMulticlassLabels* CMulticlassMachine::apply_multiclass(CFeatures* data)
output_for_i[j] = outputs[j]->get_confidence(i);

result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
result->set_multiclass_confidences(i, output_for_i);
result->set_multiclass_confidences(i, output_for_i.clone());
}

for (int32_t i=0; i < num_machines; ++i)
Expand Down

0 comments on commit 80685b0

Please sign in to comment.