Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'multiclass' of git://github.com/pluskid/shogun
- Loading branch information
Showing
7 changed files
with
245 additions
and
10 deletions.
There are no files selected for viewing
95 changes: 95 additions & 0 deletions
95
examples/undocumented/libshogun/balanced_conditional_probability_tree.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2011 Shashwat Lal Das | ||
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society | ||
* | ||
* This example demonstrates use of the Vowpal Wabbit learning algorithm. | ||
*/ | ||
|
||
#include <shogun/lib/common.h> | ||
|
||
#include <shogun/io/StreamingAsciiFile.h> | ||
#include <shogun/features/StreamingDenseFeatures.h> | ||
#include <shogun/multiclass/tree/BalancedConditionalProbabilityTree.h> | ||
|
||
using namespace shogun; | ||
|
||
int main(int argc, char **argv) | ||
{ | ||
init_shogun_with_defaults(); | ||
|
||
const char* train_file_name = "../data/7class_example4_train.dense"; | ||
const char* test_file_name = "../data/7class_example4_test.dense"; | ||
CStreamingAsciiFile* train_file = new CStreamingAsciiFile(train_file_name); | ||
SG_REF(train_file); | ||
|
||
CStreamingDenseFeatures<float32_t>* train_features = new CStreamingDenseFeatures<float32_t>(train_file, true, 1024); | ||
SG_REF(train_features); | ||
|
||
CBalancedConditionalProbabilityTree *cpt = new CBalancedConditionalProbabilityTree(); | ||
cpt->set_num_passes(1); | ||
cpt->set_features(train_features); | ||
|
||
if (argc > 1) | ||
{ | ||
float64_t alpha = 0.5; | ||
sscanf(argv[1], "%lf", &alpha); | ||
SG_SPRINT("Setting alpha to %.2lf\n", alpha); | ||
cpt->set_alpha(alpha); | ||
} | ||
|
||
cpt->train(); | ||
cpt->print_tree(); | ||
|
||
CStreamingAsciiFile* test_file = new CStreamingAsciiFile(test_file_name); | ||
SG_REF(test_file); | ||
CStreamingDenseFeatures<float32_t>* test_features = new CStreamingDenseFeatures<float32_t>(test_file, true, 1024); | ||
SG_REF(test_features); | ||
|
||
CMulticlassLabels *pred = cpt->apply_multiclass(test_features); | ||
test_features->reset_stream(); | ||
SG_SPRINT("num_labels = %d\n", pred->get_num_labels()); | ||
|
||
SG_UNREF(test_features); | ||
SG_UNREF(test_file); | ||
test_file = new CStreamingAsciiFile(test_file_name); | ||
SG_REF(test_file); | ||
test_features = new CStreamingDenseFeatures<float32_t>(test_file, true, 1024); | ||
SG_REF(test_features); | ||
|
||
CMulticlassLabels *gnd = new CMulticlassLabels(pred->get_num_labels()); | ||
test_features->start_parser(); | ||
for (int32_t i=0; i < pred->get_num_labels(); ++i) | ||
{ | ||
test_features->get_next_example(); | ||
gnd->set_int_label(i, test_features->get_label()); | ||
test_features->release_example(); | ||
} | ||
test_features->end_parser(); | ||
|
||
int32_t n_correct = 0; | ||
for (index_t i=0; i < pred->get_num_labels(); ++i) | ||
{ | ||
if (pred->get_int_label(i) == gnd->get_int_label(i)) | ||
n_correct++; | ||
//SG_SPRINT("%d-%d ", pred->get_int_label(i), gnd->get_int_label(i)); | ||
} | ||
SG_SPRINT("\n"); | ||
|
||
SG_SPRINT("Multiclass Accuracy = %.2f%%\n", 100.0*n_correct / gnd->get_num_labels()); | ||
|
||
SG_UNREF(train_features); | ||
SG_UNREF(test_features); | ||
SG_UNREF(train_file); | ||
SG_UNREF(test_file); | ||
SG_UNREF(cpt); | ||
SG_UNREF(pred); | ||
|
||
exit_shogun(); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
src/shogun/multiclass/tree/BalancedConditionalProbabilityTree.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Chiyuan Zhang | ||
* Copyright (C) 2012 Chiyuan Zhang | ||
*/ | ||
|
||
#include <shogun/multiclass/tree/BalancedConditionalProbabilityTree.h> | ||
|
||
using namespace shogun; | ||
|
||
CBalancedConditionalProbabilityTree::CBalancedConditionalProbabilityTree() | ||
:m_alpha(0.4) | ||
{ | ||
SG_ADD(&m_alpha, "m_alpha", "Trade-off parameter of tree balance", MS_NOT_AVAILABLE); | ||
} | ||
|
||
void CBalancedConditionalProbabilityTree::set_alpha(float64_t alpha) | ||
{ | ||
if (alpha < 0 || alpha > 1) | ||
SG_ERROR("expect 0 <= alpha <= 1, but got %g\n", alpha); | ||
m_alpha = alpha; | ||
} | ||
|
||
bool CBalancedConditionalProbabilityTree::which_subtree(node_t *node, SGVector<float32_t> ex) | ||
{ | ||
float64_t pred = predict_node(ex, node); | ||
float64_t depth_left = tree_depth(node->left()); | ||
float64_t depth_right = tree_depth(node->right()); | ||
|
||
float64_t cnt_left = CMath::pow(2.0, depth_left); | ||
float64_t cnt_right = CMath::pow(2.0, depth_right); | ||
|
||
float64_t obj_val = (1-m_alpha) * 2 * (pred-0.5) + m_alpha * CMath::log2(cnt_left/cnt_right); | ||
|
||
if (obj_val > 0) | ||
return false; // go right | ||
return true; // go left | ||
} | ||
|
||
int32_t CBalancedConditionalProbabilityTree::tree_depth(node_t *node) | ||
{ | ||
int32_t depth = 0; | ||
while (node != NULL) | ||
{ | ||
depth++; | ||
node = node->left(); | ||
} | ||
|
||
return depth; | ||
} |
70 changes: 70 additions & 0 deletions
70
src/shogun/multiclass/tree/BalancedConditionalProbabilityTree.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* This program is free software; you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation; either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* Written (W) 2012 Chiyuan Zhang | ||
* Copyright (C) 2012 Chiyuan Zhang | ||
*/ | ||
|
||
#ifndef BALANCEDCONDITIONALPROBABILITYTREE_H__ | ||
#define BALANCEDCONDITIONALPROBABILITYTREE_H__ | ||
|
||
#include <shogun/multiclass/tree/ConditionalProbabilityTree.h> | ||
|
||
namespace shogun | ||
{ | ||
|
||
/** | ||
* Balanced Conditional Probability Tree. | ||
* | ||
* The tree is constructed to trade-off the existing regressor's prediction | ||
* and the balance (depth) of the tree. The parameter alpha in [0,1] | ||
* control the trade-off. | ||
* | ||
* * when alpha = 1, best efforts are made to ensure the tree is balanced | ||
* * when alpha = 0, the balance of tree is complete ignored | ||
* | ||
* more balanced tree means better computational efficiency, but usually worse | ||
* performance. See the following paper for more details: | ||
* | ||
* Alina Beygelzimer, John Langford, Yuri Lifshits, Gregory Sorkin, Alex | ||
* Strehl. Conditional Probability Tree Estimation Analysis and Algorithms. UAI 2009. | ||
*/ | ||
class CBalancedConditionalProbabilityTree: public CConditionalProbabilityTree | ||
{ | ||
public: | ||
/** constructor */ | ||
CBalancedConditionalProbabilityTree(); | ||
|
||
/** destructor */ | ||
virtual ~CBalancedConditionalProbabilityTree() {} | ||
|
||
/** get name */ | ||
virtual const char* get_name() const { return "BalancedConditionalProbabilityTree"; } | ||
|
||
/** set alpha */ | ||
void set_alpha(float64_t alpha); | ||
|
||
/** get alpha */ | ||
float64_t get_alpha() const { return m_alpha; } | ||
|
||
protected: | ||
/** decide which subtree to go, when training the tree structure. | ||
* @param node the node being decided | ||
* @param ex the example being decided | ||
* @return true if should go left, false otherwise | ||
*/ | ||
virtual bool which_subtree(node_t *node, SGVector<float32_t> ex); | ||
|
||
private: | ||
int32_t tree_depth(node_t *node); | ||
|
||
float64_t m_alpha; ///< trade-off parameter for tree balance | ||
}; | ||
|
||
} /* shogun */ | ||
|
||
#endif /* end of include guard: BALANCEDCONDITIONALPROBABILITYTREE_H__ */ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters