Skip to content

Commit

Permalink
An interface for checking label type
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed May 22, 2012
1 parent d204fcd commit 4d35f91
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 3 deletions.
1 change: 0 additions & 1 deletion src/shogun/machine/KernelMachine.h
Expand Up @@ -265,7 +265,6 @@ class CKernelMachine : public CMachine
/** @return whether machine supports locking */
virtual bool supports_locking() const { return true; }


protected:

SGVector<float64_t> apply_get_outputs(CFeatures* data);
Expand Down
4 changes: 4 additions & 0 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -66,6 +66,10 @@ bool CMachine::train(CFeatures* data)

void CMachine::set_labels(CLabels* lab)
{
if (lab != NULL)
if (!is_label_valid(lab))
SG_ERROR("Invalid label for %s", get_name());

SG_UNREF(m_labels);
SG_REF(lab);
m_labels = lab;
Expand Down
11 changes: 11 additions & 0 deletions src/shogun/machine/Machine.h
Expand Up @@ -291,6 +291,17 @@ class CMachine : public CSGObject
" work though.\n", get_name());
}

/** check whether the labels is valid.
*
* Subclasses can override this to implement their check of label types.
*
* @param lab the labels being checked, guaranteed to be non-NULL
*/
virtual bool is_label_valid(CLabels *lab) const
{
return true;
}

protected:
/** maximum training time */
float64_t m_max_train_time;
Expand Down
2 changes: 0 additions & 2 deletions src/shogun/machine/MulticlassMachine.cpp
Expand Up @@ -52,8 +52,6 @@ CMulticlassMachine::~CMulticlassMachine()

void CMulticlassMachine::set_labels(CLabels* lab)
{
if (lab)
ASSERT(lab->get_label_type() == LT_MULTICLASS);
CMachine::set_labels(lab);
if (lab)
init_strategy();
Expand Down
9 changes: 9 additions & 0 deletions src/shogun/machine/MulticlassMachine.h
Expand Up @@ -134,6 +134,15 @@ class CMulticlassMachine : public CMachine
return CT_MULTICLASS;
}

/** check whether the labels is valid.
*
* @param lab the labels being checked, guaranteed to be non-NULL
*/
virtual bool is_label_valid(CLabels *lab) const
{
return lab->get_label_type() == LT_MULTICLASS;
}

protected:
/** init strategy */
void init_strategy();
Expand Down

0 comments on commit 4d35f91

Please sign in to comment.