Skip to content

Commit

Permalink
Merge pull request #711 from vigsterkr/utest
Browse files Browse the repository at this point in the history
Fix gaussian_generator in DataGenertor
  • Loading branch information
karlnapf committed Aug 13, 2012
2 parents be9da8e + 6cbc4ee commit f1dc712
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 11 deletions.
13 changes: 9 additions & 4 deletions src/shogun/features/DataGenerator.cpp
Expand Up @@ -84,19 +84,24 @@ SGMatrix<float64_t> CDataGenerator::generate_gaussians(index_t m, index_t n, ind
SGMatrix<float64_t> result =
SGMatrix<float64_t>::get_allocated_matrix(dim, n*m);


float64_t grid_distance = 5.0;
for (index_t i = 0; i < n; ++i)
{
SGVector<float64_t> mean(dim);
SGMatrix<float64_t> cov = SGMatrix<float64_t>::create_identity_matrix(dim, 1.0);

mean.random(0.0, 10.0);

mean.zero();
for (index_t k = 0; k < dim; ++k)
{
mean[k] = (i+1)*grid_distance;
if (k % (i+1) == 0)
mean[k] *= -1;
}
CGaussian* g = new CGaussian(mean, cov, DIAG);
for (index_t j = 0; j < m; ++j)
{
SGVector<float64_t> v = g->sample();
memcpy(result.matrix+((i+j)*result.num_rows), v.vector, dim*sizeof(float64_t));
memcpy((result.matrix+j*result.num_rows+i*m*dim), v.vector, dim*sizeof(float64_t));
}

SG_UNREF(g);
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/base/main_unittest.cc
Expand Up @@ -9,7 +9,7 @@ using ::testing::Test;
int main(int argc, char** argv)
{
::testing::InitGoogleMock(&argc, argv);
init_shogun();
init_shogun_with_defaults();
int ret = RUN_ALL_TESTS();
exit_shogun();

Expand Down
7 changes: 1 addition & 6 deletions tests/unit/classifier/svm/SVMOcas_unittest.cc
@@ -1,18 +1,13 @@
#include <shogun/classifier/svm/SVMOcas.h>
#include <shogun/features/DataGenerator.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/evaluation/ROCEvaluation.h>
#include <gtest/gtest.h>

#include <iostream>

using namespace shogun;

TEST(SVMOcasTest,train)
{
index_t num_samples = 50;
float64_t dist = 1.0, angle = 0.0;
CMath::init_random(5);
SGMatrix<float64_t> data =
CDataGenerator::generate_gaussians(num_samples, 2, 2);
CDenseFeatures<float64_t> features(data);
Expand All @@ -39,7 +34,7 @@ TEST(SVMOcasTest,train)

CLabels* pred = ocas->apply(test_feats);
for (int i = 0; i < num_samples; ++i)
EXPECT_EQ(((CBinaryLabels*)pred)->get_int_label(i), ((CBinaryLabels)labels).get_int_label(i));
EXPECT_EQ(ground_truth->get_int_label(i), ((CBinaryLabels*)pred)->get_int_label(i));

SG_UNREF(ocas);
SG_UNREF(train_feats);
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/multiclass/MulticlassOCAS_unittest.cc
@@ -0,0 +1,50 @@
#include <shogun/multiclass/MulticlassOCAS.h>
#include <shogun/features/DataGenerator.h>
#include <shogun/features/DenseFeatures.h>
#include <gtest/gtest.h>

using namespace shogun;

TEST(MulticlassOCASTest,train)
{
float64_t C = 1.0;
index_t num_samples = 50, num_gauss = 3, dim = 3;
CMath::init_random(5);
SGMatrix<float64_t> data =
CDataGenerator::generate_gaussians(num_samples, num_gauss, dim);
CDenseFeatures<float64_t> features(data);

index_t set_size = data.num_cols/2;
SGVector<index_t> train_idx(set_size), test_idx(set_size);
SGVector<float64_t> labels(set_size);
for (index_t i = 0, j = 0; i < data.num_cols; ++i)
{
if (i % 2 == 0)
train_idx[j] = i;
else
test_idx[j++] = i;

if (i < data.num_cols/num_gauss)
labels[i/2] = 0.0;
else if (i < 2*data.num_cols/num_gauss)
labels[i/2] = 1.0;
else
labels[i/2] = 2.0;
}

CDenseFeatures<float64_t>* train_feats = (CDenseFeatures<float64_t>*)features.copy_subset(train_idx);
CDenseFeatures<float64_t>* test_feats = (CDenseFeatures<float64_t>*)features.copy_subset(test_idx);

CMulticlassLabels* ground_truth = new CMulticlassLabels(labels);
CMulticlassOCAS* mocas = new CMulticlassOCAS(C, train_feats, ground_truth);
mocas->train();

CLabels* pred = mocas->apply(test_feats);
for (int i = 0; i < set_size; ++i)
EXPECT_EQ(ground_truth->get_label(i), ((CMulticlassLabels*)pred)->get_label(i));

SG_UNREF(mocas);
SG_UNREF(train_feats);
SG_UNREF(test_feats);
SG_UNREF(pred);
}

0 comments on commit f1dc712

Please sign in to comment.