Skip to content

Commit

Permalink
Merge branch 'multiclass-ecoc' of https://github.com/pluskid/shogun
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed May 22, 2012
2 parents e27e8f2 + 91de8f9 commit c3b256b
Show file tree
Hide file tree
Showing 14 changed files with 68 additions and 50 deletions.
23 changes: 14 additions & 9 deletions src/shogun/labels/BinaryLabels.cpp
Expand Up @@ -35,9 +35,9 @@ CBinaryLabels* CBinaryLabels::obtain_from_generic(CLabels* base_labels)
}


bool CBinaryLabels::is_valid()
{
ASSERT(m_labels.vector);
void CBinaryLabels::ensure_valid(const char* context)
{
CDenseLabels::ensure_valid(context);
bool found_plus_one=false;
bool found_minus_one=false;

Expand All @@ -51,17 +51,22 @@ bool CBinaryLabels::is_valid()
found_minus_one=true;
else
{
SG_ERROR("Not a two class labeling label[%d]=%f (only +1/-1 "
"allowed)\n", i, m_labels[real_i]);
SG_ERROR("%s%sNot a two class labeling label[%d]=%f (only +1/-1 "
"allowed)\n", context?context:"", context?": ":"", i, m_labels[real_i]);
}
}

if (!found_plus_one)
SG_ERROR("Not a two class labeling - no positively labeled examples found\n");
if (!found_minus_one)
SG_ERROR("Not a two class labeling - no negatively labeled examples found\n");
{
SG_ERROR("%s%sNot a two class labeling - no positively labeled examples found\n",
context?context:"", context?": ":"");
}

return true;
if (!found_minus_one)
{
SG_ERROR("%s%sNot a two class labeling - no negatively labeled examples found\n",
context?context:"", context?": ":"");
}
}

ELabelType CBinaryLabels::get_label_type()
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/labels/BinaryLabels.h
Expand Up @@ -60,13 +60,13 @@ class CBinaryLabels : public CDenseLabels
*/
static CBinaryLabels* obtain_from_generic(CLabels* base_labels);

/** is_valid checks if labeling is a two-class labeling
/** Make sure the label is valid, otherwise raise SG_ERROR.
*
* possible with subset
*
* @return if this is two-class labeling
*
* @param context optional message to convey the context
*/
virtual bool is_valid();
virtual void ensure_valid(const char* context=NULL);

/** get label type
*
Expand Down
8 changes: 6 additions & 2 deletions src/shogun/labels/DenseLabels.cpp
Expand Up @@ -73,8 +73,6 @@ void CDenseLabels::set_labels(SGVector<float64_t> v)
SG_ERROR("A subset is set, cannot set labels\n");

m_labels = v;

is_valid();
}

SGVector<float64_t> CDenseLabels::get_labels()
Expand Down Expand Up @@ -121,6 +119,12 @@ void CDenseLabels::set_int_labels(SGVector<int32_t> lab)
set_int_label(i, lab.vector[i]);
}

void CDenseLabels::ensure_valid(const char* context)
{
if (m_labels.vector == NULL)
SG_ERROR("%s%sempty content (NULL) for labels\n", context?context:"", context?": ":"");
}

void CDenseLabels::load(CFile* loader)
{
remove_subset();
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/labels/DenseLabels.h
Expand Up @@ -51,13 +51,13 @@ class CDenseLabels : public CLabels
/** destructor */
virtual ~CDenseLabels();

/** check if labeling is valid
/** Make sure the label is valid, otherwise raise SG_ERROR.
*
* possible with subset
*
* @return if labeling is valid (e.g. binary labeling)
*
* @param context optional message to convey the context
*/
virtual bool is_valid()=0;
virtual void ensure_valid(const char* context=NULL);

/** load labels from file
*
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/labels/Labels.h
Expand Up @@ -40,13 +40,13 @@ class CLabels : public CSGObject
/** destructor */
virtual ~CLabels();

/** check if labeling is valid
/** Make sure the label is valid, otherwise raise SG_ERROR.
*
* possible with subset
*
* @return if labeling is valid (e.g. binary labeling)
*
* @param context optional message to convey the context
*/
virtual bool is_valid()=0;
virtual void ensure_valid(const char* context=NULL)=0;

/** get number of labels, depending on whether a subset is set
*
Expand Down
10 changes: 4 additions & 6 deletions src/shogun/labels/MulticlassLabels.cpp
Expand Up @@ -31,9 +31,9 @@ CMulticlassLabels* CMulticlassLabels::obtain_from_generic(CLabels* base_labels)
return NULL;
}

bool CMulticlassLabels::is_valid()
void CMulticlassLabels::ensure_valid(const char* context)
{
ASSERT(m_labels.vector);
CDenseLabels::ensure_valid(context);

int32_t subset_size=get_num_labels();
for (int32_t i=0; i<subset_size; i++)
Expand All @@ -43,12 +43,10 @@ bool CMulticlassLabels::is_valid()

if (label<0 || float64_t(label)!=m_labels[real_i])
{
SG_ERROR("Multiclass Labels must be in range 0...<nr_classes-1> and integers!\n");
return false;
SG_ERROR("%s%sMulticlass Labels must be in range 0...<nr_classes-1> and integers!\n",
context?context:"", context?": ":"");
}
}

return true;
}

ELabelType CMulticlassLabels::get_label_type()
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/labels/MulticlassLabels.h
Expand Up @@ -60,13 +60,13 @@ class CMulticlassLabels : public CDenseLabels
*/
static CMulticlassLabels* obtain_from_generic(CLabels* base_labels);

/** is_valid checks if labeling is a multi-class labeling
/** Make sure the label is valid, otherwise raise SG_ERROR.
*
* possible with subset
*
* @return if this is multi-class labeling
*
* @param context optional message to convey the context
*/
virtual bool is_valid();
virtual void ensure_valid(const char* context=NULL);

/** get label type
*
Expand Down
6 changes: 0 additions & 6 deletions src/shogun/labels/RegressionLabels.cpp
Expand Up @@ -30,12 +30,6 @@ CRegressionLabels* CRegressionLabels::obtain_from_generic(CLabels* base_labels)
return NULL;
}

bool CRegressionLabels::is_valid()
{
ASSERT(m_labels.vector);
return true;
}

ELabelType CRegressionLabels::get_label_type()
{
return LT_REGRESSION;
Expand Down
8 changes: 0 additions & 8 deletions src/shogun/labels/RegressionLabels.h
Expand Up @@ -61,14 +61,6 @@ class CRegressionLabels : public CDenseLabels
*/
static CRegressionLabels* obtain_from_generic(CLabels* base_labels);

/** is_valid checks if labeling is a multi-class labeling
*
* possible with subset
*
* @return if this is multi-class labeling
*/
virtual bool is_valid();

/** get label type
*
* @return label type real
Expand Down
1 change: 0 additions & 1 deletion src/shogun/machine/KernelMachine.h
Expand Up @@ -265,7 +265,6 @@ class CKernelMachine : public CMachine
/** @return whether machine supports locking */
virtual bool supports_locking() const { return true; }


protected:

SGVector<float64_t> apply_get_outputs(CFeatures* data);
Expand Down
8 changes: 8 additions & 0 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -56,6 +56,10 @@ bool CMachine::train(CFeatures* data)
get_name());
}

if (m_labels == NULL)
SG_ERROR("%s@%p: No labels given", get_name(), this);
m_labels->ensure_valid(get_name());

bool result = train_machine(data);

if (m_store_model_features)
Expand All @@ -66,6 +70,10 @@ bool CMachine::train(CFeatures* data)

void CMachine::set_labels(CLabels* lab)
{
if (lab != NULL)
if (!is_label_valid(lab))
SG_ERROR("Invalid label for %s", get_name());

SG_UNREF(m_labels);
SG_REF(lab);
m_labels = lab;
Expand Down
11 changes: 11 additions & 0 deletions src/shogun/machine/Machine.h
Expand Up @@ -291,6 +291,17 @@ class CMachine : public CSGObject
" work though.\n", get_name());
}

/** check whether the labels is valid.
*
* Subclasses can override this to implement their check of label types.
*
* @param lab the labels being checked, guaranteed to be non-NULL
*/
virtual bool is_label_valid(CLabels *lab) const
{
return true;
}

protected:
/** maximum training time */
float64_t m_max_train_time;
Expand Down
2 changes: 0 additions & 2 deletions src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -52,8 +52,6 @@ CMulticlassMachine::~CMulticlassMachine()

void CMulticlassMachine::set_labels(CLabels* lab)
{
if (lab)
ASSERT(lab->get_label_type() == LT_MULTICLASS);
CMachine::set_labels(lab);
if (lab)
init_strategy();
Expand Down
9 changes: 9 additions & 0 deletions src/shogun/machine/MulticlassMachine.h
Expand Up @@ -134,6 +134,15 @@ class CMulticlassMachine : public CMachine
return CT_MULTICLASS;
}

/** check whether the labels is valid.
*
* @param lab the labels being checked, guaranteed to be non-NULL
*/
virtual bool is_label_valid(CLabels *lab) const
{
return lab->get_label_type() == LT_MULTICLASS;
}

protected:
/** init strategy */
void init_strategy();
Expand Down

0 comments on commit c3b256b

Please sign in to comment.