Skip to content

Commit

Permalink
Improved MC OCAS
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Mar 16, 2012
1 parent 5b66c20 commit 173493d
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions src/shogun/multiclass/MulticlassOCAS.cpp
Expand Up @@ -21,6 +21,7 @@ struct mocas_data
float64_t* oldW;
float64_t* full_A;
float64_t* data_y;
float64_t* output_values;
uint32_t nY;
uint32_t nData;
uint32_t nDim;
Expand Down Expand Up @@ -87,11 +88,13 @@ bool CMulticlassOCAS::train_machine(CFeatures* data)
user_data.oldW = SG_MALLOC(float64_t, num_features*num_classes);
user_data.new_a = SG_MALLOC(float64_t, num_features*num_classes);
user_data.full_A = SG_MALLOC(float64_t, num_features*num_classes*m_buf_size);
user_data.output_values = SG_MALLOC(float64_t, num_vectors);
user_data.data_y = data_y;
user_data.nY = num_classes;
user_data.nDim = num_features;
user_data.nData = num_vectors;

ocas_return_value_T value =
msvm_ocas_solver(C, data_y, nY, nData, TolRel, TolAbs,
QPBound, MaxTime, BufSize, Method,
&CMulticlassOCAS::msvm_full_compute_W,
Expand All @@ -102,6 +105,22 @@ bool CMulticlassOCAS::train_machine(CFeatures* data)
&CMulticlassOCAS::msvm_print,
&user_data);

SG_DEBUG("Number of iterations [nIter] = %d \n",value.nIter);
SG_DEBUG("Number of cutting planes [nCutPlanes] = %d \n",value.nCutPlanes);
SG_DEBUG("Number of non-zero alphas [nNZAlpha] = %d \n",value.nNZAlpha);
SG_DEBUG("Number of training errors [trn_err] = %d \n",value.trn_err);
SG_DEBUG("Primal objective value [Q_P] = %f \n",value.Q_P);
SG_DEBUG("Dual objective value [Q_D] = %f \n",value.Q_D);
SG_DEBUG("Output time [output_time] = %f \n",value.output_time);
SG_DEBUG("Sort time [sort_time] = %f \n",value.sort_time);
SG_DEBUG("Add time [add_time] = %f \n",value.add_time);
SG_DEBUG("W time [w_time] = %f \n",value.w_time);
SG_DEBUG("QP solver time [qp_solver_time] = %f \n",value.qp_solver_time);
SG_DEBUG("OCAS time [ocas_time] = %f \n",value.ocas_time);
SG_DEBUG("Print time [print_time] = %f \n",value.print_time);
SG_DEBUG("QP exit flag [qp_exitflag] = %d \n",value.qp_exitflag);
SG_DEBUG("Exit flag [exitflag] = %d \n",value.exitflag);

clear_machines();
m_machines = SGVector<CMachine*>(num_classes);
for (int32_t i=0; i<num_classes; i++)
Expand All @@ -116,6 +135,7 @@ bool CMulticlassOCAS::train_machine(CFeatures* data)
SG_FREE(user_data.oldW);
SG_FREE(user_data.new_a);
SG_FREE(user_data.full_A);
SG_FREE(user_data.output_values);

return true;
}
Expand Down Expand Up @@ -206,7 +226,7 @@ int CMulticlassOCAS::msvm_full_add_new_cut(float64_t *new_col_H, uint32_t *new_c

new_col_H[i] = tmp;
}

return 0;
}

Expand All @@ -216,15 +236,16 @@ int CMulticlassOCAS::msvm_full_compute_output(float64_t *output, void* user_data
uint32_t nY = ((mocas_data*)user_data)->nY;
uint32_t nDim = ((mocas_data*)user_data)->nDim;
uint32_t nData = ((mocas_data*)user_data)->nData;
float64_t* output_values = ((mocas_data*)user_data)->output_values;
CDotFeatures* features = ((mocas_data*)user_data)->features;

uint32_t i, y;

for(i=0; i < nData; i++)
{
for(y=0; y < nY; y++)
output[LIBOCAS_INDEX(y,i,nY)] =
features->dense_dot(i,&W[nDim*y],nDim);
for(y=0; y<nY; y++)
{
features->dense_dot_range(output_values,0,nData,NULL,&W[nDim*y],nDim,0.0);
for (i=0; i<nData; i++)
output[LIBOCAS_INDEX(y,i,nY)] = output_values[i];
}

return 0;
Expand All @@ -238,7 +259,4 @@ int CMulticlassOCAS::msvm_sort_data(float64_t* vals, float64_t* data, uint32_t s

void CMulticlassOCAS::msvm_print(ocas_return_value_T value)
{
return;
}


0 comments on commit 173493d

Please sign in to comment.