Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Refactored apply of latent machines
  • Loading branch information
lisitsyn committed Aug 21, 2012
1 parent 416d360 commit d6b3a7d
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 14 deletions.
14 changes: 14 additions & 0 deletions src/interfaces/modular/Machine.i
Expand Up @@ -42,6 +42,16 @@
}
%enddef

%define APPLY_LATENT(CLASS)
%extend CLASS
{
CLatentLabels* apply(CFeatures* data=NULL)
{
return $self->apply_latent(data);
}
}
%enddef

namespace shogun {
APPLY_MULTICLASS(CMulticlassMachine);
APPLY_MULTICLASS(CKernelMulticlassMachine);
Expand Down Expand Up @@ -70,6 +80,9 @@ APPLY_STRUCTURED(CKernelStructuredOutputMachine);
#ifdef USE_MOSEK
APPLY_STRUCTURED(CPrimalMosekSOSVM);
#endif
APPLY_STRUCTURED(CDualLibQPBMSOSVM);

APPLY_LATENT(CLatentSVM);
}

%rename(apply_generic) CMachine::apply(CFeatures* data=NULL);
Expand Down Expand Up @@ -103,4 +116,5 @@ APPLY_STRUCTURED(CPrimalMosekSOSVM);
#undef APPLY_BINARY
#undef APPLY_REGRESSION
#undef APPLY_STRUCTURED
#undef APPLY_LATENT
#endif
3 changes: 1 addition & 2 deletions src/shogun/latent/LatentSOSVM.cpp
Expand Up @@ -31,9 +31,8 @@ CLatentSOSVM::~CLatentSOSVM()
SG_UNREF(m_so_solver);
}

CLatentLabels* CLatentSOSVM::apply()
CLatentLabels* CLatentSOSVM::apply_latent()
{

return NULL;
}

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/latent/LatentSOSVM.h
Expand Up @@ -40,7 +40,7 @@ namespace shogun
*
* @return classified labels
*/
virtual CLatentLabels* apply();
virtual CLatentLabels* apply_latent();

/** set SO solver that is going to be used
*
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/latent/LatentSVM.cpp
Expand Up @@ -29,7 +29,7 @@ CLatentSVM::~CLatentSVM()
{
}

CLatentLabels* CLatentSVM::apply()
CLatentLabels* CLatentSVM::apply_latent()
{
if (!m_model)
SG_ERROR("LatentModel is not set!\n");
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/latent/LatentSVM.h
Expand Up @@ -48,9 +48,9 @@ namespace shogun
*
* @return resulting labels
*/
virtual CLatentLabels* apply();
virtual CLatentLabels* apply_latent();

using CLinearLatentMachine::apply;
using CLinearLatentMachine::apply_latent;

/** Returns the name of the SGSerializable instance.
*
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/machine/LinearLatentMachine.cpp
Expand Up @@ -38,15 +38,15 @@ CLinearLatentMachine::~CLinearLatentMachine()
SG_UNREF(m_model);
}

CLatentLabels* CLinearLatentMachine::apply(CFeatures* data)
CLatentLabels* CLinearLatentMachine::apply_latent(CFeatures* data)
{
if (m_model == NULL)
SG_ERROR("LatentModel is not set!\n");

CLatentFeatures* lf = CLatentFeatures::obtain_from_generic(data);
m_model->set_features(lf);

return apply();
return apply_latent();
}

void CLinearLatentMachine::set_model(CLatentModel* latent_model)
Expand Down
10 changes: 5 additions & 5 deletions src/shogun/machine/LinearLatentMachine.h
Expand Up @@ -39,19 +39,19 @@ namespace shogun
CLinearLatentMachine(CLatentModel* model, float64_t C);

virtual ~CLinearLatentMachine();

/** apply linear machine to all examples
/** apply linear machine to data set before
*
* @return resulting labels
* @return classified labels
*/
virtual CLatentLabels* apply() = 0;
virtual CLatentLabels* apply_latent() = 0;

/** apply linear machine to data
*
* @param data (test)data to be classified
* @return classified labels
*/
virtual CLatentLabels* apply(CFeatures* data);
virtual CLatentLabels* apply_latent(CFeatures* data);

/** Returns the name of the SGSerializable instance.
*
Expand Down
26 changes: 25 additions & 1 deletion src/shogun/machine/Machine.cpp
Expand Up @@ -201,6 +201,10 @@ CLabels* CMachine::apply_locked(SGVector<index_t> indices)
return apply_locked_regression(indices);
case PT_MULTICLASS:
return apply_locked_multiclass(indices);
case PT_STRUCTURED:
return apply_locked_structured(indices);
case PT_LATENT:
return apply_locked_latent(indices);
default:
SG_ERROR("Unknown problem type");
break;
Expand Down Expand Up @@ -228,7 +232,13 @@ CMulticlassLabels* CMachine::apply_multiclass(CFeatures* data)

CStructuredLabels* CMachine::apply_structured(CFeatures* data)
{
SG_ERROR("This machine does not support apply_multiclass()\n");
SG_ERROR("This machine does not support apply_structured()\n");
return NULL;
}

CLatentLabels* CMachine::apply_latent(CFeatures* data)
{
SG_ERROR("This machine does not support apply_latent()\n");
return NULL;
}

Expand All @@ -253,4 +263,18 @@ CMulticlassLabels* CMachine::apply_locked_multiclass(SGVector<index_t> indices)
return NULL;
}

CStructuredLabels* CMachine::apply_locked_structured(SGVector<index_t> indices)
{
SG_ERROR("apply_locked_structured(SGVector<index_t>) is not yet implemented "
"for %s\n", get_name());
return NULL;
}

CLatentLabels* CMachine::apply_locked_latent(SGVector<index_t> indices)
{
SG_ERROR("apply_locked_latent(SGVector<index_t>) is not yet implemented "
"for %s\n", get_name());
return NULL;
}


9 changes: 9 additions & 0 deletions src/shogun/machine/Machine.h
Expand Up @@ -19,6 +19,7 @@
#include <shogun/labels/RegressionLabels.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/labels/StructuredLabels.h>
#include <shogun/labels/LatentLabels.h>
#include <shogun/features/Features.h>

namespace shogun
Expand Down Expand Up @@ -165,6 +166,8 @@ class CMachine : public CSGObject
virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
/** apply machine to data in means of SO classification problem */
virtual CStructuredLabels* apply_structured(CFeatures* data=NULL);
/** apply machine to data in means of latent problem */
virtual CLatentLabels* apply_latent(CFeatures* data=NULL);

/** set labels
*
Expand Down Expand Up @@ -253,6 +256,12 @@ class CMachine : public CSGObject
/** applies a locked machine on a set of indices for multiclass problems */
virtual CMulticlassLabels* apply_locked_multiclass(
SGVector<index_t> indices);
/** applies a locked machine on a set of indices for structured problems */
virtual CStructuredLabels* apply_locked_structured(
SGVector<index_t> indices);
/** applies a locked machine on a set of indices for latent problems */
virtual CLatentLabels* apply_locked_latent(
SGVector<index_t> indices);

/** Locks the machine on given labels and data. After this call, only
* train_locked and apply_locked may be called
Expand Down

0 comments on commit d6b3a7d

Please sign in to comment.