Skip to content

Commit

Permalink
Fixed crashers in PRC/ROC
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed May 8, 2012
1 parent 585ebed commit 276fa54
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 44 deletions.
21 changes: 6 additions & 15 deletions src/shogun/evaluation/PRCEvaluation.cpp
Expand Up @@ -15,7 +15,6 @@ using namespace shogun;

CPRCEvaluation::~CPRCEvaluation()
{
SG_FREE(m_PRC_graph);
}

float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
Expand Down Expand Up @@ -46,9 +45,8 @@ float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)

// clean and initialize graph and auPRC
SG_FREE(labels);
SG_FREE(m_PRC_graph);
m_PRC_graph = SG_MALLOC(float64_t, length*2);
m_thresholds = SG_MALLOC(float64_t, length);
m_PRC_graph = SGMatrix<float64_t>(length,2);
m_thresholds = SGVector<float64_t>(length);
m_auPRC = 0.0;

// get total numbers of positive and negative labels
Expand Down Expand Up @@ -77,10 +75,9 @@ float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
}

// calc auRPC using area under curve
m_auPRC = CMath::area_under_curve(m_PRC_graph,length,true);
m_auPRC = CMath::area_under_curve(m_PRC_graph.matrix,length,true);

// set PRC length and computed indicator
m_PRC_length = length;
// set computed indicator
m_computed = true;

return m_auPRC;
Expand All @@ -91,19 +88,15 @@ SGMatrix<float64_t> CPRCEvaluation::get_PRC()
if (!m_computed)
SG_ERROR("Uninitialized, please call evaluate first");

ASSERT(m_PRC_graph);

return SGMatrix<float64_t>(m_PRC_graph,2,m_PRC_length);
return m_PRC_graph;
}

SGVector<float64_t> CPRCEvaluation::get_thresholds()
{
if (!m_computed)
SG_ERROR("Uninitialized, please call evaluate first");

ASSERT(m_thresholds);

return SGVector<float64_t>(m_thresholds,m_PRC_length);
return m_thresholds;
}

float64_t CPRCEvaluation::get_auPRC()
Expand All @@ -113,5 +106,3 @@ float64_t CPRCEvaluation::get_auPRC()

return m_auPRC;
}


15 changes: 8 additions & 7 deletions src/shogun/evaluation/PRCEvaluation.h
Expand Up @@ -27,8 +27,12 @@ class CPRCEvaluation: public CBinaryClassEvaluation
public:
/** constructor */
CPRCEvaluation() :
CBinaryClassEvaluation(), m_PRC_graph(NULL),
m_auPRC(0.0), m_PRC_length(0), m_computed(false) {};
CBinaryClassEvaluation(), m_computed(false)
{
m_PRC_graph = SGMatrix<float64_t>();
m_thresholds = SGVector<float64_t>();
m_auPRC = 0.0;
};

/** destructor */
virtual ~CPRCEvaluation();
Expand Down Expand Up @@ -65,17 +69,14 @@ class CPRCEvaluation: public CBinaryClassEvaluation
protected:

/** 2-d array used to store PRC graph */
float64_t* m_PRC_graph;
SGMatrix<float64_t> m_PRC_graph;

/** vector with thresholds corresponding to points on the PRC graph */
float64_t* m_thresholds;
SGVector<float64_t> m_thresholds;

/** area under PRC graph */
float64_t m_auPRC;

/** number of points in PRC graph */
int32_t m_PRC_length;

/** indicator of PRC and auPRC being computed already */
bool m_computed;
};
Expand Down
21 changes: 6 additions & 15 deletions src/shogun/evaluation/ROCEvaluation.cpp
Expand Up @@ -15,7 +15,6 @@ using namespace shogun;

CROCEvaluation::~CROCEvaluation()
{
SG_FREE(m_ROC_graph);
}

float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
Expand Down Expand Up @@ -58,12 +57,11 @@ float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
diff_count++;
}

delete [] labels;
SG_FREE(labels);

// initialize graph and auROC
SG_FREE(m_ROC_graph);
m_ROC_graph = SG_MALLOC(float64_t, diff_count*2+2);
m_thresholds = SG_MALLOC(float64_t, length);
m_ROC_graph = SGMatrix<float64_t>(diff_count+1,2);
m_thresholds = SGVector<float64_t>(length);
m_auROC = 0.0;

// get total numbers of positive and negative labels
Expand Down Expand Up @@ -106,11 +104,8 @@ float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
m_ROC_graph[2*diff_count] = 1.0;
m_ROC_graph[2*diff_count+1] = 1.0;

// set ROC length
m_ROC_length = diff_count+1;

// calc auROC using area under curve
m_auROC = CMath::area_under_curve(m_ROC_graph,m_ROC_length,false);
m_auROC = CMath::area_under_curve(m_ROC_graph.matrix,diff_count+1,false);

m_computed = true;

Expand All @@ -122,19 +117,15 @@ SGMatrix<float64_t> CROCEvaluation::get_ROC()
if (!m_computed)
SG_ERROR("Uninitialized, please call evaluate first");

ASSERT(m_ROC_graph);

return SGMatrix<float64_t>(m_ROC_graph,2,m_ROC_length);
return m_ROC_graph;
}

SGVector<float64_t> CROCEvaluation::get_thresholds()
{
if (!m_computed)
SG_ERROR("Uninitialized, please call evaluate first");

ASSERT(m_thresholds);

return SGVector<float64_t>(m_thresholds,m_ROC_length);
return m_thresholds;
}

float64_t CROCEvaluation::get_auROC()
Expand Down
14 changes: 7 additions & 7 deletions src/shogun/evaluation/ROCEvaluation.h
Expand Up @@ -32,8 +32,11 @@ class CROCEvaluation: public CBinaryClassEvaluation
public:
/** constructor */
CROCEvaluation() :
CBinaryClassEvaluation(), m_ROC_graph(NULL),
m_auROC(0.0), m_ROC_length(0), m_computed(false) {};
CBinaryClassEvaluation(), m_computed(false)
{
m_ROC_graph = SGMatrix<float64_t>();
m_thresholds = SGVector<float64_t>();
};

/** destructor */
virtual ~CROCEvaluation();
Expand Down Expand Up @@ -71,17 +74,14 @@ class CROCEvaluation: public CBinaryClassEvaluation
protected:

/** 2-d array used to store ROC graph */
float64_t* m_ROC_graph;
SGMatrix<float64_t> m_ROC_graph;

/** vector with thresholds corresponding to points on the ROC graph */
float64_t* m_thresholds;
SGVector<float64_t> m_thresholds;

/** area under ROC graph */
float64_t m_auROC;

/** number of points in ROC graph */
int32_t m_ROC_length;

/** indicator of ROC and auROC being computed already */
bool m_computed;
};
Expand Down

0 comments on commit 276fa54

Please sign in to comment.