Skip to content

Commit

Permalink
Merge pull request #308 from frx/streaming_vw
Browse files Browse the repository at this point in the history
VW application and option to save predictions to a file
  • Loading branch information
Soeren Sonnenburg committed Aug 22, 2011
2 parents 728d1de + 33bac51 commit a715fc7
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 24 deletions.
216 changes: 216 additions & 0 deletions applications/vw/vw.cpp
@@ -0,0 +1,216 @@
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <getopt.h>
#include <string.h>

#include <shogun/io/StreamingVwFile.h>
#include <shogun/io/StreamingVwCacheFile.h>
#include <shogun/features/StreamingVwFeatures.h>
#include <shogun/classifier/vw/VowpalWabbit.h>

using namespace shogun;

class Args_t
{
public:
Args_t()
{
adaptive = false;
exact_adaptive_norm = false;
use_cache_input = false;
create_cache = false;
input_file_name = NULL;
regressor_input_file_name = NULL;
regressor_output_file_name = NULL;
predictions_output_file_name = NULL;
}

public:
bool adaptive;
bool exact_adaptive_norm;
bool use_cache_input;
bool create_cache;
char* input_file_name;
char* regressor_input_file_name;
char* regressor_output_file_name;
char* predictions_output_file_name;
} Args;

static const struct option longOpts[] = {
{ "data", required_argument, NULL, 'd' },
{ "adaptive", no_argument, NULL, 'a' },
{ "exact_adaptive_norm", no_argument, NULL, 'e' },
{ "use_cache", no_argument, NULL, 'c' },
{ "create_cache", no_argument, NULL, 'C' },
{ "predictions", required_argument, NULL, 'p' },
{ "help", no_argument, NULL, 'h' }
};

static const char *optString = "d:aecCp:h";

void display_usage()
{
printf("vw - Run Vowpal Wabbit.\n\n");
printf("Supported arguments are:\n");
printf("-d <file> \t-\tName of input file.\n");
printf("-a\t-\tEnable adaptive learning.\n");
printf("-e\t-\tUse exact norm during adaptive learning.\n");
printf("-c\t-\tTry to use a cache file for input.\n");
printf("-C\t-\tCreate a cache file from data.\n");
printf("-p <file> \t-\tFile to write predictions to.\n");
printf("-h\t-\tDisplay this information.\n");

exit(1);
}

void parse_options(int argc, char** argv)
{
int opt = 0;
int longIndex;

opt = getopt_long(argc, argv, optString, longOpts, &longIndex);
while (opt != -1)
{
switch (opt)
{
case 'd':
Args.input_file_name = optarg;
printf("Input file is: %s.\n", Args.input_file_name);
break;

case 'a':
Args.adaptive = true;
printf("Using adaptive learning.\n");
break;

case 'e':
Args.adaptive = true;
Args.exact_adaptive_norm = true;
printf("Using exact adaptive norm.\n");
break;

case 'c':
Args.use_cache_input = true;
printf("Treating input as a cache file.\n");
break;

case 'C':
Args.create_cache = true;
printf("Will create a cache file from the input.\n");
break;

case 'p':
Args.predictions_output_file_name = optarg;
printf("Predictions will be saved to: %s.\n", Args.predictions_output_file_name);
break;

case 'h':
display_usage();
break;

default:
break;
}

opt = getopt_long(argc, argv, optString, longOpts, &longIndex);
}

if (! Args.input_file_name)
{
printf("Data file must be specified! (use -d <file>)\n");
exit(1);
}

if (Args.create_cache && Args.use_cache_input)
{
printf("Creating cache not supported while reading from cache input!\n");
exit(1);
}
}

void display_stats(CVowpalWabbit* vw)
{
CVwEnvironment* env = vw->get_env();
SG_REF(env);

double weighted_labeled_examples = env->weighted_examples - env->weighted_unlabeled_examples;
double best_constant = (env->weighted_labels - env->initial_t) / weighted_labeled_examples;
double constant_loss = (best_constant*(1.0 - best_constant)*(1.0 - best_constant) + (1.0 - best_constant)*best_constant*best_constant);

printf("\nFinished run.\n");
printf("Number of examples = %lld.\n", env->example_number);
printf("Weighted example sum = %f.\n", env->weighted_examples);
printf("Weighted label sum = %f.\n", env->weighted_labels);
printf("Average loss = %f.\n", env->sum_loss / env->weighted_examples);
printf("Best constant = %f.\n", best_constant);

if (env->min_label == 0. && env->max_label == 1. && best_constant < 1. && best_constant > 0.)
printf("Best constant's loss = %f.\n", constant_loss);

printf("Total feature number = %ld.\n", (long int) env->total_features);

SG_UNREF(env);
}

int main(int argc, char** argv)
{
parse_options(argc, argv);

init_shogun_with_defaults();

CStreamingVwFile* vw_file = NULL;
CStreamingVwCacheFile* vw_cache_file = NULL;
CStreamingVwFeatures* features = NULL;

if (Args.use_cache_input)
{
vw_cache_file = new CStreamingVwCacheFile(Args.input_file_name);
SG_REF(vw_cache_file);
features = new CStreamingVwFeatures(vw_cache_file, true, 1024);
SG_REF(features);
}
else
{
vw_file = new CStreamingVwFile(Args.input_file_name);
SG_REF(vw_file);
features = new CStreamingVwFeatures(vw_file, true, 1024);
SG_REF(features);
}

CVowpalWabbit* vw = new CVowpalWabbit(features);

if (Args.adaptive)
vw->set_adaptive(true);

if (Args.exact_adaptive_norm)
vw->set_exact_adaptive_norm(true);

if (Args.create_cache)
vw_file->set_write_to_cache(true);

if (Args.predictions_output_file_name)
vw->set_prediction_out(Args.predictions_output_file_name);

SG_REF(vw);
vw->train();

SG_REF(vw);
display_stats(vw);

if (Args.use_cache_input)
{
SG_UNREF(vw_cache_file);
}
else
{
SG_UNREF(vw_file);
}

SG_UNREF(features);
SG_UNREF(vw);

exit_shogun();

return 0;
}
96 changes: 75 additions & 21 deletions src/shogun/classifier/vw/VowpalWabbit.cpp
Expand Up @@ -90,6 +90,14 @@ void CVowpalWabbit::set_regressor_out(char* file_name, bool is_text)
reg_dump_text = is_text;
}

void CVowpalWabbit::set_prediction_out(char* file_name)
{
save_predictions = true;
prediction_fd = open(file_name, O_CREAT|O_TRUNC|O_WRONLY, 0666);
if (prediction_fd < 0)
SG_SERROR("Unable to open prediction file %s for writing!\n", file_name);
}

void CVowpalWabbit::add_quadratic_pair(char* pair)
{
env->pairs.push_back(pair);
Expand All @@ -108,7 +116,6 @@ bool CVowpalWabbit::train_machine(CFeatures* feat)

VwExample* example = NULL;
size_t current_pass = 0;
float32_t dump_interval = exp(1.);

const char* header_fmt = "%-10s %-10s %8s %8s %10s %8s %8s\n";

Expand Down Expand Up @@ -139,14 +146,7 @@ bool CVowpalWabbit::train_machine(CFeatures* feat)
learner->train(example, example->eta_round);
example->eta_round = 0.;

if (!quiet)
{
if (env->weighted_examples + example->ld->weight > dump_interval)
{
print_update(example);
dump_interval *= 2;
}
}
output_example(example);

features->release_example();
}
Expand Down Expand Up @@ -216,8 +216,11 @@ void CVowpalWabbit::init(CStreamingVwFeatures* feat)
SG_REF(reg);

quiet = false;
dump_interval = exp(1.);
reg_name = NULL;
reg_dump_text = true;
save_predictions = false;
prediction_fd = -1;

w = reg->weight_vectors[0];
w_dim = 1 << env->num_bits;
Expand Down Expand Up @@ -297,6 +300,27 @@ float32_t CVowpalWabbit::finalize_prediction(float32_t ret)
return ret;
}

void CVowpalWabbit::output_example(VwExample* &example)
{
if (!quiet)
{
if (env->weighted_examples + example->ld->weight > dump_interval)
{
print_update(example);
dump_interval *= 2;
}
}

if (save_predictions)
{
float32_t wt = 0.;
if (reg->weight_vectors)
wt = reg->weight_vectors[0][0];

output_prediction(prediction_fd, example->final_prediction, wt * example->global_weight, example->tag);
}
}

void CVowpalWabbit::print_update(VwExample* &ex)
{
SG_SPRINT("%-10.6f %-10.6f %8lld %8.1f %8.4f %8.4f %8lu\n",
Expand All @@ -309,6 +333,37 @@ void CVowpalWabbit::print_update(VwExample* &ex)
(long unsigned int)ex->num_features);
}


void CVowpalWabbit::output_prediction(int32_t f, float32_t res, float32_t weight, v_array<char> tag)
{
if (f >= 0)
{
char temp[30];
int32_t num = sprintf(temp, "%f", res);
ssize_t t;
t = write(f, temp, num);
if (t != num)
SG_SERROR("Write error!\n");

if (tag.begin != tag.end)
{
temp[0] = ' ';
t = write(f, temp, 1);
if (t != 1)
SG_SERROR("Write error!\n");

t = write(f, tag.begin, sizeof(char)*tag.index());
if (t != (ssize_t) (sizeof(char)*tag.index()))
SG_SERROR("Write error!\n");
}

temp[0] = '\n';
t = write(f, temp, 1);
if (t != 1)
SG_SERROR("Write error!\n");
}
}

float32_t CVowpalWabbit::compute_exact_norm(VwExample* &ex, float32_t& sum_abs_x)
{
// We must traverse the features in _precisely_ the same order as during training.
Expand Down Expand Up @@ -349,16 +404,15 @@ float32_t CVowpalWabbit::compute_exact_norm(VwExample* &ex, float32_t& sum_abs_x
float32_t CVowpalWabbit::compute_exact_norm_quad(float32_t* weights, VwFeature& page_feature, v_array<VwFeature> &offer_features,
size_t mask, float32_t g, float32_t& sum_abs_x)
{
size_t halfhash = quadratic_constant * page_feature.weight_index;
float32_t xGx = 0.;
float32_t update2 = g * page_feature.x * page_feature.x;
for (VwFeature* elem = offer_features.begin; elem != offer_features.end; elem++)
{
float32_t* w_vec = &weights[(halfhash + elem->weight_index) & mask];
float32_t t = elem->x * CMath::invsqrt(w_vec[1] + update2 * elem->x * elem->x);
xGx += t * elem->x;
sum_abs_x += fabsf(elem->x);
}
return xGx;
size_t halfhash = quadratic_constant * page_feature.weight_index;
float32_t xGx = 0.;
float32_t update2 = g * page_feature.x * page_feature.x;
for (VwFeature* elem = offer_features.begin; elem != offer_features.end; elem++)
{
float32_t* w_vec = &weights[(halfhash + elem->weight_index) & mask];
float32_t t = elem->x * CMath::invsqrt(w_vec[1] + update2 * elem->x * elem->x);
xGx += t * elem->x;
sum_abs_x += fabsf(elem->x);
}
return xGx;
}

0 comments on commit a715fc7

Please sign in to comment.