Skip to content

Commit

Permalink
GS KRR work on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
uricamic committed Apr 18, 2012
1 parent 0f16159 commit 7a4b0cc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 41 deletions.
48 changes: 18 additions & 30 deletions src/shogun/regression/KernelRidgeRegression.cpp
Expand Up @@ -33,7 +33,7 @@ CKernelRidgeRegression::CKernelRidgeRegression(float64_t tau, CKernel* k, CLabel
set_labels(lab);
set_kernel(k);
set_epsilon(0.0001);
m_train_func=getTrainFunction(m);
m_train_func=m;
}

void CKernelRidgeRegression::init()
Expand All @@ -43,22 +43,6 @@ void CKernelRidgeRegression::init()
SG_ADD(&m_tau, "tau", "Regularization parameter", MS_AVAILABLE);
}

CKernelRidgeRegression::train_method CKernelRidgeRegression::getTrainFunction(ETrainingType id)
{
switch (id)
{
case PINV:
return &CKernelRidgeRegression::train_machine_pinv;
break;
case GS:
return &CKernelRidgeRegression::train_machine_gs;
break;
default:
return &CKernelRidgeRegression::train_machine_pinv;
break;
}
}

bool CKernelRidgeRegression::train_machine_pinv(CFeatures* data)
{
if (!m_labels)
Expand Down Expand Up @@ -117,14 +101,9 @@ bool CKernelRidgeRegression::train_machine_gs(CFeatures* data)
}
ASSERT(kernel && kernel->has_features());

// Get kernel matrix
SGMatrix<float64_t> A = kernel->get_kernel_matrix<float64_t>();
int32_t n = A.num_cols;
int32_t m = A.num_rows;
ASSERT(A.matrix && m>0 && n>0);

for(int32_t i=0; i < n; i++)
A.matrix[i+i*n]+=m_tau;
int32_t n = kernel->get_num_vec_rhs();
int32_t m = kernel->get_num_vec_lhs();
ASSERT(m>0 && n>0);

// re-set alphas of kernel machine
m_alpha.destroy_vector();
Expand Down Expand Up @@ -157,9 +136,9 @@ bool CKernelRidgeRegression::train_machine_gs(CFeatures* data)
sigma=b[i];
for(int32_t j=0; j<n; j++)
if (i!=j)
sigma-=A.matrix[j+i*n]*m_alpha[j];
sigma-=kernel->kernel(j, i)*m_alpha[j];
alpha_old=m_alpha[i];
m_alpha[i]=sigma/A.matrix[i+i*n];
m_alpha[i]=sigma/(kernel->kernel(i, i)+m_tau);
d=fabs(alpha_old-m_alpha[i]);
if(d>err)
err=d;
Expand All @@ -168,14 +147,23 @@ bool CKernelRidgeRegression::train_machine_gs(CFeatures* data)
flag=false;
}

SG_FREE(A.matrix);

return true;
}

bool CKernelRidgeRegression::train_machine(CFeatures *data)
{
return (this->*m_train_func)(data);
switch (m_train_func)
{
case PINV:
return train_machine_pinv(data);
break;
case GS:
return train_machine_gs(data);
break;
default:
return train_machine_pinv(data);
break;
}
}

bool CKernelRidgeRegression::load(FILE* srcfile)
Expand Down
13 changes: 2 additions & 11 deletions src/shogun/regression/KernelRidgeRegression.h
Expand Up @@ -54,11 +54,9 @@ namespace shogun
class CKernelRidgeRegression : public CKernelMachine
{
public:
/** default constructor */
/** default constructor */
CKernelRidgeRegression();

typedef bool (CKernelRidgeRegression::*train_method)(CFeatures* data);

enum ETrainingType
{
PINV=1,
Expand All @@ -74,11 +72,6 @@ class CKernelRidgeRegression : public CKernelMachine
CKernelRidgeRegression(float64_t tau, CKernel* k, CLabels* lab, ETrainingType m=PINV);
virtual ~CKernelRidgeRegression() {}

/** set training function
*
*/
inline void set_train_func(train_method m) { m_train_func = m; };

/** set regularization constant
*
* @param tau new tau
Expand Down Expand Up @@ -159,9 +152,7 @@ class CKernelRidgeRegression : public CKernelMachine
float64_t m_epsilon;

/** training function */
train_method m_train_func;

static train_method getTrainFunction(ETrainingType id);
ETrainingType m_train_func;
};
}

Expand Down

0 comments on commit 7a4b0cc

Please sign in to comment.