Skip to content

Commit

Permalink
add option to get corresponding bias-thresholds together with PRC / R…
Browse files Browse the repository at this point in the history
…OC curve
  • Loading branch information
Soeren Sonnenburg committed Nov 30, 2011
1 parent 4bce5ac commit afa4852
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/shogun/evaluation/PRCEvaluation.cpp
Expand Up @@ -49,6 +49,7 @@ float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
SG_FREE(labels);
SG_FREE(m_PRC_graph);
m_PRC_graph = SG_MALLOC(float64_t, length*2);
m_thresholds = SG_MALLOC(float64_t, length);
m_auPRC = 0.0;

// get total numbers of positive and negative labels
Expand All @@ -72,6 +73,8 @@ float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
m_PRC_graph[2*i] = tp/float64_t(i+1);
// recall (y)
m_PRC_graph[2*i+1] = tp/float64_t(pos_count);

m_thresholds[i]= predicted->get_label(idxs[i]);
}

// calc auRPC using area under curve
Expand All @@ -94,6 +97,16 @@ SGMatrix<float64_t> CPRCEvaluation::get_PRC()
return SGMatrix<float64_t>(m_PRC_graph,2,m_PRC_length);
}

SGVector<float64_t> CPRCEvaluation::get_thresholds()
{
if (!m_computed)
SG_ERROR("Uninitialized, please call evaluate first");

ASSERT(m_thresholds);

return SGVector<float64_t>(m_thresholds,m_PRC_length);
}

float64_t CPRCEvaluation::get_auPRC()
{
if (!m_computed)
Expand Down
7 changes: 7 additions & 0 deletions src/shogun/evaluation/PRCEvaluation.h
Expand Up @@ -58,11 +58,18 @@ class CPRCEvaluation: public CBinaryClassEvaluation
*/
SGMatrix<float64_t> get_PRC();

/** get thresholds corresponding to points on the PRC graph
* @return thresholds
*/
SGVector<float64_t> get_thresholds();
protected:

/** 2-d array used to store PRC graph */
float64_t* m_PRC_graph;

/** vector with thresholds corresponding to points on the PRC graph */
float64_t* m_thresholds;

/** area under PRC graph */
float64_t m_auPRC;

Expand Down
13 changes: 13 additions & 0 deletions src/shogun/evaluation/ROCEvaluation.cpp
Expand Up @@ -64,6 +64,7 @@ float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
// initialize graph and auROC
SG_FREE(m_ROC_graph);
m_ROC_graph = SG_MALLOC(float64_t, diff_count*2+2);
m_thresholds = SG_MALLOC(float64_t, length);
m_auROC = 0.0;

// get total numbers of positive and negative labels
Expand Down Expand Up @@ -94,6 +95,8 @@ float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
j++;
}

m_thresholds[i]=threshold;

if (ground_truth->get_label(idxs[i]) > 0)
tp+=1.0;
else
Expand Down Expand Up @@ -125,6 +128,16 @@ SGMatrix<float64_t> CROCEvaluation::get_ROC()
return SGMatrix<float64_t>(m_ROC_graph,2,m_ROC_length);
}

SGVector<float64_t> CROCEvaluation::get_thresholds()
{
if (!m_computed)
SG_ERROR("Uninitialized, please call evaluate first");

ASSERT(m_thresholds);

return SGVector<float64_t>(m_thresholds,m_ROC_length);
}

float64_t CROCEvaluation::get_auROC()
{
if (!m_computed)
Expand Down
8 changes: 8 additions & 0 deletions src/shogun/evaluation/ROCEvaluation.h
Expand Up @@ -63,11 +63,19 @@ class CROCEvaluation: public CBinaryClassEvaluation
*/
SGMatrix<float64_t> get_ROC();

/** get thresholds corresponding to points on the ROC graph
* @return thresholds
*/
SGVector<float64_t> get_thresholds();

protected:

/** 2-d array used to store ROC graph */
float64_t* m_ROC_graph;

/** vector with thresholds corresponding to points on the ROC graph */
float64_t* m_thresholds;

/** area under ROC graph */
float64_t m_auROC;

Expand Down

0 comments on commit afa4852

Please sign in to comment.