Skip to content

Commit

Permalink
Fixes for OnlineLinearMachine
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed May 22, 2012
1 parent 4ead6f0 commit 32ce04d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
18 changes: 15 additions & 3 deletions src/shogun/machine/OnlineLinearMachine.cpp
Expand Up @@ -31,7 +31,19 @@ COnlineLinearMachine::~COnlineLinearMachine()
SG_UNREF(features);
}

CRegressionLabels* COnlineLinearMachine::apply(CFeatures* data)
CBinaryLabels* COnlineLinearMachine::apply_binary(CFeatures* data)
{
SGVector<float64_t> outputs = apply_get_outputs(data);
return new CBinaryLabels(outputs);
}

CRegressionLabels* COnlineLinearMachine::apply_regression(CFeatures* data)
{
SGVector<float64_t> outputs = apply_get_outputs(data);
return new CRegressionLabels(outputs);
}

SGVector<float64_t> COnlineLinearMachine::apply_get_outputs(CFeatures* data)
{
if (data)
{
Expand Down Expand Up @@ -63,10 +75,10 @@ CRegressionLabels* COnlineLinearMachine::apply(CFeatures* data)
for (int32_t i=0; i<num_labels; i++)
labels_array.vector[i]=(*labels_dynarray)[i];

return new CRegressionLabels(labels_array);
return labels_array;
}

float32_t COnlineLinearMachine::apply(float32_t* vec, int32_t len)
float32_t COnlineLinearMachine::apply_one(float32_t* vec, int32_t len)
{
return CMath::dot(vec, w, len)+bias;
}
Expand Down
21 changes: 18 additions & 3 deletions src/shogun/machine/OnlineLinearMachine.h
Expand Up @@ -149,14 +149,23 @@ class COnlineLinearMachine : public CMachine
}

/** apply linear machine to data
* for regression problems
*
* @param data (test)data to be classified
* @return classified labels
*/
virtual CRegressionLabels* apply(CFeatures* data=NULL);
virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);

/** apply linear machine to data
* for binary classification problems
*
* @param data (test)data to be classified
* @return classified labels
*/
virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);

/// get output for example "vec_idx"
virtual float64_t apply(int32_t vec_idx)
virtual float64_t apply_one(int32_t vec_idx)
{
SG_NOTIMPLEMENTED;
return CMath::INFTY;
Expand All @@ -170,7 +179,7 @@ class COnlineLinearMachine : public CMachine
*
* @return classified label
*/
virtual float32_t apply(float32_t* vec, int32_t len);
virtual float32_t apply_one(float32_t* vec, int32_t len);

/**
* apply linear machine to vector currently being processed
Expand All @@ -192,6 +201,12 @@ class COnlineLinearMachine : public CMachine
*/
virtual const char* get_name() const { return "OnlineLinearMachine"; }

protected:

SGVector<float64_t> apply_get_outputs(CFeatures* data);

virtual bool train_require_labels() const { return false; }

protected:
/** dimension of w */
int32_t w_dim;
Expand Down

0 comments on commit 32ce04d

Please sign in to comment.