Skip to content

Commit

Permalink
A bunch of fixes for multitask algoritmhs
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Jul 30, 2012
1 parent 7893db8 commit 3bda438
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/shogun/lib/malsar/malsar_clustered.cpp
Expand Up @@ -300,7 +300,7 @@ malsar_result_t malsar_clustered(
for (int i=0; i<n_feats; i++)
{
for (task=0; task<n_tasks; task++)
tasks_w[i] = Wzp(i,task);
tasks_w(i,task) = Wzp(i,task);
}
//tasks_w.display_matrix();
SGVector<float64_t> tasks_c(n_tasks);
Expand Down
19 changes: 14 additions & 5 deletions src/shogun/lib/malsar/malsar_joint_feature_learning.cpp
Expand Up @@ -97,11 +97,20 @@ malsar_result_t malsar_joint_feature_learning(
while (inner_iter <= 1000)
{
// compute lasso projection of Ws - gWs/gamma
for (task=0; task<n_tasks; task++)
for (int i=0; i<n_feats; i++)
{
Wzp.col(task) = Ws.col(task) - gWs.col(task)/gamma;
double norm = Wzp.col(task).lpNorm<2>();
Wzp.col(task) *= CMath::max(0.0,norm-rho1/gamma)/norm;
Wzp.row(i).noalias() = Ws.row(i) - gWs.row(i)/gamma;
double norm = Wzp.row(i).lpNorm<2>();
if (norm == 0.0)
Wzp.row(i).setZero();
else
{
double threshold = norm - rho1/gamma;
if (threshold < 0.0)
Wzp.row(i).setZero();
else
Wzp.row(i) *= threshold/norm;
}
}
// walk in direction of antigradient
Czp = Cs - gCs/gamma;
Expand Down Expand Up @@ -210,7 +219,7 @@ malsar_result_t malsar_joint_feature_learning(
for (int i=0; i<n_feats; i++)
{
for (task=0; task<n_tasks; task++)
tasks_w[i] = Wzp(i,task);
tasks_w(i,task) = Wzp(i,task);
}
SGVector<float64_t> tasks_c(n_tasks);
for (int i=0; i<n_tasks; i++) tasks_c[i] = Czp[i];
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/lib/malsar/malsar_low_rank.cpp
Expand Up @@ -204,7 +204,7 @@ malsar_result_t malsar_low_rank(
for (int i=0; i<n_feats; i++)
{
for (task=0; task<n_tasks; task++)
tasks_w[i] = Wzp(i,task);
tasks_w(i,task) = Wzp(i,task);
}
SGVector<float64_t> tasks_c(n_tasks);
for (int i=0; i<n_tasks; i++) tasks_c[i] = Czp[i];
Expand Down
Expand Up @@ -207,14 +207,15 @@ CBinaryLabels* CMultitaskLogisticRegression::apply_locked_binary(SGVector<index_
{
int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
SGVector<float64_t> result(indices.vlen);
result.zero();
for (int32_t i=0; i<indices.vlen; i++)
{
for (int32_t j=0; j<n_tasks; j++)
{
if (m_tasks_indices[j].count(indices[i]))
{
set_current_task(j);
result[i] = apply_one(i);
result[i] = apply_one(indices[i]);
break;
}
}
Expand Down

0 comments on commit 3bda438

Please sign in to comment.