Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
BMRM libshogun example fix (runnable without cmd arguments)
  • Loading branch information
uricamic committed Aug 22, 2012
1 parent 33a0869 commit 3713c0c
Showing 1 changed file with 88 additions and 17 deletions.
105 changes: 88 additions & 17 deletions examples/undocumented/libshogun/so_multiclass_BMRM.cpp
Expand Up @@ -26,6 +26,13 @@

using namespace shogun;

#define DIMS 2
#define EPSILON 10e-5
#define NUM_SAMPLES 100
#define NUM_CLASSES 10

char FNAME[] = "data.svmlight";

/** Reads multiclass trainig data stored in svmlight format (i.e. label nz_idx_1:value1 nz_idx_2:value2 ... nz_idx_N:valueN )
*
* @param fname path to file with training data
Expand Down Expand Up @@ -70,6 +77,47 @@ void read_data(const char fname[], uint32_t DIM, uint32_t N, SGVector< float64_t
SG_UNREF(stream_features);
}

/** Generates random multiclass training data and stores them in svmlight format
*
* @param labs returned vector with labels
* @param feats returned matrix with features
*/
void gen_rand_data(SGVector< float64_t > labs, SGMatrix< float64_t > feats)
{
float64_t means[DIMS];
float64_t stds[DIMS];

FILE* pfile = fopen(FNAME, "w");

for ( int32_t c = 0 ; c < NUM_CLASSES ; ++c )
{
for ( int32_t j = 0 ; j < DIMS ; ++j )
{
means[j] = CMath::random(-100, 100);
stds[j] = CMath::random( 1, 5);
}

for ( int32_t i = 0 ; i < NUM_SAMPLES ; ++i )
{
labs[c*NUM_SAMPLES+i] = c;

fprintf(pfile, "%d", c);

for ( int32_t j = 0 ; j < DIMS ; ++j )
{
feats[(c*NUM_SAMPLES+i)*DIMS + j] =
CMath::normal_random(means[j], stds[j]);

fprintf(pfile, " %d:%f", j+1, feats[(c*NUM_SAMPLES+i)*DIMS + j]);
}

fprintf(pfile, "\n");
}
}

fclose(pfile);
}

int main(int argc, char * argv[])
{
// initialization
Expand All @@ -83,32 +131,48 @@ int main(int argc, char * argv[])

init_shogun_with_defaults();

if (argc < 8)
if (argc > 1 && argc < 8)
{
SG_SERROR("Usage: so_multiclass_BMRM <data.in> <feat_dim> <num_feat> <lambda> <icp> <epsilon> <solver> [<cp_models>]\n");
return -1;
}

SG_SPRINT("arg[1] = %s\n", argv[1]);
if (argc > 1)
{
// parse command line arguments for parameters setting

feat_dim=::atoi(argv[2]);
num_feat=::atoi(argv[3]);
lambda=::atof(argv[4]);
icp=::atoi(argv[5]);
eps=::atof(argv[6]);
SG_SPRINT("arg[1] = %s\n", argv[1]);

if (strcmp("BMRM", argv[7])==0)
solver=BMRM;
feat_dim=::atoi(argv[2]);
num_feat=::atoi(argv[3]);
lambda=::atof(argv[4]);
icp=::atoi(argv[5]);
eps=::atof(argv[6]);

if (strcmp("PPBMRM", argv[7])==0)
solver=PPBMRM;
if (strcmp("BMRM", argv[7])==0)
solver=BMRM;

if (strcmp("P3BMRM", argv[7])==0)
solver=P3BMRM;
if (strcmp("PPBMRM", argv[7])==0)
solver=PPBMRM;

if (argc > 8)
if (strcmp("P3BMRM", argv[7])==0)
solver=P3BMRM;

if (argc > 8)
{
cp_models=::atoi(argv[8]);
}
}
else
{
cp_models=::atoi(argv[8]);
// default parameters

feat_dim=DIMS;
num_feat=NUM_SAMPLES*NUM_CLASSES;
lambda=1e3;
icp=1;
eps=0.01;
solver=BMRM;
}

SGVector< float64_t >* labs=
Expand All @@ -117,8 +181,15 @@ int main(int argc, char * argv[])
SGMatrix< float64_t >* feats=
new SGMatrix< float64_t >(feat_dim, num_feat);

// read data
read_data(argv[1], feat_dim, num_feat, labs, feats);
if (argc==1)
{
gen_rand_data(*labs, *feats);
}
else
{
// read data
read_data(argv[1], feat_dim, num_feat, labs, feats);
}

// Create train labels
CMulticlassSOLabels* labels = new CMulticlassSOLabels(*labs);
Expand Down

0 comments on commit 3713c0c

Please sign in to comment.