Skip to content

Commit

Permalink
Fixes for MALSAR JFL
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Jul 25, 2012
1 parent 086ee0a commit 09b53b5
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions src/shogun/lib/malsar/malsar_joint_feature_learning.cpp
Expand Up @@ -15,6 +15,7 @@
#include <iostream>

using namespace Eigen;
using namespace std;

namespace shogun
{
Expand All @@ -39,21 +40,6 @@ malsar_result_t malsar_joint_feature_learning(
// initialize weight vector and bias for each task
MatrixXd Ws = MatrixXd::Zero(n_feats, n_tasks);
VectorXd Cs = VectorXd::Zero(n_tasks);
for (task=0; task<n_tasks; task++)
{
int n_pos = 0;
int n_neg = 0;
SGVector<index_t> task_idx = options.tasks_indices[task];
for (int i=0; i<task_idx.vlen; i++)
{
if (y[task_idx[i]] > 0)
n_pos++;
else
n_neg++;
}
Cs[task] = CMath::log(double(n_pos)/n_neg);
}

MatrixXd Wz=Ws, Wzp=Ws, Wz_old=Ws, delta_Wzp=Ws, gWs=Ws;
VectorXd Cz=Cs, Czp=Cs, Cz_old=Cs, delta_Czp=Cs, gCs=Cs;

Expand Down Expand Up @@ -99,6 +85,10 @@ malsar_result_t malsar_joint_feature_learning(
// add regularizer
Fs += Ws.squaredNorm();

//cout << "gWs" << endl << gWs << endl;
//cout << "gCs" << endl << gCs << endl;
//SG_SPRINT("Fs = %f\n",Fs);

double Fzp = 0.0;

int inner_iter = 0;
Expand All @@ -123,13 +113,13 @@ malsar_result_t malsar_joint_feature_learning(
int n_task_vecs = task_idx.vlen;
for (int i=0; i<n_task_vecs; i++)
{
double aa = -y[task_idx[i]]*(features->dense_dot(task_idx[i], Wzp.col(task).data(), n_feats)+Cs[task]);
double aa = -y[task_idx[i]]*(features->dense_dot(task_idx[i], Wzp.col(task).data(), n_feats)+Czp[task]);
double bb = CMath::max(aa,0.0);

Fzp += (CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb)/n_task_vecs;
}
}
Fzp += Wzp.squaredNorm();
Fzp += rho2*Wzp.squaredNorm();

// compute delta between line search point and search point
delta_Wzp = Wzp - Ws;
Expand All @@ -149,6 +139,7 @@ malsar_result_t malsar_joint_feature_learning(
// break if delta is getting too small
if (r_sum <= 1e-20)
{
SG_SDEBUG("Line search point is too close to search point\n");
done = true;
break;
}
Expand All @@ -170,8 +161,10 @@ malsar_result_t malsar_joint_feature_learning(
// compute objective value
obj_old = obj;
obj = Fzp;
for (task=0; task<n_tasks; task++)
obj += rho1*(Wz.col(task).norm());
for (int i=0; i<n_feats; i++)
obj += rho1*(Wz.row(i).lpNorm<2>());
//for (task=0; task<n_tasks; task++)
// obj += rho1*(Wz.col(task).norm());
SG_SDEBUG("Obj = %f\n",obj);

// check if process should be terminated
Expand All @@ -181,7 +174,10 @@ malsar_result_t malsar_joint_feature_learning(
if (iter>=2)
{
if ( CMath::abs(obj-obj_old) <= options.tolerance )
{
SG_SDEBUG("Objective changes less than tolerance\n");
done = true;
}
}
break;
case 1:
Expand Down

0 comments on commit 09b53b5

Please sign in to comment.