Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Added rejection strategy handling to multiclass machine
  • Loading branch information
lisitsyn committed May 25, 2012
1 parent f96fdfa commit eac2e3c
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 42 deletions.
18 changes: 18 additions & 0 deletions src/shogun/machine/MulticlassMachine.h
Expand Up @@ -116,6 +116,24 @@ class CMulticlassMachine : public CMachine
return m_multiclass_strategy;
}

/** returns rejection strategy
*
* @return rejection strategy
*/
inline CRejectionStrategy* get_rejection_strategy() const
{
return m_multiclass_strategy->get_rejection_strategy();
}

/** sets rejection strategy
*
* @param rejection_strategy rejection strategy to be set
*/
inline void set_rejection_strategy(CRejectionStrategy* rejection_strategy)
{
m_multiclass_strategy->set_rejection_strategy(rejection_strategy);
}

/** get name */
virtual const char* get_name() const
{
Expand Down
8 changes: 1 addition & 7 deletions src/shogun/multiclass/MulticlassOneVsRestStrategy.cpp
Expand Up @@ -15,16 +15,10 @@
using namespace shogun;

CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy()
:CMulticlassStrategy(), m_rejection_strategy(NULL)
: CMulticlassStrategy()
{
}

CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy(CRejectionStrategy *rejection_strategy)
:CMulticlassStrategy(), m_rejection_strategy(rejection_strategy)
{
SG_REF(m_rejection_strategy);
}

SGVector<int32_t> CMulticlassOneVsRestStrategy::train_prepare_next()
{
for (int32_t i=0; i < m_orig_labels->get_num_labels(); ++i)
Expand Down
22 changes: 0 additions & 22 deletions src/shogun/multiclass/MulticlassOneVsRestStrategy.h
Expand Up @@ -16,35 +16,15 @@
namespace shogun
{

class CRejectionStrategy;

class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
{
public:
/** constructor */
CMulticlassOneVsRestStrategy();

/** constructor with rejection strategy */
CMulticlassOneVsRestStrategy(CRejectionStrategy *rejection_strategy);

/** destructor */
virtual ~CMulticlassOneVsRestStrategy() {}

/** get rejection strategy */
CRejectionStrategy *get_rejection_strategy()
{
SG_REF(m_rejection_strategy);
return m_rejection_strategy;
}

/** set rejection strategy */
void set_rejection_strategy(CRejectionStrategy *rejection_strategy)
{
SG_REF(rejection_strategy);
SG_UNREF(m_rejection_strategy);
m_rejection_strategy = rejection_strategy;
}

/** start training */
virtual void train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels)
{
Expand Down Expand Up @@ -80,8 +60,6 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
return "MulticlassOneVsRestStrategy";
};

protected:
CRejectionStrategy *m_rejection_strategy; ///< rejection strategy
};

} // namespace shogun
Expand Down
5 changes: 3 additions & 2 deletions src/shogun/multiclass/MulticlassStrategy.cpp
Expand Up @@ -15,9 +15,10 @@ using namespace shogun;


CMulticlassStrategy::CMulticlassStrategy()
:m_train_labels(NULL), m_orig_labels(NULL), m_train_iter(0)
: m_rejection_strategy(NULL), m_train_labels(NULL), m_orig_labels(NULL), m_train_iter(0)
{
SG_ADD(&m_num_classes, "num_classes", "Number of classes", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**)&m_rejection_strategy, "rejection_strategy", "Strategy of rejection", MS_NOT_AVAILABLE);
SG_ADD(&m_num_classes, "num_classes", "Number of classes", MS_NOT_AVAILABLE);
}

void CMulticlassStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels)
Expand Down
39 changes: 28 additions & 11 deletions src/shogun/multiclass/MulticlassStrategy.h
Expand Up @@ -34,17 +34,32 @@ class CMulticlassStrategy: public CSGObject
return "MulticlassStrategy";
};

/** set number of classes */
void set_num_classes(int32_t num_classes)
{
m_num_classes = num_classes;
}

/** get number of classes */
int32_t get_num_classes() const
{
return m_num_classes;
}
/** set number of classes */
void set_num_classes(int32_t num_classes)
{
m_num_classes = num_classes;
}

/** get number of classes */
int32_t get_num_classes() const
{
return m_num_classes;
}

/** get rejection strategy */
CRejectionStrategy *get_rejection_strategy()
{
SG_REF(m_rejection_strategy);
return m_rejection_strategy;
}

/** set rejection strategy */
void set_rejection_strategy(CRejectionStrategy *rejection_strategy)
{
SG_REF(rejection_strategy);
SG_UNREF(m_rejection_strategy);
m_rejection_strategy = rejection_strategy;
}

/** start training */
virtual void train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels);
Expand All @@ -70,6 +85,8 @@ class CMulticlassStrategy: public CSGObject
virtual int32_t get_num_machines()=0;

protected:

CRejectionStrategy* m_rejection_strategy; ///< rejection strategy
CBinaryLabels *m_train_labels; ///< labels used to train the submachines
CMulticlassLabels *m_orig_labels; ///< original multiclass labels
int32_t m_train_iter; ///< index of current iterations
Expand Down

0 comments on commit eac2e3c

Please sign in to comment.