Skip to content

Commit

Permalink
Added getter for support vector indices for MC liblinear
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Apr 1, 2012
1 parent b43c2ab commit 067287a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/shogun/multiclass/MulticlassLibLinear.cpp
Expand Up @@ -13,6 +13,7 @@
#include <shogun/multiclass/MulticlassLibLinear.h>
#include <shogun/classifier/svm/SVM_linear.h>
#include <shogun/mathematics/Math.h>
#include <shogun/lib/v_array.h>

using namespace shogun;

Expand Down Expand Up @@ -53,6 +54,33 @@ CMulticlassLibLinear::~CMulticlassLibLinear()
reset_train_state();
}

SGVector<int32_t> CMulticlassLibLinear::get_support_vectors() const
{
if (!m_train_state)
SG_ERROR("Please enable save_train_state option and train machine.\n");

int32_t num_vectors = m_features->get_num_vectors();
int32_t num_classes = m_labels->get_num_classes();

v_array<int32_t> nz_idxs;
nz_idxs.reserve(num_vectors);

for (int32_t i=0; i<num_vectors; i++)
{
for (int32_t y=0; y<num_classes; y++)
{
if (CMath::abs(m_train_state->alpha[i*num_classes+y])>1e-6)
{
nz_idxs.push(i);
break;
}
}
}
int32_t num_nz = nz_idxs.index();
nz_idxs.reserve(num_nz);
return SGVector<int32_t>(nz_idxs.begin,num_nz);
}

bool CMulticlassLibLinear::train_machine(CFeatures* data)
{
if (data)
Expand Down
5 changes: 5 additions & 0 deletions src/shogun/multiclass/MulticlassLibLinear.h
Expand Up @@ -135,6 +135,11 @@ class CMulticlassLibLinear : public CLinearMulticlassMachine
}
}

/** get support vector indices
* @return support vector indices
*/
SGVector<int32_t> get_support_vectors() const;

protected:

/** train machine */
Expand Down

0 comments on commit 067287a

Please sign in to comment.