Skip to content

Commit

Permalink
Refactored SO machine to not use its own features reference, fixed ge…
Browse files Browse the repository at this point in the history
…neric risk and bmrm result type
  • Loading branch information
lisitsyn committed Aug 21, 2012
1 parent 530155d commit 45840d7
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 31 deletions.
Expand Up @@ -56,7 +56,7 @@ def get_so_labels(out):
loss = HingeLoss()

lambda_ = 1e1
sosvm = DualLibQPBMSOSVM(model, loss, labels, features, lambda_)
sosvm = DualLibQPBMSOSVM(model, loss, labels, lambda_)

sosvm.set_cleanAfter(10) # number of iterations that cutting plane has to be inactive for to be removed
sosvm.set_cleanICP(True) # enables inactive cutting plane removal feature
Expand All @@ -70,8 +70,8 @@ def get_so_labels(out):
sosvm.train()

res = sosvm.get_result()
Fps = np.array(res.hist_Fp)
Fds = np.array(res.hist_Fd)
Fps = np.array(res.get_hist_Fp())
Fds = np.array(res.get_hist_Fd())
wdists = np.array(res.hist_wdist)

plt.figure()
Expand Down
27 changes: 27 additions & 0 deletions examples/undocumented/python_modular/structure_hmsvm_bmrm.py
@@ -0,0 +1,27 @@
#!/usr/bin/env python

import numpy
import scipy

from scipy import io
from shogun.Features import RealMatrixFeatures
from shogun.Loss import HingeLoss
from shogun.Structure import HMSVMLabels, HMSVMModel, Sequence, TwoStateModel, SMT_TWO_STATE
from shogun.Evaluation import StructuredAccuracy
from shogun.Structure import DualLibQPBMSOSVM

data_dict = scipy.io.loadmat('../data/hmsvm_data_large_integer.mat')
labels_array = data_dict['label'][0]
idxs = numpy.nonzero(labels_array == -1)
labels_array[idxs] = 0
labels = HMSVMLabels(labels_array, 250, 500, 2)
features = RealMatrixFeatures(data_dict['signal'].astype(float), 250, 500)
loss = HingeLoss()
model = HMSVMModel(features, labels, SMT_TWO_STATE, 4)
sosvm = DualLibQPBMSOSVM(model, loss, labels, 5000.0)
sosvm.train()
print sosvm.get_w()
predicted = sosvm.apply()
evaluator = StructuredAccuracy()
acc = evaluator.evaluate(predicted, labels)
print('Accuracy = %.4f' % acc)
1 change: 0 additions & 1 deletion src/interfaces/modular/modshogun_ignores.i
Expand Up @@ -28,7 +28,6 @@
%ignore shogun::CKernelMeanMatching::CKernelMeanMatching(CKernel* kernel, SGVector<index_t> training_indices, SGVector<index_t> test_indices);
#endif

%ignore shogun::bmrm_return_value_T;
%ignore shogun::bmrm_ll;
%ignore shogun::TMultipleCPinfo;
%ignore refcount_t;
Expand Down
25 changes: 9 additions & 16 deletions src/shogun/machine/LinearStructuredOutputMachine.cpp
Expand Up @@ -14,39 +14,32 @@
using namespace shogun;

CLinearStructuredOutputMachine::CLinearStructuredOutputMachine()
: CStructuredOutputMachine(), m_features(NULL)
: CStructuredOutputMachine()
{
register_parameters();
}

CLinearStructuredOutputMachine::CLinearStructuredOutputMachine(
CStructuredModel* model,
CLossFunction* loss,
CStructuredLabels* labs,
CFeatures* features)
: CStructuredOutputMachine(model, loss, labs), m_features(NULL)
CStructuredLabels* labs)
: CStructuredOutputMachine(model, loss, labs)
{
set_features(features);
register_parameters();
}

CLinearStructuredOutputMachine::~CLinearStructuredOutputMachine()
{
SG_UNREF(m_features)
}

void CLinearStructuredOutputMachine::set_features(CFeatures* f)
{
SG_REF(f);
SG_UNREF(m_features);
m_features = f;
m_model->set_features(f);
}

CFeatures* CLinearStructuredOutputMachine::get_features() const
{
SG_REF(m_features);
return m_features;
return m_model->get_features();
}

SGVector< float64_t > CLinearStructuredOutputMachine::get_w() const
Expand All @@ -62,28 +55,28 @@ CStructuredLabels* CLinearStructuredOutputMachine::apply_structured(CFeatures* d
}

CStructuredLabels* out;
if ( !m_features )
CFeatures* model_features = this->get_features();
if (!model_features)
{
out = new CStructuredLabels();
}
else
{
out = new CStructuredLabels( m_features->get_num_vectors() );
for ( int32_t i = 0 ; i < m_features->get_num_vectors() ; ++i )
out = new CStructuredLabels(model_features->get_num_vectors());
for ( int32_t i = 0 ; i < model_features->get_num_vectors() ; ++i )
{
CResultSet* result = m_model->argmax(m_w, i, false);
out->add_label(result->argmax);

SG_UNREF(result);
}
}

SG_UNREF(model_features);
SG_REF(out);
return out;
}

void CLinearStructuredOutputMachine::register_parameters()
{
SG_ADD((CSGObject**)&m_features, "m_features", "Feature object", MS_NOT_AVAILABLE);
SG_ADD(&m_w, "m_w", "Weight vector", MS_NOT_AVAILABLE);
}
3 changes: 1 addition & 2 deletions src/shogun/machine/LinearStructuredOutputMachine.h
Expand Up @@ -29,9 +29,8 @@ class CLinearStructuredOutputMachine : public CStructuredOutputMachine
* @param model structured model with application specific functions
* @param loss loss function
* @param labs structured labels
* @param features features
*/
CLinearStructuredOutputMachine(CStructuredModel* model, CLossFunction* loss, CStructuredLabels* labs, CFeatures* features);
CLinearStructuredOutputMachine(CStructuredModel* model, CLossFunction* loss, CStructuredLabels* labs);

/** destructor */
virtual ~CLinearStructuredOutputMachine();
Expand Down
6 changes: 4 additions & 2 deletions src/shogun/structure/DualLibQPBMSOSVM.cpp
Expand Up @@ -23,10 +23,9 @@ CDualLibQPBMSOSVM::CDualLibQPBMSOSVM(
CStructuredModel* model,
CLossFunction* loss,
CStructuredLabels* labs,
CDotFeatures* features,
float64_t _lambda,
SGVector< float64_t > W)
:CLinearStructuredOutputMachine(model, loss, labs, features)
: CLinearStructuredOutputMachine(model, loss, labs)
{
set_TolRel(0.001);
set_TolAbs(0.0);
Expand Down Expand Up @@ -84,6 +83,9 @@ void CDualLibQPBMSOSVM::init()

bool CDualLibQPBMSOSVM::train_machine(CFeatures* data)
{
if (data)
set_features(data);

// call the solver
switch(m_solver)
{
Expand Down
1 change: 0 additions & 1 deletion src/shogun/structure/DualLibQPBMSOSVM.h
Expand Up @@ -62,7 +62,6 @@ class CDualLibQPBMSOSVM : public CLinearStructuredOutputMachine
CStructuredModel* model,
CLossFunction* loss,
CStructuredLabels* labs,
CDotFeatures* features,
float64_t _lambda,
SGVector< float64_t > W=0);

Expand Down
1 change: 1 addition & 0 deletions src/shogun/structure/StructuredModel.cpp
Expand Up @@ -167,6 +167,7 @@ float64_t CStructuredModel::risk(float64_t* subgrad, float64_t* W, TMultipleCPin
SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, 1.0, psi_pred.vector, dim);
SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, -1.0, psi_truth.vector, dim);
R += result->score;
R += this->delta_loss(i, result->argmax);
SG_UNREF(result);
}

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/structure/libbmrm.cpp
Expand Up @@ -110,7 +110,7 @@ bmrm_return_value_T svm_bmrm_solver(
uint32_t Tmax,
bool verbose)
{
bmrm_return_value_T bmrm={0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
bmrm_return_value_T bmrm;
libqp_state_T qp_exitflag={0, 0, 0, 0};
float64_t *b, *beta, *diag_H, *prevW;
float64_t R, *subgrad, *A, QPSolverTolRel, C=1.0, wdist=0.0;
Expand Down
34 changes: 31 additions & 3 deletions src/shogun/structure/libbmrm.h
Expand Up @@ -25,14 +25,33 @@
#define LIBBMRM_MEMMOVE(x, y, z) memmove(x, y, z)
#define LIBBMRM_INDEX(ROW, COL, NUM_ROWS) ((COL)*(NUM_ROWS)+(ROW))
#define LIBBMRM_ABS(A) ((A) < 0 ? -(A) : (A))
#define IGNORE_IN_CLASSLIST

namespace shogun
{
/** BMRM result structure */
IGNORE_IN_CLASSLIST struct bmrm_return_value_T
//struct bmrm_return_value_T
struct bmrm_return_value_T
{
/** constructor */
bmrm_return_value_T()
{
nIter = 0;
nCP = 0;
nzA = 0;
Fp = 0;
Fd = 0;
qp_exitflag = 0;
exitflag = 0;
};

/** destructor */
~bmrm_return_value_T() { };

/** dummy load serializable */
bool load_serializable(CSerializableFile* file, const char* prefix="") { return false; }

/** dummy save serializable */
bool save_serializable(CSerializableFile* file, const char* prefix="") { return false; }

/** number of iterations */
uint32_t nIter;

Expand Down Expand Up @@ -66,6 +85,15 @@ IGNORE_IN_CLASSLIST struct bmrm_return_value_T

/** Track of w_dist values in individual iterations */
SGVector< float64_t > hist_wdist;

/** get hist Fp */
SGVector<float64_t> get_hist_Fp() const { return hist_Fp; }

/** get hist Fd */
SGVector<float64_t> get_hist_Fd() const { return hist_Fd; }

/** get hist wdist */
SGVector<float64_t> get_hist_wdist() const { return hist_wdist; }
};

/** Linked list for cutting planes buffer management */
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/structure/libp3bm.cpp
Expand Up @@ -49,7 +49,7 @@ bmrm_return_value_T svm_p3bm_solver(
uint32_t cp_models,
bool verbose)
{
bmrm_return_value_T p3bmrm={0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
bmrm_return_value_T p3bmrm;
libqp_state_T qp_exitflag={0, 0, 0, 0}, qp_exitflag_good={0, 0, 0, 0};
float64_t *b, *b2, *beta, *beta_good, *beta_start, *diag_H, *diag_H2;
float64_t R, *Rt, **subgrad_t, *A, QPSolverTolRel, *C=NULL;
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/structure/libppbm.cpp
Expand Up @@ -48,7 +48,7 @@ bmrm_return_value_T svm_ppbm_solver(
uint32_t Tmax,
bool verbose)
{
bmrm_return_value_T ppbmrm={0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
bmrm_return_value_T ppbmrm;
libqp_state_T qp_exitflag={0, 0, 0, 0}, qp_exitflag_good={0, 0, 0, 0};
float64_t *b, *b2, *beta, *beta_good, *beta_start, *diag_H, *diag_H2;
float64_t R, *subgrad, *A, QPSolverTolRel, C=1.0;
Expand Down

0 comments on commit 45840d7

Please sign in to comment.