Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fixes for MALSAR's joint feature learning
  • Loading branch information
lisitsyn committed Jul 15, 2012
1 parent f146cab commit f205180
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/shogun/lib/slep/malsar_joint_feature_learning.cpp
Expand Up @@ -31,11 +31,11 @@ slep_result_t malsar_joint_feature_learning(
{
int task;
int n_feats = features->get_dim_feature_space();
SG_SPRINT("n feats = %d\n", n_feats);
SG_SDEBUG("n feats = %d\n", n_feats);
int n_vecs = features->get_num_vectors();
SG_SPRINT("n vecs = %d\n", n_vecs);
SG_SDEBUG("n vecs = %d\n", n_vecs);
int n_tasks = options.n_tasks;
SG_SPRINT("n tasks = %d\n", n_tasks);
SG_SDEBUG("n tasks = %d\n", n_tasks);

int iter = 0;

Expand Down Expand Up @@ -64,7 +64,8 @@ slep_result_t malsar_joint_feature_learning(
double obj=0.0, obj_old=0.0;

internal::set_is_malloc_allowed(false);
while (iter < options.max_iter)
bool done = false;
while (!done && iter <= options.max_iter)
{
double alpha = double(t_old - 1)/t;

Expand Down Expand Up @@ -99,7 +100,6 @@ slep_result_t malsar_joint_feature_learning(
Fs += Ws.squaredNorm();

double Fzp = 0.0;
double gradient_break = false;

// line search, Armijo-Goldstein scheme
while (true)
Expand Down Expand Up @@ -146,7 +146,7 @@ slep_result_t malsar_joint_feature_learning(
// break if delta is getting too small
if (r_sum <= 1e-20)
{
gradient_break = true;
done = true;
break;
}

Expand All @@ -157,15 +157,13 @@ slep_result_t malsar_joint_feature_learning(
gamma *= gamma_inc;
}

if (gradient_break)
break;

Wz_old = Wz;
Cz_old = Cz;
Wz = Wzp;
Cz = Czp;

// compute objective value
obj_old = obj;
obj = Fzp;
for (task=0; task<n_tasks; task++)
obj += rho1*(Wz.col(task).norm());
Expand All @@ -176,32 +174,32 @@ slep_result_t malsar_joint_feature_learning(
case 0:
if (iter>=2)
{
if ( (CMath::abs(obj)-CMath::abs(obj_old)) <= options.tolerance)
break;
if ( CMath::abs(obj-obj_old) <= options.tolerance )
done = true;
}
break;
case 1:
if (iter>=2)
{
if ( (CMath::abs(obj)-CMath::abs(obj_old)) <= options.tolerance*CMath::abs(obj_old))
break;
if ( CMath::abs(obj-obj_old) <= options.tolerance*CMath::abs(obj_old))
done = true;
}
break;
case 2:
if (CMath::abs(obj) <= options.tolerance)
break;
done = true;
break;
case 3:
if (iter>=options.max_iter)
break;
done = true;
break;
}

iter++;
t_old = t;
t = 0.5 * (1 + CMath::sqrt(1.0 + 4*t*t));
}
SG_SDEBUG("%d iteration passed\n",iter);
SG_SDEBUG("%d iteration passed, objective = %f\n",iter,obj);

SGMatrix<float64_t> tasks_w(n_feats, n_tasks);
for (int i=0; i<n_feats; i++)
Expand Down

0 comments on commit f205180

Please sign in to comment.