Skip to content

Commit

Permalink
introduce macros to easy apply renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
Soeren Sonnenburg committed May 21, 2012
1 parent 54eef53 commit 7407407
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 63 deletions.
76 changes: 39 additions & 37 deletions src/interfaces/modular/Machine.i
Expand Up @@ -2,63 +2,65 @@
%warnfilter(302) apply_generic;*/
%rename(apply_generic) apply(CFeatures*);

namespace shogun {
%extend CMulticlassMachine
%define APPLY_MULTICLASS(CLASS)
%extend CLASS
{
CMulticlassLabels* apply(CFeatures* data=NULL)
{
return CMulticlassLabels::obtain_from_generic($self->apply_multiclass(data));
}
}
%enddef

%extend CKernelMulticlassMachine
%define APPLY_BINARY(CLASS)
%extend CLASS
{
CMulticlassLabels* apply(CFeatures* data=NULL)
CBinaryLabels* apply(CFeatures* data=NULL)
{
return CMulticlassLabels::obtain_from_generic($self->apply_multiclass(data));
}
}

%extend CLinearMulticlassMachine
{
CMulticlassLabels* apply(CFeatures* data=NULL)
{
return CMulticlassLabels::obtain_from_generic($self->apply_multiclass(data));
return CBinaryLabels::obtain_from_generic($self->apply_binary(data));
}
}
%enddef


/*%extend COnlineLinearMachine
{
CRealLabels* apply(CFeatures* data=NULL)
{
return CRealLabels::obtain_from_generic($self->apply_binary(data));
}
}*/

%extend CLinearMachine
%define APPLY_REGRESSION(CLASS)
%extend CLASS
{
CRegressionLabels* apply(CFeatures* data=NULL)
{
return CRegressionLabels::obtain_from_generic($self->apply_regression(data));
}
}
%enddef

%extend CKernelMachine
{
CRegressionLabels* apply(CFeatures* data=NULL)
{
return CRegressionLabels::obtain_from_generic($self->apply_regression(data));
}
}
namespace shogun {
APPLY_MULTICLASS(CMulticlassMachine);
APPLY_MULTICLASS(CKernelMulticlassMachine);
APPLY_MULTICLASS(CLinearMulticlassMachine);
APPLY_MULTICLASS(CDistanceMachine);

%extend CDistanceMachine
{
CMulticlassLabels* apply(CFeatures* data=NULL)
{
return CMulticlassLabels::obtain_from_generic($self->apply_multiclass(data));
}
}
APPLY_BINARY(CLinearMachine);
APPLY_BINARY(CKernelMachine);
APPLY_BINARY(CWDSVMOcas);
APPLY_BINARY(CPluginEstimate);

APPLY_REGRESSION(CKernelRidgeRegression);
APPLY_REGRESSION(CSVRLight);
APPLY_REGRESSION(CMKLRegression);
APPLY_REGRESSION(CKernelRidgeRegression);
APPLY_REGRESSION(CLinearRidgeRegression);
APPLY_REGRESSION(CLeastSquaresRegression);
APPLY_REGRESSION(CLeastAngleRegression);
APPLY_REGRESSION(CGaussianProcessRegression);
}

#undef APPLY_MULTICLASS
#undef APPLY_BINARY
#undef APPLY_REGRESSION

/*%extend COnlineLinearMachine
{
CRegressionLabels* apply(CFeatures* data=NULL)
{
return CRegressionLabels::obtain_from_generic($self->apply_binary(data));
}
}*/
4 changes: 2 additions & 2 deletions src/shogun/classifier/PluginEstimate.cpp
Expand Up @@ -114,12 +114,12 @@ CLabels* CPluginEstimate::apply(CFeatures* data)
ASSERT(result->get_num_labels()==features->get_num_vectors());

for (int32_t vec=0; vec<features->get_num_vectors(); vec++)
result->set_label(vec, apply(vec));
result->set_label(vec, apply_one(vec));

return result;
}

float64_t CPluginEstimate::apply(int32_t vec_idx)
float64_t CPluginEstimate::apply_one(int32_t vec_idx)
{
ASSERT(features);

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/PluginEstimate.h
Expand Up @@ -66,7 +66,7 @@ class CPluginEstimate: public CMachine
virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }

/// classify the test feature vector indexed by vec_idx
float64_t apply(int32_t vec_idx);
float64_t apply_one(int32_t vec_idx);

/** obsolete posterior log odds
*
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/WDSVMOcas.cpp
Expand Up @@ -124,7 +124,7 @@ CLabels* CWDSVMOcas::apply(CFeatures* data)
SG_REF(output);

for (int32_t i=0; i<num; i++)
output->set_label(i, apply(i));
output->set_label(i, apply_one(i));

return output;
}
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/svm/WDSVMOcas.h
Expand Up @@ -160,7 +160,7 @@ class CWDSVMOcas : public CMachine
* @param num number of example to classify
* @return classified result
*/
inline virtual float64_t apply(int32_t num)
inline virtual float64_t apply_one(int32_t num)
{
ASSERT(features);
if (!wd_weights)
Expand Down
9 changes: 0 additions & 9 deletions src/shogun/regression/GaussianProcessRegression.cpp
Expand Up @@ -100,15 +100,6 @@ CLabels* CGaussianProcessRegression::apply(CFeatures* data)
return mean_prediction(features);
}

float64_t CGaussianProcessRegression::apply(int32_t num)
{
SG_ERROR("apply(int32_t num) is not yet implemented "
"for %s\n", get_name());

return 0;
}


bool CGaussianProcessRegression::train_machine(CFeatures* data)
{
if (!data->has_property(FP_DOT))
Expand Down
7 changes: 0 additions & 7 deletions src/shogun/regression/GaussianProcessRegression.h
Expand Up @@ -116,13 +116,6 @@ class CGaussianProcessRegression : public CMachine
*/
virtual CLabels* apply(CFeatures* data=NULL);

/** apply regression to one example
*
* @param num which example to apply to
* @return classified value
*/
virtual float64_t apply(int32_t num);

/** get classifier type
*
* @return classifier type GaussianProcessRegression
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/regression/LeastAngleRegression.cpp
Expand Up @@ -93,7 +93,7 @@ bool CLeastAngleRegression::train_machine(CFeatures* data)
if (!m_labels)
SG_ERROR("No labels set\n");
if (m_labels->get_label_type() != LT_REAL)
SG_ERROR("Expected RealLabels\n");
SG_ERROR("Expected RegressionLabels\n");

if (!data)
data=features;
Expand Down
4 changes: 2 additions & 2 deletions src/shogun/ui/GUIPluginEstimate.cpp
Expand Up @@ -101,7 +101,7 @@ CLabels* CGUIPluginEstimate::apply()
return estimator->apply();
}

float64_t CGUIPluginEstimate::apply(int32_t idx)
float64_t CGUIPluginEstimate::apply_one(int32_t idx)
{
CFeatures* testfeatures=ui->ui_features->get_test_features();

Expand All @@ -119,5 +119,5 @@ float64_t CGUIPluginEstimate::apply(int32_t idx)

estimator->set_features((CStringFeatures<uint16_t>*) testfeatures);

return estimator->apply(idx);
return estimator->apply_one(idx);
}
2 changes: 1 addition & 1 deletion src/shogun/ui/GUIPluginEstimate.h
Expand Up @@ -60,7 +60,7 @@ class CGUIPluginEstimate : public CSGObject
/** apply
* @param idx
*/
float64_t apply(int32_t idx);
float64_t apply_one(int32_t idx);

/** @return object name */
inline virtual const char* get_name() const { return "GUIPluginEstimate"; }
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/ui/SGInterface.cpp
Expand Up @@ -5387,7 +5387,7 @@ bool CSGInterface::cmd_plugin_estimate_classify_example()
return false;

int32_t idx=get_int();
float64_t result=ui_pluginestimate->apply(idx);
float64_t result=ui_pluginestimate->apply_one(idx);

set_vector(&result, 1);
return true;
Expand Down

0 comments on commit 7407407

Please sign in to comment.