Navigation Menu

Skip to content

Commit

Permalink
Merge pull request #310 from frx/streaming_vw
Browse files Browse the repository at this point in the history
Some VW improvements and python examples
  • Loading branch information
Soeren Sonnenburg committed Aug 24, 2011
2 parents 7e47856 + 7cf81e6 commit a90530f
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 11 deletions.
@@ -0,0 +1,49 @@
from modshogun import StreamingVwFile
from modshogun import StreamingVwCacheFile
from modshogun import T_SVMLIGHT
from modshogun import StreamingVwFeatures
from modshogun import VowpalWabbit

def create_cache():
"""Creates a binary cache from an ascii data file."""

# Open the input file as a StreamingVwFile
input_file = StreamingVwFile("../data/fm_train_sparsereal.dat")
# Default file name will be vw_cache.dat.cache
input_file.set_write_to_cache(True)

# Tell VW that the file is in SVMLight format
# Supported types are T_DENSE, T_SVMLIGHT and T_VW
input_file.set_parser_type(T_SVMLIGHT)

# Create a StreamingVwFeatures object, `True' indicating the examples are labelled
features = StreamingVwFeatures(input_file, True, 1024)

# Create a VW object from the features
vw = VowpalWabbit(features)
vw.set_no_training(True)

# Train (in this case does nothing but run over all examples)
vw.train()

def train_from_cache():
"""Train using the generated cache file."""

# Open the input cache file as a StreamingVwCacheFile
input_file = StreamingVwCacheFile("vw_cache.dat.cache");

# The rest is exactly as for normal input
features = StreamingVwFeatures(input_file, True, 1024);
vw = VowpalWabbit(features)
vw.train()

if __name__ == "__main__":
print "Creating cache..."
create_cache()
print "Done."
print

print "Training using the cache file..."
print
train_from_cache()

26 changes: 26 additions & 0 deletions examples/undocumented/python_modular/streaming_vw_modular.py
@@ -0,0 +1,26 @@
from modshogun import StreamingVwFile
from modshogun import T_SVMLIGHT
from modshogun import StreamingVwFeatures
from modshogun import VowpalWabbit

def run_vw():
"""Runs the VW algorithm on a toy dataset in SVMLight format."""

# Open the input file as a StreamingVwFile
input_file = StreamingVwFile("../data/fm_train_sparsereal.dat")

# Tell VW that the file is in SVMLight format
# Supported types are T_DENSE, T_SVMLIGHT and T_VW
input_file.set_parser_type(T_SVMLIGHT)

# Create a StreamingVwFeatures object, `True' indicating the examples are labelled
features = StreamingVwFeatures(input_file, True, 1024)

# Create a VW object from the features
vw = VowpalWabbit(features)

# Train
vw.train()

if __name__ == "__main__":
run_vw()
2 changes: 2 additions & 0 deletions src/interfaces/modular/IO.i
Expand Up @@ -27,6 +27,7 @@
%rename(SerializableXmlFile) CSerializableXmlFile;
%rename(SimpleFile) CSimpleFile;
%rename(MemoryMappedFile) CMemoryMappedFile;
%rename(VwParser) CVwParser;

%include <shogun/io/File.h>
%include <shogun/io/StreamingFile.h>
Expand Down Expand Up @@ -70,6 +71,7 @@ namespace shogun

%include <shogun/io/AsciiFile.h>
%include <shogun/io/StreamingAsciiFile.h>
%include <shogun/classifier/vw/VwParser.h>
%include <shogun/io/StreamingVwFile.h>
%include <shogun/io/StreamingVwCacheFile.h>
%include <shogun/io/BinaryFile.h>
Expand Down
1 change: 1 addition & 0 deletions src/interfaces/modular/IO_includes.i
Expand Up @@ -9,6 +9,7 @@
#include <shogun/io/StreamingFileFromSimpleFeatures.h>
#include <shogun/io/AsciiFile.h>
#include <shogun/io/StreamingAsciiFile.h>
#include <shogun/classifier/vw/VwParser.h>
#include <shogun/io/StreamingVwFile.h>
#include <shogun/io/StreamingVwCacheFile.h>
#include <shogun/io/BinaryFile.h>
Expand Down
21 changes: 13 additions & 8 deletions src/shogun/classifier/vw/VowpalWabbit.cpp
Expand Up @@ -135,18 +135,22 @@ bool CVowpalWabbit::train_machine(CFeatures* feat)
{
example = features->get_example();

if (example->pass != current_pass)
// Check if we shouldn't train (generally used for cache creation)
if (!no_training)
{
env->eta *= env->eta_decay_rate;
current_pass = example->pass;
}
if (example->pass != current_pass)
{
env->eta *= env->eta_decay_rate;
current_pass = example->pass;
}

predict_and_finalize(example);
predict_and_finalize(example);

learner->train(example, example->eta_round);
example->eta_round = 0.;
learner->train(example, example->eta_round);
example->eta_round = 0.;

output_example(example);
output_example(example);
}

features->release_example();
}
Expand Down Expand Up @@ -216,6 +220,7 @@ void CVowpalWabbit::init(CStreamingVwFeatures* feat)
SG_REF(reg);

quiet = false;
no_training = false;
dump_interval = exp(1.);
reg_name = NULL;
reg_dump_text = true;
Expand Down
13 changes: 13 additions & 0 deletions src/shogun/classifier/vw/VowpalWabbit.h
Expand Up @@ -62,6 +62,16 @@ class CVowpalWabbit: public COnlineLinearMachine
*/
void reinitialize_weights();

/**
* Set whether one desires to not train and only
* make passes over all examples instead.
*
* This is useful if one wants to create a cache file from data.
*
* @param dont_train true if one doesn't want to train
*/
void set_no_training(bool dont_train) { no_training = dont_train; }

/**
* Set whether learning is adaptive or not
*
Expand Down Expand Up @@ -265,6 +275,9 @@ class CVowpalWabbit: public COnlineLinearMachine
/// Whether to display statistics or not
bool quiet;

/// Whether we should just run over examples without training
bool no_training;

/// Multiplication factor for number of examples to dump after
float32_t dump_interval;

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/vw/VwParser.h
Expand Up @@ -61,7 +61,7 @@ class CVwParser: public CSGObject
/**
* Destructor
*/
~CVwParser();
virtual ~CVwParser();

/**
* Get the environment
Expand Down
31 changes: 31 additions & 0 deletions src/shogun/classifier/vw/cache/VwCacheReader.h
Expand Up @@ -91,6 +91,37 @@ class CVwCacheReader: public CSGObject
*/
virtual CVwEnvironment* get_env();

/**
* Update min and max labels seen in the environment
*
* @param label current label based on which to update
*/
virtual void set_mm(float64_t label)
{
env->min_label = CMath::min(env->min_label, label);
if (label != FLT_MAX)
env->max_label = CMath::max(env->max_label, label);
}

/**
* A dummy function performing no operation in case training
* is not to be performed.
*
* @param label label
*/
virtual void noop_mm(float64_t label) { }

/**
* Function which is actually called to update min and max labels
* Should be set to one of the functions implemented for this.
*
* @param label label based on which to update
*/
virtual void set_minmax(float64_t label)
{
set_mm(label);
}

/**
* Function to read one example from the cache
*
Expand Down
2 changes: 2 additions & 0 deletions src/shogun/classifier/vw/cache/VwNativeCacheReader.cpp
Expand Up @@ -97,6 +97,8 @@ char* CVwNativeCacheReader::bufread_label(VwLabel* const ld, char* c)
{
ld->label = *(float32_t*)c;
c += sizeof(ld->label);
set_minmax(ld->label);

ld->weight = *(float32_t*)c;
c += sizeof(ld->weight);
ld->initial = *(float32_t*)c;
Expand Down
7 changes: 5 additions & 2 deletions src/shogun/classifier/vw/cache/VwNativeCacheReader.h
Expand Up @@ -21,13 +21,16 @@
namespace shogun
{

/// Packed structure for efficient storage
#ifndef DOXYGEN_SHOULD_SKIP_THIS
// Packed structure for efficient storage
struct one_float
{
/// The float to store
// The float to store
float32_t f;
} __attribute__((packed));

#endif // DOXYGEN_SHOULD_SKIP_THIS

/** @brief Class CVwNativeCacheReader reads from a cache exactly as
* that which has been produced by VW's default cache format.
*
Expand Down
1 change: 1 addition & 0 deletions src/shogun/features/StreamingVwFeatures.cpp
Expand Up @@ -98,6 +98,7 @@ void CStreamingVwFeatures::setup_example(VwExample* ae)
VwFeature temp = {1,constant_hash & env->mask};
ae->indices.push(constant_namespace);
ae->atomics[constant_namespace].push(temp);
ae->sum_feat_sq[constant_namespace] = 0;

if(env->stride != 1)
{
Expand Down
1 change: 1 addition & 0 deletions src/shogun/io/StreamingVwFile.cpp
Expand Up @@ -72,4 +72,5 @@ void CStreamingVwFile::init()

set_parser_type(T_VW);
write_to_cache = false;
SG_REF(env);
}

0 comments on commit a90530f

Please sign in to comment.