Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Conjugate index improvement
  • Loading branch information
lisitsyn committed Feb 1, 2012
1 parent aaf479e commit 0be4076
Showing 1 changed file with 9 additions and 39 deletions.
48 changes: 9 additions & 39 deletions src/shogun/classifier/ConjugateIndex.cpp
Expand Up @@ -84,51 +84,21 @@ bool CConjugateIndex::train(CFeatures* train_features)

m_feature_vector = SGVector<float64_t>(num_features);

//float64_t* evals = SG_MALLOC(float64_t, num_features);
//float64_t* evecs = SG_MALLOC(float64_t, num_features*num_features);
SGMatrix<float64_t> matrix(CMath::max(num_features,num_vectors),CMath::max(num_features,num_vectors));
SGMatrix<float64_t> class_feature_matrix(num_features,CMath::max(num_features,num_vectors));
SGMatrix<float64_t> helper_matrix(CMath::max(num_features,num_vectors),num_features);

SG_PROGRESS(0,0,m_num_classes-1);

for (int32_t label=0; label<m_num_classes; label++)
{
/*
int32_t count = 0;
for (int32_t i=0; i<num_vectors; i++)
{
if ((int32_t)labels->get_label(i) == label)
{
for (int32_t j=0; j<num_features; j++)
{
for (int32_t k=0; k<num_features; k++)
{
matrix[j*num_features+k] +=
feature_matrix[i*num_features+j]*
feature_matrix[i*num_features+k];
}
}
count++;
}
}
ASSERT(num_features>count);
int32_t info = 0;
wrap_dsyevr('V','U',num_features,matrix.matrix,num_features,1,num_features-count+1,evals,evecs,&info);
cblas_dgemm(CblasColMajor,CblasNoTrans,CblasTrans,
num_features,num_features,num_features-count-1,
1.0,evecs,num_features,
evecs,num_features,
0.0,m_classes[label].matrix,num_features);
ASSERT(!info);
*/
int32_t count = 0;
for (int32_t i=0; i<num_vectors; i++)
{
if ((int32_t)labels->get_label(i) == label)
if (labels->get_int_label(i) == label)
count++;
}

SGMatrix<float64_t> class_feature_matrix(num_features,count);
SGMatrix<float64_t> matrix(count,count);
SGMatrix<float64_t> helper_matrix(num_features,count);

count = 0;
for (int32_t i=0; i<num_vectors; i++)
{
Expand Down Expand Up @@ -162,10 +132,10 @@ bool CConjugateIndex::train(CFeatures* train_features)
0.0,m_classes[label].matrix,num_features);

SG_PROGRESS(label+1,0,m_num_classes);
helper_matrix.destroy_matrix();
class_feature_matrix.destroy_matrix();
matrix.destroy_matrix();
}
helper_matrix.destroy_matrix();
class_feature_matrix.destroy_matrix();
matrix.destroy_matrix();
SG_DONE();

return true;
Expand Down

0 comments on commit 0be4076

Please sign in to comment.