Skip to content

Commit

Permalink
Merge pull request #460 from uricamic/GSKRR
Browse files Browse the repository at this point in the history
Gauss-Seidel iterative method for Kernel Ridge Regression learning
  • Loading branch information
Soeren Sonnenburg committed Apr 18, 2012
2 parents 024fc0a + 4b4bde8 commit 6653329
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 3 deletions.
84 changes: 82 additions & 2 deletions src/shogun/regression/KernelRidgeRegression.cpp
Expand Up @@ -24,14 +24,16 @@ CKernelRidgeRegression::CKernelRidgeRegression()
init();
}

CKernelRidgeRegression::CKernelRidgeRegression(float64_t tau, CKernel* k, CLabels* lab)
CKernelRidgeRegression::CKernelRidgeRegression(float64_t tau, CKernel* k, CLabels* lab, ETrainingType m)
: CKernelMachine()
{
init();

m_tau=tau;
set_labels(lab);
set_kernel(k);
set_epsilon(0.0001);
m_train_func=m;
}

void CKernelRidgeRegression::init()
Expand All @@ -41,7 +43,7 @@ void CKernelRidgeRegression::init()
SG_ADD(&m_tau, "tau", "Regularization parameter", MS_AVAILABLE);
}

bool CKernelRidgeRegression::train_machine(CFeatures* data)
bool CKernelRidgeRegression::train_machine_pinv(CFeatures* data)
{
if (!m_labels)
SG_ERROR("No labels set\n");
Expand Down Expand Up @@ -86,6 +88,84 @@ bool CKernelRidgeRegression::train_machine(CFeatures* data)
return true;
}

bool CKernelRidgeRegression::train_machine_gs(CFeatures* data)
{
if (!m_labels)
SG_ERROR("No labels set\n");

if (data)
{
if (m_labels->get_num_labels() != data->get_num_vectors())
SG_ERROR("Number of training vectors does not match number of labels\n");
kernel->init(data, data);
}
ASSERT(kernel && kernel->has_features());

int32_t n = kernel->get_num_vec_rhs();
int32_t m = kernel->get_num_vec_lhs();
ASSERT(m>0 && n>0);

// re-set alphas of kernel machine
m_alpha.destroy_vector();
SGVector<float64_t> b;
float64_t alpha_old;

b=m_labels->get_labels_copy();
m_alpha=m_labels->get_labels_copy();
m_alpha.zero();

// tell kernel machine that all alphas are needed as 'support vectors'
m_svs.destroy_vector();
m_svs=SGVector<index_t>(m_alpha.vlen);
m_svs.range_fill();

if (get_alphas().vlen!=n)
{
SG_ERROR("Number of labels does not match number of kernel"
" columns (num_labels=%d cols=%d\n", m_alpha.vlen, n);
}

// Gauss-Seidel iterative method
float64_t sigma, err, d;
bool flag=true;
while(flag)
{
err=0.0;
for(int32_t i=0; i<n; i++)
{
sigma=b[i];
for(int32_t j=0; j<n; j++)
if (i!=j)
sigma-=kernel->kernel(j, i)*m_alpha[j];
alpha_old=m_alpha[i];
m_alpha[i]=sigma/(kernel->kernel(i, i)+m_tau);
d=fabs(alpha_old-m_alpha[i]);
if(d>err)
err=d;
}
if (err<=m_epsilon)
flag=false;
}

return true;
}

bool CKernelRidgeRegression::train_machine(CFeatures *data)
{
switch (m_train_func)
{
case PINV:
return train_machine_pinv(data);
break;
case GS:
return train_machine_gs(data);
break;
default:
return train_machine_pinv(data);
break;
}
}

bool CKernelRidgeRegression::load(FILE* srcfile)
{
SG_SET_LOCALE_C;
Expand Down
41 changes: 40 additions & 1 deletion src/shogun/regression/KernelRidgeRegression.h
Expand Up @@ -57,13 +57,19 @@ class CKernelRidgeRegression : public CKernelMachine
/** default constructor */
CKernelRidgeRegression();

enum ETrainingType
{
PINV=1,
GS=2,
};

/** constructor
*
* @param tau regularization constant tau
* @param k kernel
* @param lab labels
*/
CKernelRidgeRegression(float64_t tau, CKernel* k, CLabels* lab);
CKernelRidgeRegression(float64_t tau, CKernel* k, CLabels* lab, ETrainingType m=PINV);
virtual ~CKernelRidgeRegression() {}

/** set regularization constant
Expand All @@ -72,6 +78,12 @@ class CKernelRidgeRegression : public CKernelMachine
*/
inline void set_tau(float64_t tau) { m_tau = tau; };

/** set precision
*
* @param tau new tau
*/
inline void set_epsilon(float64_t epsilon) { m_epsilon = epsilon; }

/** load regression from file
*
* @param srcfile file to load from
Expand Down Expand Up @@ -109,13 +121,40 @@ class CKernelRidgeRegression : public CKernelMachine
*/
virtual bool train_machine(CFeatures* data=NULL);

/** train regression using Gauss-Seidel iterative method
*
* @param data training data (parameter can be avoided if distance or
* kernel-based regressors are used and distance/kernels are
* initialized with train data)
*
* @return whether training was successful
*/
bool train_machine_gs(CFeatures* data=NULL);

/** train regression using pinv
*
* @param data training data (parameter can be avoided if distance or
* kernel-based regressors are used and distance/kernels are
* initialized with train data)
*
* @return whether training was successful
*/
bool train_machine_pinv(CFeatures* data=NULL);

private:
void init();

private:
/** regularization parameter tau */
float64_t m_tau;

/** epsilon constant */
float64_t m_epsilon;

/** training function */
ETrainingType m_train_func;
};
}

#endif // HAVE_LAPACK
#endif // _KERNELRIDGEREGRESSION_H__

0 comments on commit 6653329

Please sign in to comment.