Navigation Menu

Skip to content

Commit

Permalink
Improved multiclass tree guided logistic regression
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Aug 21, 2012
1 parent 45840d7 commit f253152
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/shogun/lib/slep/slep_mc_tree_lr.cpp
Expand Up @@ -46,6 +46,18 @@ slep_result_t slep_mc_tree_lr(
MatrixXd w = MatrixXd::Zero(n_feats, n_classes);
// intercepts (biases)
VectorXd c = VectorXd::Zero(n_classes);

if (options.last_result)
{
SGMatrix<float64_t> last_w = options.last_result->w;
SGVector<float64_t> last_c = options.last_result->c;
for (i=0; i<n_classes; i++)
{
c[i] = last_c[i];
for (j=0; j<n_feats; j++)
w(j,i) = last_w(j,i);
}
}
// iterative process matrices and vectors
MatrixXd wp = w, wwp = MatrixXd::Zero(n_feats, n_classes);
VectorXd cp = c, ccp = VectorXd::Zero(n_classes);
Expand All @@ -55,6 +67,8 @@ slep_result_t slep_mc_tree_lr(
VectorXd search_c = VectorXd::Zero(n_classes);
// dot products
MatrixXd Aw = MatrixXd::Zero(n_vecs, n_classes);
for (j=0; j<n_classes; j++)
features->dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
MatrixXd As = MatrixXd::Zero(n_vecs, n_classes);
MatrixXd Awp = MatrixXd::Zero(n_vecs, n_classes);
// gradients
Expand Down Expand Up @@ -205,7 +219,7 @@ slep_result_t slep_mc_tree_lr(
cout << "Objective = " << objective << endl;

// check for termination of whole process
if ((CMath::abs(objective - objective_p) < options.tolerance) && (iter>2))
if ((CMath::abs(objective - objective_p) < options.tolerance*CMath::abs(objective)) && (iter>2))
{
SG_SINFO("Objective changes less than tolerance\n");
done = true;
Expand Down
16 changes: 16 additions & 0 deletions src/shogun/multiclass/MulticlassTreeGuidedLogisticRegression.cpp
Expand Up @@ -65,6 +65,22 @@ bool CMulticlassTreeGuidedLogisticRegression::train_machine(CFeatures* data)
int32_t n_feats = m_features->get_dim_feature_space();

slep_options options = slep_options::default_options();
if (m_machines->get_num_elements()!=0)
{
SGMatrix<float64_t> all_w_old(n_feats, n_classes);
SGVector<float64_t> all_c_old(n_classes);
for (int32_t i=0; i<n_classes; i++)
{
CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i);
SGVector<float64_t> w = machine->get_w();
for (int32_t j=0; j<n_feats; j++)
all_w_old(j,i) = w[j];
all_c_old[i] = machine->get_bias();
SG_UNREF(machine);
}
options.last_result = new slep_result_t(all_w_old,all_c_old);
m_machines->reset_array();
}
if (m_index_tree->is_general())
{
SGVector<float64_t> G = m_index_tree->get_SLEP_G();
Expand Down

0 comments on commit f253152

Please sign in to comment.