Skip to content

Commit

Permalink
Merged tree and group guided least squares SLEP solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Jul 2, 2012
1 parent 0e836c4 commit 725d6ad
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 331 deletions.
153 changes: 111 additions & 42 deletions src/shogun/lib/slep/slep_mt_lsr.cpp
Expand Up @@ -12,20 +12,112 @@
#include <shogun/mathematics/Math.h>
#include <shogun/lib/slep/q1/eppMatrix.h>
#include <shogun/mathematics/lapack.h>
#include <shogun/lib/slep/tree/altra.h>
#include <shogun/lib/slep/tree/general_altra.h>

namespace shogun
{

double compute_ls_regularizer(double* w, int n_vecs, int n_feats,
int n_tasks, const slep_options& options)
{
double regularizer = 0.0;
switch (options.mode)
{
case MULTITASK_GROUP:
{
for (int i=0; i<n_feats; i++)
{
double w_row_norm = 0.0;
for (int t=0; t<n_tasks; t++)
w_row_norm += CMath::pow(w[i+t*n_feats],options.q);
regularizer += CMath::pow(w_row_norm,1.0/options.q);
}
}
break;
case MULTITASK_TREE:
{
for (int i=0; i<n_feats; i++)
{
double tree_norm = 0.0;

if (options.general)
tree_norm = general_treeNorm(w+i, n_tasks, n_tasks, options.G, options.ind_t, options.n_nodes);
else
tree_norm = treeNorm(w+i, n_tasks, n_tasks, options.ind_t, options.n_nodes);

regularizer += tree_norm;
}
}
break;
default:
SG_SERROR("WHOA?\n");
}
return regularizer;
}

double compute_ls_lambda(double z, CDotFeatures* features, double* y, double* ATy, int n_vecs,
int n_feats, int n_tasks, const slep_options& options)
{
double lambda_max = 0.0;

if (z<0 || z>1)
SG_SERROR("z is not in range [0,1]");

switch (options.mode)
{
case MULTITASK_GROUP:
{
double q_bar = 0.0;
if (options.q==1)
q_bar = CMath::ALMOST_INFTY;
else if (options.q>1e6)
q_bar = 1;
else
q_bar = options.q/(options.q-1);

lambda_max = 0.0;

for (int i=0; i<n_feats; i++)
{
double sum = 0.0;
for (int t=0; t<n_tasks; t++)
sum += CMath::pow(fabs(ATy[t*n_feats+i]),q_bar);
lambda_max =
CMath::max(lambda_max, CMath::pow(sum,1.0/q_bar));
}
}
break;
case MULTITASK_TREE:
{
if (options.general)
lambda_max = general_findLambdaMax_mt(ATy, n_feats, n_tasks,
options.G, options.ind_t,
options.n_nodes);
else
lambda_max = findLambdaMax_mt(ATy, n_feats, n_tasks,
options.ind_t, options.n_nodes);
}
break;
default:
SG_SERROR("WHOAA!\n");
}

SG_FREE(ATy);
return z*lambda_max;
}


SGMatrix<double> slep_mt_lsr(
CDotFeatures* features,
double* y,
double z,
const slep_options& options)
{
int i,j,t;
int i,t;
int n_feats = features->get_dim_feature_space();
int n_vecs = features->get_num_vectors();
double lambda, lambda_max, beta;
double lambda, beta;
double funcp = 0.0, func = 0.0;

int n_tasks = options.n_tasks;
Expand All @@ -44,42 +136,13 @@ SGMatrix<double> slep_mt_lsr(
}

if (options.regularization!=0)
{
if (z<0 || z>1)
SG_SERROR("z is not in range [0,1]");

double q_bar = 0.0;
if (options.q==1)
q_bar = CMath::ALMOST_INFTY;
else if (options.q>1e6)
q_bar = 1;
else
q_bar = options.q/(options.q-1);

lambda_max = 0.0;

for (i=0; i<n_feats; i++)
{
double sum = 0.0;
for (t=0; t<n_tasks; t++)
sum += CMath::pow(fabs(ATy[t*n_feats+i]),q_bar);
lambda_max =
CMath::max(lambda_max, CMath::pow(sum,1.0/q_bar));
}

lambda = z*lambda_max;
}
lambda = compute_ls_lambda(z, features, y, ATy, n_vecs,
n_feats, n_tasks, options);
else
lambda = z;

SGMatrix<double> w(n_feats,n_tasks);
w.zero();
if (options.initial_w)
{
for (j=0; j<n_tasks; j++)
for (i=0; i<n_feats; i++)
w(i,j) = options.initial_w[j*n_feats+i];
}

double* s = SG_CALLOC(double, n_feats*n_tasks);
double* g = SG_CALLOC(double, n_feats*n_tasks);
Expand Down Expand Up @@ -151,7 +214,20 @@ SGMatrix<double> slep_mt_lsr(
for (i=0; i<n_feats*n_tasks; i++)
v[i] = s[i] - g[i]*(1.0/L);

eppMatrix(w.matrix, v, n_feats, n_tasks, lambda/L, options.q);
switch (options.mode)
{
case MULTITASK_GROUP:
eppMatrix(w.matrix, v, n_feats, n_tasks, lambda/L, options.q);
break;
case MULTITASK_TREE:
if (options.general)
general_altra_mt(w.matrix, v, n_feats, n_tasks, options.G, options.ind_t, options.n_nodes, lambda/L);
else
altra_mt(w.matrix, v, n_feats, n_tasks, options.ind_t, options.n_nodes, lambda/L);
break;
default:
SG_SERROR("WHOA?!\n");
}

// v = x - s
for (i=0; i<n_feats*n_tasks; i++)
Expand Down Expand Up @@ -194,14 +270,7 @@ SGMatrix<double> slep_mt_lsr(
for (i=0; i<n_vecs; i++)
resid[i] = Aw[i] - y[i];

double regularizer = 0.0;
for (i=0; i<n_feats; i++)
{
double w_row_norm = 0.0;
for (t=0; t<n_tasks; t++)
w_row_norm += CMath::pow(w(i,t),options.q);
regularizer += CMath::pow(w_row_norm,1.0/options.q);
}
double regularizer = compute_ls_regularizer(w.matrix, n_vecs, n_feats, n_tasks, options);

funcp = func;
func = 0.5*SGVector<float64_t>::dot(resid,resid,n_vecs) + lambda*regularizer;
Expand Down

0 comments on commit 725d6ad

Please sign in to comment.