Skip to content

Commit

Permalink
added label distinctions for apply_locked
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed May 22, 2012
1 parent 4d58627 commit 95f980f
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 11 deletions.
18 changes: 16 additions & 2 deletions src/shogun/machine/KernelMachine.cpp
Expand Up @@ -498,7 +498,21 @@ bool CKernelMachine::train_locked(SGVector<index_t> indices)
return result;
}

CLabels* CKernelMachine::apply_locked(SGVector<index_t> indices)
CBinaryLabels* CKernelMachine::apply_locked_binary(SGVector<index_t> indices)
{
SGVector<float64_t> outputs = apply_locked_get_output(indices);
return new CBinaryLabels(outputs);
}

CRegressionLabels* CKernelMachine::apply_locked_regression(
SGVector<index_t> indices)
{
SGVector<float64_t> outputs = apply_locked_get_output(indices);
return new CRegressionLabels(outputs);
}

SGVector<float64_t> CKernelMachine::apply_locked_get_output(
SGVector<index_t> indices)
{
if (!is_data_locked())
SG_ERROR("CKernelMachine::apply_locked() call data_lock() before!\n");
Expand Down Expand Up @@ -586,7 +600,7 @@ CLabels* CKernelMachine::apply_locked(SGVector<index_t> indices)
#endif
SG_DONE();

return new CRegressionLabels(output);
return output;
}

void CKernelMachine::data_lock(CLabels* labs, CFeatures* features)
Expand Down
21 changes: 20 additions & 1 deletion src/shogun/machine/KernelMachine.h
Expand Up @@ -242,12 +242,31 @@ class CKernelMachine : public CMachine
*/
virtual bool train_locked(SGVector<index_t> indices);

/** Applies a locked machine on a set of indices. Error if machine is
* not locked. Binary case
*
* @param indices index vector (of locked features) that is predicted
* @return resulting labels
*/
virtual CBinaryLabels* apply_locked_binary(SGVector<index_t> indices);

/** Applies a locked machine on a set of indices. Error if machine is
* not locked. Binary case
*
* @param indices index vector (of locked features) that is predicted
* @return resulting labels
*/
virtual CRegressionLabels* apply_locked_regression(
SGVector<index_t> indices);

/** Applies a locked machine on a set of indices. Error if machine is
* not locked
*
* @param indices index vector (of locked features) that is predicted
* @return raw output of machine
*/
virtual CLabels* apply_locked(SGVector<index_t> indices);
virtual SGVector<float64_t> apply_locked_get_output(
SGVector<index_t> indices);

/** Locks the machine on given labels and data. After this call, only
* train_locked and apply_locked may be called.
Expand Down
38 changes: 38 additions & 0 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -173,6 +173,23 @@ CLabels* CMachine::apply(CFeatures* data)
return NULL;
}

CLabels* CMachine::apply_locked(SGVector<index_t> indices)
{
switch (get_machine_problem_type())
{
case PT_BINARY:
return apply_locked_binary(indices);
case PT_REGRESSION:
return apply_locked_regression(indices);
case PT_MULTICLASS:
return apply_locked_multiclass(indices);
default:
SG_ERROR("Unknown problem type");
break;
}
return NULL;
}

CBinaryLabels* CMachine::apply_binary(CFeatures* data)
{
SG_ERROR("This machine does not support apply_binary()\n");
Expand All @@ -191,4 +208,25 @@ CMulticlassLabels* CMachine::apply_multiclass(CFeatures* data)
return NULL;
}

CBinaryLabels* CMachine::apply_locked_binary(SGVector<index_t> indices)
{
SG_ERROR("apply_locked_binary(SGVector<index_t>) is not yet implemented "
"for %s\n", get_name());
return NULL;
}

CRegressionLabels* CMachine::apply_locked_regression(SGVector<index_t> indices)
{
SG_ERROR("apply_locked_regression(SGVector<index_t>) is not yet implemented "
"for %s\n", get_name());
return NULL;
}

CMulticlassLabels* CMachine::apply_locked_multiclass(SGVector<index_t> indices)
{
SG_ERROR("apply_locked_multiclass(SGVector<index_t>) is not yet implemented "
"for %s\n", get_name());
return NULL;
}


15 changes: 7 additions & 8 deletions src/shogun/machine/Machine.h
Expand Up @@ -217,16 +217,15 @@ class CMachine : public CSGObject
/** Applies a locked machine on a set of indices. Error if machine is
* not locked
*
* NOT IMPLEMENTED
*
* @param indices index vector (of locked features) that is predicted
*/
virtual CLabels* apply_locked(SGVector<index_t> indices)
{
SG_ERROR("apply_locked(SGVector<index_t>) is not yet implemented "
"for %s\n", get_name());
return NULL;
}
virtual CLabels* apply_locked(SGVector<index_t> indices);

virtual CBinaryLabels* apply_locked_binary(SGVector<index_t> indices);
virtual CRegressionLabels* apply_locked_regression(
SGVector<index_t> indices);
virtual CMulticlassLabels* apply_locked_multiclass(
SGVector<index_t> indices);

/** Locks the machine on given labels and data. After this call, only
* train_locked and apply_locked may be called
Expand Down

0 comments on commit 95f980f

Please sign in to comment.