Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Introduced rejection strategies for MC machines
  • Loading branch information
lisitsyn committed Mar 3, 2012
1 parent b8e87c3 commit b3f83a0
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 19 deletions.
3 changes: 3 additions & 0 deletions src/interfaces/modular/Classifier.i
Expand Up @@ -65,6 +65,8 @@
%rename(KernelMulticlassMachine) CKernelMulticlassMachine;
%rename(LinearMulticlassMachine) CLinearMulticlassMachine;
%rename(MulticlassLibLinear) CMulticlassLibLinear;
%rename(RejectionStrategy) CRejectionStrategy;
%rename(ThresholdReject) CThresholdReject;

/* These functions return new Objects */
%newobject apply();
Expand Down Expand Up @@ -118,6 +120,7 @@
%include <shogun/machine/multiclass/MulticlassMachine.h>
%include <shogun/machine/multiclass/KernelMulticlassMachine.h>
%include <shogun/machine/multiclass/LinearMulticlassMachine.h>
%include <shogun/machine/multiclass/RejectionStrategy.h>
%include <shogun/classifier/svm/MulticlassLibLinear.h>

#ifdef USE_SVMLIGHT
Expand Down
1 change: 1 addition & 0 deletions src/interfaces/modular/Classifier_includes.i
Expand Up @@ -43,6 +43,7 @@
#include <shogun/machine/multiclass/KernelMulticlassMachine.h>
#include <shogun/machine/multiclass/LinearMulticlassMachine.h>
#include <shogun/classifier/svm/MulticlassLibLinear.h>
#include <shogun/machine/multiclass/RejectionStrategy.h>
#ifdef USE_SVMLIGHT
#include <shogun/classifier/svm/SVMLight.h>
#include <shogun/classifier/svm/SVMLightOneClass.h>
Expand Down
4 changes: 4 additions & 0 deletions src/shogun/features/Labels.h
Expand Up @@ -220,6 +220,10 @@ class CLabels : public CSGObject
*/
index_t subset_idx_conversion(index_t idx) const;

public:

static const int32_t REJECTION_LABEL = -2;

private:
void init();

Expand Down
42 changes: 26 additions & 16 deletions src/shogun/machine/multiclass/MulticlassMachine.cpp
Expand Up @@ -10,19 +10,20 @@

#include <shogun/machine/multiclass/MulticlassMachine.h>
#include <shogun/base/Parameter.h>
#include <shogun/features/Labels.h>

using namespace shogun;

CMulticlassMachine::CMulticlassMachine()
: CMachine(), m_multiclass_strategy(ONE_VS_REST_STRATEGY)
: CMachine(), m_multiclass_strategy(ONE_VS_REST_STRATEGY), m_rejection_strategy(NULL)
{
init();
}

CMulticlassMachine::CMulticlassMachine(
EMulticlassStrategy strategy,
CMachine* machine, CLabels* labs)
: CMachine(), m_multiclass_strategy(strategy), m_machine(machine)
: CMachine(), m_multiclass_strategy(strategy), m_machine(machine), m_rejection_strategy(NULL)
{
set_labels(labs);
SG_REF(machine);
Expand Down Expand Up @@ -118,9 +119,9 @@ bool CMulticlassMachine::train_one_vs_rest()

CLabels* CMulticlassMachine::classify_one_vs_rest()
{
int32_t m_num_classes = labels->get_num_classes();
int32_t m_num_machines = get_num_machines();
ASSERT(m_num_machines==m_num_classes);
int32_t num_classes = labels->get_num_classes();
int32_t num_machines = get_num_machines();
ASSERT(num_machines==num_classes);
CLabels* result=NULL;

if (is_ready())
Expand All @@ -131,33 +132,42 @@ CLabels* CMulticlassMachine::classify_one_vs_rest()
SG_REF(result);

ASSERT(num_vectors==result->get_num_labels());
CLabels** outputs=SG_MALLOC(CLabels*, m_num_machines);
CLabels** outputs=SG_MALLOC(CLabels*, num_machines);

for (int32_t i=0; i<m_num_machines; i++)
for (int32_t i=0; i<num_machines; i++)
{
ASSERT(m_machines[i]);
outputs[i]=m_machines[i]->apply();
}

SGVector<float64_t> outputs_for_i(num_machines);
for (int32_t i=0; i<num_vectors; i++)
{
int32_t winner=0;
float64_t max_out=outputs[0]->get_label(i);
int32_t winner = 0;
float64_t max_out = outputs[0]->get_label(i);

for (int32_t j=1; j<m_num_machines; j++)
{
float64_t out=outputs[j]->get_label(i);
for (int32_t j=0; j<num_machines; j++)
outputs_for_i[j] = outputs[j]->get_label(i);

if (out>max_out)
if (m_rejection_strategy && m_rejection_strategy->reject(outputs_for_i))
{
winner=result->REJECTION_LABEL;
}
else
{
for (int32_t j=1; j<num_machines; j++)
{
winner=j;
max_out=out;
if (outputs_for_i[j]>max_out)
{
max_out = outputs_for_i[j];
winner = j;
}
}
}
result->set_label(i, winner);
}

for (int32_t i=0; i<m_num_machines; i++)
for (int32_t i=0; i<num_machines; i++)
SG_UNREF(outputs[i]);

SG_FREE(outputs);
Expand Down
19 changes: 16 additions & 3 deletions src/shogun/machine/multiclass/MulticlassMachine.h
Expand Up @@ -12,12 +12,14 @@
#define _MULTICLASSMACHINE_H___

#include <shogun/machine/Machine.h>
#include <shogun/machine/multiclass/RejectionStrategy.h>

namespace shogun
{

class CFeatures;
class CLabels;
class CRejectionStrategy;

#ifndef DOXYGEN_SHOULD_SKIP_THIS
enum EMulticlassStrategy
Expand Down Expand Up @@ -117,13 +119,24 @@ class CMulticlassMachine : public CMachine
*/
inline EMulticlassStrategy get_multiclass_strategy() const
{
return m_multiclass_strategy;
return m_multiclass_strategy;
}

/** get rejection strategy */
inline CRejectionStrategy* get_rejection_strategy() const
{
return m_rejection_strategy;
}
/** set rejection strategy */
inline void set_rejection_strategy(CRejectionStrategy* rejection_strategy)
{
m_rejection_strategy = rejection_strategy;
}

/** get name */
virtual const char* get_name() const
{
return "MulticlassMachine";
return "MulticlassMachine";
}

protected:
Expand Down Expand Up @@ -162,7 +175,7 @@ class CMulticlassMachine : public CMachine
SGVector<CMachine*> m_machines;

/** rejection strategy */
//CRejectionStrategy* m_rejection_strategy;
CRejectionStrategy* m_rejection_strategy;
};
}
#endif
76 changes: 76 additions & 0 deletions src/shogun/machine/multiclass/RejectionStrategy.h
@@ -0,0 +1,76 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Sergey Lisitsyn
* Copyright (C) 2012 Sergey Lisitsyn
*/

#ifndef _REJECTIONSTRATEGY_H___
#define _REJECTIONSTRATEGY_H___

namespace shogun
{

/** @brief rejection strategy */
class CRejectionStrategy : public CSGObject
{
public:
/** default constructor */
CRejectionStrategy() { };

/** destructor */
virtual ~CRejectionStrategy() { };

/** get name */
virtual const char* get_name() const
{
return "RejectionStrategy";
};

/** returns true if given output set leads to rejection */
virtual bool reject(SGVector<float64_t> outputs) const = 0;

};

class CThresholdReject : public CRejectionStrategy
{
public:

/** constructor */
CThresholdReject() :
CRejectionStrategy(), m_threshold(0.0) { };

/** constructor */
CThresholdReject(float64_t threshold) :
CRejectionStrategy(), m_threshold(threshold) { };

virtual ~CThresholdReject() {};

/** get name */
virtual const char* get_name() const
{
return "AllNegativesMulticlassReject";
}

/** returns true if given output set leads to rejection */
virtual bool reject(SGVector<float64_t> outputs) const
{
for (int32_t i=0; i<outputs.vlen; i++)
{
if (outputs[i]>m_threshold)
return false;
}
return true;
}

protected:

float64_t m_threshold;


};
}
#endif

0 comments on commit b3f83a0

Please sign in to comment.