Skip to content

Commit

Permalink
fixed store_model_features, which didnt work before, causing x-val on…
Browse files Browse the repository at this point in the history
… multi class kernel machines to fail badly
  • Loading branch information
karlnapf committed Jul 17, 2012
1 parent 23dd339 commit 8d97d0c
Showing 1 changed file with 47 additions and 30 deletions.
77 changes: 47 additions & 30 deletions src/shogun/machine/KernelMulticlassMachine.cpp
Expand Up @@ -5,57 +5,74 @@
* (at your option) any later version.
*
* Written (W) 2012 Chiyuan Zhang
* Written (W) 2012 Heiko Strathmann
* Copyright (C) 2012 Chiyuan Zhang
*/

#include <shogun/lib/Map.h>
#include <shogun/lib/Set.h>
#include <shogun/machine/KernelMulticlassMachine.h>

using namespace shogun;

void CKernelMulticlassMachine::store_model_features()
{
CKernel *kernel = ((CKernelMachine *)m_machine)->get_kernel();
CKernel *kernel=((CKernelMachine *)m_machine)->get_kernel();
if (!kernel)
SG_ERROR("kernel is needed to store SV features.\n");

CFeatures* lhs = kernel->get_lhs();
CFeatures* rhs = kernel->get_rhs();
CFeatures* lhs=kernel->get_lhs();
CFeatures* rhs=kernel->get_rhs();
if (!lhs)
SG_ERROR("kernel lhs is needed to store SV features.\n");
{
SG_ERROR("%s::store_model_features(): kernel lhs is needed to store "
"SV features.\n");
}

CMap<int32_t, int32_t> all_sv;
for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
{
CKernelMachine *machine = (CKernelMachine *)get_machine(i);
for (int32_t j=0; j < machine->get_num_support_vectors(); ++j)
all_sv.add(machine->get_support_vector(j), 0);
/* this map will be abused as a map */
CSet<index_t> all_sv;
for (index_t i=0; i<m_machines->get_num_elements(); ++i)
{
CKernelMachine *machine=(CKernelMachine *)get_machine(i);
for (index_t j=0; j<machine->get_num_support_vectors(); ++j)
all_sv.add(machine->get_support_vector(j));

SG_UNREF(machine);
}
SG_UNREF(machine);
}

SGVector<int32_t> sv_idx(all_sv.get_num_elements());
for (index_t i=0; i < sv_idx.vlen; ++i)
sv_idx[i] = all_sv.get_element(i);

for (index_t i=0; i < sv_idx.vlen; ++i)
*all_sv.get_element_ptr(all_sv.index_of(sv_idx[i])) = i;
/* convert map to vector of SV */
SGVector<index_t> sv_idx(all_sv.get_num_elements());
for (index_t i=0; i<sv_idx.vlen; ++i)
sv_idx[i]=*all_sv.get_element_ptr(i);

CFeatures* sv_features=lhs->copy_subset(sv_idx);

kernel->init(sv_features, rhs);
/* now, features are replaced by concatenated SV features */
kernel->init(sv_features, rhs);

/* now the old SV indices have to be mapped to the new features */

/* update SV of all machines */
for (int32_t i=0; i<m_machines->get_num_elements(); ++i)
{
CKernelMachine *machine=(CKernelMachine *)get_machine(i);

/* for each machine, replace SV by index in sv_idx array */
for (int32_t j=0; j<machine->get_num_support_vectors(); ++j)
{
/* get index of SV in old features */
index_t current_sv_idx=machine->get_support_vector(j);

for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
{
CKernelMachine *machine = (CKernelMachine *)get_machine(i);
/* the position of this old index in the map is the position of
* the SV in the new features */
index_t new_sv_idx=all_sv.index_of(current_sv_idx);

for (int32_t j=0; j < machine->get_num_support_vectors(); ++j)
machine->set_support_vector(j, all_sv.get_element(all_sv.index_of(machine->get_support_vector(j))));
machine->set_support_vector(j, new_sv_idx);
}

SG_UNREF(machine);
}
SG_UNREF(machine);
}

SG_UNREF(lhs);
SG_UNREF(rhs);
SG_UNREF(kernel);
SG_UNREF(lhs);
SG_UNREF(rhs);
SG_UNREF(kernel);
}

0 comments on commit 8d97d0c

Please sign in to comment.