Skip to content

Commit

Permalink
Merge branch 'multiclass' of git://github.com/pluskid/shogun
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Jun 11, 2012
2 parents 1532e15 + c3a67a6 commit f5163e8
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 10 deletions.
@@ -0,0 +1,95 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2011 Shashwat Lal Das
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
*
* This example demonstrates use of the Vowpal Wabbit learning algorithm.
*/

#include <shogun/lib/common.h>

#include <shogun/io/StreamingAsciiFile.h>
#include <shogun/features/StreamingDenseFeatures.h>
#include <shogun/multiclass/tree/BalancedConditionalProbabilityTree.h>

using namespace shogun;

int main(int argc, char **argv)
{
init_shogun_with_defaults();

const char* train_file_name = "../data/7class_example4_train.dense";
const char* test_file_name = "../data/7class_example4_test.dense";
CStreamingAsciiFile* train_file = new CStreamingAsciiFile(train_file_name);
SG_REF(train_file);

CStreamingDenseFeatures<float32_t>* train_features = new CStreamingDenseFeatures<float32_t>(train_file, true, 1024);
SG_REF(train_features);

CBalancedConditionalProbabilityTree *cpt = new CBalancedConditionalProbabilityTree();
cpt->set_num_passes(1);
cpt->set_features(train_features);

if (argc > 1)
{
float64_t alpha = 0.5;
sscanf(argv[1], "%lf", &alpha);
SG_SPRINT("Setting alpha to %.2lf\n", alpha);
cpt->set_alpha(alpha);
}

cpt->train();
cpt->print_tree();

CStreamingAsciiFile* test_file = new CStreamingAsciiFile(test_file_name);
SG_REF(test_file);
CStreamingDenseFeatures<float32_t>* test_features = new CStreamingDenseFeatures<float32_t>(test_file, true, 1024);
SG_REF(test_features);

CMulticlassLabels *pred = cpt->apply_multiclass(test_features);
test_features->reset_stream();
SG_SPRINT("num_labels = %d\n", pred->get_num_labels());

SG_UNREF(test_features);
SG_UNREF(test_file);
test_file = new CStreamingAsciiFile(test_file_name);
SG_REF(test_file);
test_features = new CStreamingDenseFeatures<float32_t>(test_file, true, 1024);
SG_REF(test_features);

CMulticlassLabels *gnd = new CMulticlassLabels(pred->get_num_labels());
test_features->start_parser();
for (int32_t i=0; i < pred->get_num_labels(); ++i)
{
test_features->get_next_example();
gnd->set_int_label(i, test_features->get_label());
test_features->release_example();
}
test_features->end_parser();

int32_t n_correct = 0;
for (index_t i=0; i < pred->get_num_labels(); ++i)
{
if (pred->get_int_label(i) == gnd->get_int_label(i))
n_correct++;
//SG_SPRINT("%d-%d ", pred->get_int_label(i), gnd->get_int_label(i));
}
SG_SPRINT("\n");

SG_SPRINT("Multiclass Accuracy = %.2f%%\n", 100.0*n_correct / gnd->get_num_labels());

SG_UNREF(train_features);
SG_UNREF(test_features);
SG_UNREF(train_file);
SG_UNREF(test_file);
SG_UNREF(cpt);
SG_UNREF(pred);

exit_shogun();

return 0;
}
Expand Up @@ -34,6 +34,7 @@ int main()
cpt->set_num_passes(1);
cpt->set_features(train_features);
cpt->train();
cpt->print_tree();

CStreamingAsciiFile* test_file = new CStreamingAsciiFile(test_file_name);
SG_REF(test_file);
Expand All @@ -42,7 +43,7 @@ int main()

CMulticlassLabels *pred = cpt->apply_multiclass(test_features);
test_features->reset_stream();
printf("num_labels = %d\n", pred->get_num_labels());
SG_SPRINT("num_labels = %d\n", pred->get_num_labels());

SG_UNREF(test_features);
SG_UNREF(test_file);
Expand All @@ -66,11 +67,11 @@ int main()
{
if (pred->get_int_label(i) == gnd->get_int_label(i))
n_correct++;
//printf("%d-%d ", pred->get_int_label(i), gnd->get_int_label(i));
//SG_SPRINT("%d-%d ", pred->get_int_label(i), gnd->get_int_label(i));
}
printf("\n");
SG_SPRINT("\n");

printf("Multiclass Accuracy = %.2f%%\n", 100.0*n_correct / gnd->get_num_labels());
SG_SPRINT("Multiclass Accuracy = %.2f%%\n", 100.0*n_correct / gnd->get_num_labels());

SG_UNREF(train_features);
SG_UNREF(test_features);
Expand Down
54 changes: 54 additions & 0 deletions src/shogun/multiclass/tree/BalancedConditionalProbabilityTree.cpp
@@ -0,0 +1,54 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Chiyuan Zhang
* Copyright (C) 2012 Chiyuan Zhang
*/

#include <shogun/multiclass/tree/BalancedConditionalProbabilityTree.h>

using namespace shogun;

CBalancedConditionalProbabilityTree::CBalancedConditionalProbabilityTree()
:m_alpha(0.4)
{
SG_ADD(&m_alpha, "m_alpha", "Trade-off parameter of tree balance", MS_NOT_AVAILABLE);
}

void CBalancedConditionalProbabilityTree::set_alpha(float64_t alpha)
{
if (alpha < 0 || alpha > 1)
SG_ERROR("expect 0 <= alpha <= 1, but got %g\n", alpha);
m_alpha = alpha;
}

bool CBalancedConditionalProbabilityTree::which_subtree(node_t *node, SGVector<float32_t> ex)
{
float64_t pred = predict_node(ex, node);
float64_t depth_left = tree_depth(node->left());
float64_t depth_right = tree_depth(node->right());

float64_t cnt_left = CMath::pow(2.0, depth_left);
float64_t cnt_right = CMath::pow(2.0, depth_right);

float64_t obj_val = (1-m_alpha) * 2 * (pred-0.5) + m_alpha * CMath::log2(cnt_left/cnt_right);

if (obj_val > 0)
return false; // go right
return true; // go left
}

int32_t CBalancedConditionalProbabilityTree::tree_depth(node_t *node)
{
int32_t depth = 0;
while (node != NULL)
{
depth++;
node = node->left();
}

return depth;
}
70 changes: 70 additions & 0 deletions src/shogun/multiclass/tree/BalancedConditionalProbabilityTree.h
@@ -0,0 +1,70 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Chiyuan Zhang
* Copyright (C) 2012 Chiyuan Zhang
*/

#ifndef BALANCEDCONDITIONALPROBABILITYTREE_H__
#define BALANCEDCONDITIONALPROBABILITYTREE_H__

#include <shogun/multiclass/tree/ConditionalProbabilityTree.h>

namespace shogun
{

/**
* Balanced Conditional Probability Tree.
*
* The tree is constructed to trade-off the existing regressor's prediction
* and the balance (depth) of the tree. The parameter alpha in [0,1]
* control the trade-off.
*
* * when alpha = 1, best efforts are made to ensure the tree is balanced
* * when alpha = 0, the balance of tree is complete ignored
*
* more balanced tree means better computational efficiency, but usually worse
* performance. See the following paper for more details:
*
* Alina Beygelzimer, John Langford, Yuri Lifshits, Gregory Sorkin, Alex
* Strehl. Conditional Probability Tree Estimation Analysis and Algorithms. UAI 2009.
*/
class CBalancedConditionalProbabilityTree: public CConditionalProbabilityTree
{
public:
/** constructor */
CBalancedConditionalProbabilityTree();

/** destructor */
virtual ~CBalancedConditionalProbabilityTree() {}

/** get name */
virtual const char* get_name() const { return "BalancedConditionalProbabilityTree"; }

/** set alpha */
void set_alpha(float64_t alpha);

/** get alpha */
float64_t get_alpha() const { return m_alpha; }

protected:
/** decide which subtree to go, when training the tree structure.
* @param node the node being decided
* @param ex the example being decided
* @return true if should go left, false otherwise
*/
virtual bool which_subtree(node_t *node, SGVector<float32_t> ex);

private:
int32_t tree_depth(node_t *node);

float64_t m_alpha; ///< trade-off parameter for tree balance
};

} /* shogun */

#endif /* end of include guard: BALANCEDCONDITIONALPROBABILITYTREE_H__ */

12 changes: 8 additions & 4 deletions src/shogun/multiclass/tree/ConditionalProbabilityTree.cpp
Expand Up @@ -137,11 +137,17 @@ bool CConditionalProbabilityTree::train_machine(CFeatures* data)
SG_UNREF(lll);
}

m_root->debug_print(ConditionalProbabilityTreeNodeData::print);

return true;
}

void CConditionalProbabilityTree::print_tree()
{
if (m_root)
m_root->debug_print(ConditionalProbabilityTreeNodeData::print);
else
printf("Empty Tree\n");
}

void CConditionalProbabilityTree::train_example(SGVector<float32_t> ex, int32_t label)
{
if (m_root == NULL)
Expand Down Expand Up @@ -188,14 +194,12 @@ void CConditionalProbabilityTree::train_example(SGVector<float32_t> ex, int32_t
mch->start_train();
m_machines->push_back(mch);
left_node->machine(m_machines->get_num_elements()-1);
printf(" insert %d %p\n", left_node->data.label, left_node);
m_leaves.insert(make_pair(left_node->data.label, left_node));
node->left(left_node);

node_t *right_node = new node_t();
right_node->data.label = label;
right_node->machine(create_machine(ex));
printf(" insert %d %p\n", label, right_node);
m_leaves.insert(make_pair(label, right_node));
node->right(right_node);
}
Expand Down
13 changes: 12 additions & 1 deletion src/shogun/multiclass/tree/ConditionalProbabilityTree.h
Expand Up @@ -28,10 +28,18 @@ struct ConditionalProbabilityTreeNodeData

static void print(const ConditionalProbabilityTreeNodeData &data)
{
printf("label=%d\n", data.label);
SG_SPRINT("label=%d\n", data.label);
}
};

/**
* Conditional Probability Tree.
*
* See reference:
*
* Alina Beygelzimer, John Langford, Yuri Lifshits, Gregory Sorkin, Alex
* Strehl. Conditional Probability Tree Estimation Analysis and Algorithms. UAI 2009.
*/
class CConditionalProbabilityTree: public CTreeMachine<ConditionalProbabilityTreeNodeData>
{
public:
Expand Down Expand Up @@ -78,6 +86,9 @@ class CConditionalProbabilityTree: public CTreeMachine<ConditionalProbabilityTre
* @param ex a vector to be applied
*/
virtual int32_t apply_multiclass_example(SGVector<float32_t> ex);

/** print the tree structure for debug purpose */
void print_tree();
protected:
/** the labels will be embedded in the streaming features */
virtual bool train_require_labels() const { return false; }
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/tree/TreeMachineNode.h
Expand Up @@ -122,7 +122,7 @@ class CTreeMachineNode
static void debug_print_impl(data_print_func_t data_print_func, CTreeMachineNode<data_t> *node, int32_t depth)
{
for (int32_t i=0; i < depth; ++i)
printf(" ");
SG_SPRINT(" ");
data_print_func(node->data);
if (node->left())
debug_print_impl(data_print_func, node->left(), depth+1);
Expand Down

0 comments on commit f5163e8

Please sign in to comment.