Skip to content

Commit

Permalink
DualLibQPBMSOSVM interface for libbmrm
Browse files Browse the repository at this point in the history
  • Loading branch information
uricamic committed Jun 7, 2012
1 parent d8a7270 commit d6c2696
Show file tree
Hide file tree
Showing 6 changed files with 490 additions and 0 deletions.
65 changes: 65 additions & 0 deletions src/shogun/so/DualLibQPBMSOSVM.cpp
@@ -0,0 +1,65 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Michal Uricar
* Copyright (C) 2012 Michal Uricar
*/

#include <shogun/so/DualLibQPBMSOSVM.h>
#include <shogun/so/libbmrm.h>

using namespace shogun;

CDualLibQPBMSOSVM::CDualLibQPBMSOSVM()
:CLinearStructuredOutputMachine()
{
}

CDualLibQPBMSOSVM::CDualLibQPBMSOSVM(
CStructuredModel* model,
CLossFunction* loss,
CStructuredLabels* labs,
CFeatures* features,
float64_t lambda)
:CLinearStructuredOutputMachine(model, loss, labs, features)
{
set_opitons(0.001, 0.0, 100);
set_lambda(lambda);
}

CDualLibQPBMSOSVM::~CDualLibQPBMSOSVM()
{
}

void CDualLibQPBMSOSVM::set_opitons(float64_t TolRel, float64_t TolAbs, uint32_t BufSize)
{
m_TolRel=TolRel;
m_TolAbs=TolAbs;
m_BufSize=BufSize;
}

bool CDualLibQPBMSOSVM::train_machine(CFeatures* data)
{
// get dimension of w
uint32_t nDim=m_risk_function->get_w_dim(data);

// call the BMRM solver
bmrm_return_value_T result = svm_bmrm_solver(data, m_w.vector, m_TolRel, m_TolAbs, m_lambda,
m_BufSize, nDim, m_risk_function);

if (result.exitflag==1)
{
return true;
} else {
return false;
}
}

void CDualLibQPBMSOSVM::register_parameters()
{
SG_ADD(&m_w, "m_w", "Weight vector", MS_NOT_AVAILABLE);
}

73 changes: 73 additions & 0 deletions src/shogun/so/DualLibQPBMSOSVM.h
@@ -0,0 +1,73 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Michal Uricar
* Copyright (C) 2012 Michal Uricar
*/

#ifndef _DUALLIBQPBMSOSVM__H__
#define _DUALLIBQPBMSOSVM__H__

#include <shogun/machine/LinearStructuredOutputMachine.h>
#include <shogun/so/RiskFunction.h>

namespace shogun
{

class CDualLibQPBMSOSVM : public CLinearStructuredOutputMachine
{
public:
/** default constructor */
CDualLibQPBMSOSVM();

/** standard constructor
*
*/
CDualLibQPBMSOSVM(CStructuredModel* model, CLossFunction* loss, CStructuredLabels* labs, CFeatures* features, float64_t lambda);

/** destructor */
~CDualLibQPBMSOSVM();

/** set lambda */
inline void set_lambda(float64_t lambda) { m_lambda=lambda; }

/** set solver options */
void set_opitons(float64_t TolRel, float64_t TolAbs, uint32_t BufSize);

protected:
/** train dual SO-SVM
*
*/
bool train_machine(CFeatures* data=NULL);

private:
/** register class parameters */
void register_parameters();

private:
/** weight vector */
SGVector< float64_t > m_w;

/** lambda */
float64_t m_lambda;

/** TolRel */
float64_t m_TolRel;

/** TolAbs */
float64_t m_TolAbs;

/** BufSize */
uint32_t m_BufSize;

/** Risk function */
CRiskFunction* m_risk_function;

}; /* class CDualLibQPBMSOSVM */

} /* namespace shogun */

#endif /* _DUALLIBQPBMSOSVM__H__ */
22 changes: 22 additions & 0 deletions src/shogun/so/RiskFunction.cpp
@@ -0,0 +1,22 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Michal Uricar
* Copyright (C) 2012 Michal Uricar
*/

#include <shogun/so/RiskFunction.h>

using namespace shogun;

CRiskFunction::CRiskFunction()
: CSGObject()
{
}

CRiskFunction::~CRiskFunction()
{
}
52 changes: 52 additions & 0 deletions src/shogun/so/RiskFunction.h
@@ -0,0 +1,52 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Michal Uricar
* Copyright (C) 2012 Michal Uricar
*/

#ifndef _RISK_FUNCTION__H__
#define _RISK_FUNCTION__H__

#include <shogun/base/SGObject.h>
//#include <shogun/features/Features.h>
//#include <shogun/labels/StructuredLabels.h>
//#include <shogun/lib/SGVector.h>

namespace shogun
{

/** @brief Class CRiskFunction TODO
*
*/
class CRiskFunction : public CSGObject
{
public:
/** default constructor */
CRiskFunction();

/** destructor */
virtual ~CRiskFunction();

/** computes the value of the risk function and sub-gradient at given point
*
*/
//virtual void risk(void* data, float64_t* R, SGVector< float64_t > subgrad, SGVector< float64_t > w) = 0;
virtual void risk(void* data, float64_t* R, float64_t* subgrad, float64_t* W) = 0;

/** get the dimension of vector w
*
*/
virtual uint32_t get_w_dim(void* data) = 0;

/** @return name of SGSerializable */
virtual const char* get_name() const { return "RiskFunction"; }

}; /* CRiskFunction */

}

#endif /* _RISK_FUNCTION__H__ */

0 comments on commit d6c2696

Please sign in to comment.