Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
[WIP] restructure latent solvers
- Loading branch information
Showing
10 changed files
with
397 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Viktor Gal | ||
* Copyright (C) 2012 Viktor Gal | ||
*/ | ||
|
||
#include <shogun/latent/LatentRiskFunction.h> | ||
|
||
using namespace shogun; | ||
|
||
CLatentRiskFunction::CLatentRiskFunction() | ||
: CRiskFunction() | ||
{ | ||
} | ||
|
||
CLatentRiskFunction::~CLatentRiskFunction() | ||
{ | ||
|
||
} | ||
|
||
void CLatentRiskFunction::risk(void* data, float64_t* R, float64_t* subgrad, float64_t* W) | ||
{ | ||
ASSERT(data != NULL); | ||
ASSERT(R != NULL); | ||
ASSERT(subgrad != NULL); | ||
ASSERT(W != NULL); | ||
|
||
*R = 0; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Viktor Gal | ||
* Copyright (C) 2012 Viktor Gal | ||
*/ | ||
|
||
#ifndef __LATENTRISK_FUNCTION_H__ | ||
#define __LATENTRISK_FUNCTION_H__ | ||
|
||
#include <shogun/structure/RiskFunction.h> | ||
|
||
namespace shogun | ||
{ | ||
/** @brief: Calculates the risk function for Latent Structural SVM | ||
* | ||
* \sum_{i=1)^n \max_{\hat{y},\hat{h} \in YxH}{\mathbf{w} \cdot \Psi(x_i, \hat{y}, \hat{h})+\delta{y_i, \hat{y}, \hat{h}} | ||
* - \sum_{i=1)^n \mathbf{w} \cdot \Psi(x_i, y_i, h^*_i) | ||
* | ||
* For more details see [1] | ||
* [1] C.-N. J. Yu and T. Joachims, | ||
* “Learning structural SVMs with latent variables,” | ||
* presented at the Proceedings of the 26th Annual International Conference on Machine Learning, | ||
* New York, NY, USA, 2009, pp. 1169–1176. | ||
* http://www.cs.cornell.edu/~cnyu/papers/icml09_latentssvm.pdf | ||
* | ||
*/ | ||
class CLatentRiskFunction: public CRiskFunction | ||
{ | ||
public: | ||
/** default constructor */ | ||
CLatentRiskFunction(); | ||
|
||
/** destructor */ | ||
virtual ~CLatentRiskFunction(); | ||
|
||
/** computes the value of the risk function and sub-gradient at given point | ||
* | ||
* @param data | ||
* @param R | ||
* @param subgrad | ||
* @param w | ||
*/ | ||
virtual void risk(void* data, float64_t* R, float64_t* subgrad, float64_t* W); | ||
|
||
/** @return name of SGSerializable */ | ||
virtual const char* get_name() const { return "LatentRiskFunction"; } | ||
}; | ||
} | ||
|
||
#endif /* __LATENTSORISK_H__ */ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Viktor Gal | ||
* Copyright (C) 2012 Viktor Gal | ||
*/ | ||
|
||
#include <shogun/latent/LatentSOSVM.h> | ||
#include <shogun/latent/LatentRiskFunction.h> | ||
#include <shogun/structure/DualLibQPBMSOSVM.h> | ||
|
||
using namespace shogun; | ||
|
||
CLatentSOSVM::CLatentSOSVM() | ||
: CLinearLatentMachine() | ||
{ | ||
register_parameters(); | ||
} | ||
|
||
CLatentSOSVM::CLatentSOSVM(CLatentModel* model, CLinearStructuredOutputMachine* so_solver, float64_t C) | ||
: CLinearLatentMachine(model, C) | ||
{ | ||
register_parameters(); | ||
set_so_solver(so_solver); | ||
} | ||
|
||
CLatentSOSVM::~CLatentSOSVM() | ||
{ | ||
SG_UNREF(m_so_solver); | ||
} | ||
|
||
CLatentLabels* CLatentSOSVM::apply(CFeatures* data) | ||
{ | ||
|
||
|
||
return NULL; | ||
} | ||
|
||
void CLatentSOSVM::set_so_solver(CLinearStructuredOutputMachine* so) | ||
{ | ||
SG_UNREF(m_so_solver); | ||
SG_REF(so); | ||
m_so_solver = so; | ||
} | ||
|
||
float64_t CLatentSOSVM::do_inner_loop(float64_t cooling_eps) | ||
{ | ||
float64_t lambda = 1/m_C; | ||
CLatentRiskFunction* risk = new CLatentRiskFunction(); | ||
CDualLibQPBMSOSVM* so = new CDualLibQPBMSOSVM(NULL, NULL, NULL, NULL, lambda, risk); | ||
so->train(); | ||
|
||
/* copy the resulting w */ | ||
SGVector<float64_t> cur_w = so->get_w(); | ||
memcpy(w.vector, cur_w.vector, cur_w.vlen*sizeof(float64_t)); | ||
|
||
/* get the primal objective value */ | ||
float64_t po = so->get_bmrm_result().Fp; | ||
|
||
SG_UNREF(risk); | ||
SG_UNREF(so); | ||
|
||
return po; | ||
} | ||
|
||
void CLatentSOSVM::register_parameters() | ||
{ | ||
m_parameters->add((CSGObject**)&m_so_solver, "so_solver", "Structured Output Solver."); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Viktor Gal | ||
* Copyright (C) 2012 Viktor Gal | ||
*/ | ||
|
||
#ifndef __LATENTSOSVM_H__ | ||
#define __LATENTSOSVM_H__ | ||
|
||
#include <shogun/machine/LinearLatentMachine.h> | ||
#include <shogun/machine/LinearStructuredOutputMachine.h> | ||
|
||
namespace shogun | ||
{ | ||
/** | ||
* @brief TODO | ||
* | ||
*/ | ||
class CLatentSOSVM: public CLinearLatentMachine | ||
{ | ||
public: | ||
/** default ctor*/ | ||
CLatentSOSVM(); | ||
|
||
/** | ||
* | ||
* @param model | ||
* @param so_solver | ||
* @param C | ||
*/ | ||
CLatentSOSVM(CLatentModel* model, CLinearStructuredOutputMachine* so_solver, float64_t C); | ||
|
||
virtual ~CLatentSOSVM(); | ||
|
||
/** apply linear machine to data | ||
* | ||
* @param data (test)data to be classified | ||
* @return classified labels | ||
*/ | ||
virtual CLatentLabels* apply(CFeatures* data); | ||
|
||
void set_so_solver(CLinearStructuredOutputMachine* so); | ||
|
||
/** Returns the name of the SGSerializable instance. | ||
* | ||
* @return name of the SGSerializable | ||
*/ | ||
virtual const char* get_name() const { return "LatentSOSVM"; } | ||
|
||
protected: | ||
virtual float64_t do_inner_loop(float64_t cooling_eps); | ||
|
||
private: | ||
void register_parameters(); | ||
|
||
private: | ||
/** Linear Structured Solver */ | ||
CLinearStructuredOutputMachine* m_so_solver; | ||
}; | ||
} | ||
|
||
#endif /* __LATENTSOSVM_H__ */ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Viktor Gal | ||
* Copyright (C) 2012 Viktor Gal | ||
*/ | ||
|
||
#include <typeinfo> | ||
|
||
#include <shogun/classifier/svm/SVMOcas.h> | ||
#include <shogun/latent/LatentSVM.h> | ||
|
||
using namespace shogun; | ||
|
||
CLatentSVM::CLatentSVM() | ||
: CLinearLatentMachine() | ||
{ | ||
} | ||
|
||
CLatentSVM::CLatentSVM(CLatentModel* model, float64_t C) | ||
: CLinearLatentMachine(model, C) | ||
{ | ||
} | ||
|
||
CLatentSVM::~CLatentSVM() | ||
{ | ||
} | ||
|
||
float64_t CLatentSVM::do_inner_loop(float64_t cooling_eps) | ||
{ | ||
CLatentLabels* labels = m_model->get_labels(); | ||
CSVMOcas svm(m_C, features, labels); | ||
svm.set_epsilon(cooling_eps); | ||
svm.train(); | ||
SG_UNREF(labels); | ||
|
||
/* copy the resulting w */ | ||
SGVector<float64_t> cur_w = svm.get_w(); | ||
memcpy(w.vector, cur_w.vector, cur_w.vlen*sizeof(float64_t)); | ||
|
||
return svm.compute_primal_objective(); | ||
} | ||
|
Oops, something went wrong.