Skip to content

Commit

Permalink
Updated multitask logistic regression
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Jun 25, 2012
1 parent 19a6e35 commit a937587
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/shogun/lib/slep/slep_mt_lr.cpp
Expand Up @@ -16,7 +16,7 @@
namespace shogun
{

SGMatrix<double> slep_mt_lr(
slep_result_t slep_mt_lr(
CDotFeatures* features,
double* y,
double z,
Expand Down Expand Up @@ -104,7 +104,7 @@ SGMatrix<double> slep_mt_lr(
for (i=0; i<n_feats; i++)
w(i,j) = options.initial_w[j*n_feats+i];
}
double* c = SG_CALLOC(double, n_tasks);
SGVector<double> c(n_tasks);
for (t=0; t<n_tasks; t++)
c[t] = CMath::log(m1[t]/m2[t]);

Expand Down Expand Up @@ -266,10 +266,10 @@ SGMatrix<double> slep_mt_lr(

funcp = func;
func = fun_x + lambda*regularizer;
SG_SPRINT("Obj = %f + %f * %f = %f \n",fun_x, lambda, regularizer, func);
//SG_SPRINT("Obj = %f + %f * %f = %f \n",fun_x, lambda, regularizer, func);

//if (gradient_break)
// break;
if (gradient_break)
break;

double norm_wp, norm_wwp;
double step;
Expand Down Expand Up @@ -334,6 +334,6 @@ SGMatrix<double> slep_mt_lr(
SG_FREE(m2);
SG_FREE(ATb);

return w;
return slep_result_t(w,c);
};
};
2 changes: 1 addition & 1 deletion src/shogun/lib/slep/slep_mt_lr.h
Expand Up @@ -17,7 +17,7 @@
namespace shogun
{

SGMatrix<double> slep_mt_lr(
slep_result_t slep_mt_lr(
CDotFeatures* features,
double* y,
double z,
Expand Down
14 changes: 14 additions & 0 deletions src/shogun/lib/slep/slep_options.h
Expand Up @@ -14,6 +14,8 @@
#define IGNORE_IN_CLASSLIST

#include <stdlib.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/SGVector.h>

namespace shogun
{
Expand Down Expand Up @@ -52,6 +54,18 @@ IGNORE_IN_CLASSLIST struct slep_options
return opts;
}
};

IGNORE_IN_CLASSLIST struct slep_result_t
{
SGMatrix<double> w;
SGVector<double> c;

slep_result_t(SGMatrix<double> w_, SGVector<double> c_)
{
w = w_;
c = c_;
}
};
#endif
}
#endif /* ----- #ifndef SLEP_OPTIONS_H_ ----- */
Expand Down
Expand Up @@ -58,6 +58,8 @@ void CMultitaskLogisticRegression::set_current_task(int32_t task)
w = SGVector<float64_t>(n_feats);
for (int32_t i=0; i<n_feats; i++)
w[i] = m_tasks_w(i,task);

bias = m_tasks_c[task];
}

CTaskRelation* CMultitaskLogisticRegression::get_task_relation() const
Expand Down Expand Up @@ -100,7 +102,9 @@ bool CMultitaskLogisticRegression::train_machine(CFeatures* data)
options.ind = ind.vector;
options.n_nodes = ind.vlen-1;

m_tasks_w = slep_mt_lr(features, y.vector, m_z, options);
slep_result_t result = slep_mt_lr(features, y.vector, m_z, options);
m_tasks_w = result.w;
m_tasks_c = result.c;
}
break;
case TREE:
Expand Down
3 changes: 3 additions & 0 deletions src/shogun/transfer/multitask/MultitaskLogisticRegression.h
Expand Up @@ -86,6 +86,9 @@ class CMultitaskLogisticRegression : public CSLEPMachine

/** tasks w's */
SGMatrix<float64_t> m_tasks_w;

/** tasks interceptss */
SGVector<float64_t> m_tasks_c;

};
}
Expand Down

0 comments on commit a937587

Please sign in to comment.