Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
initial sketch for new statistical testing framework
  • Loading branch information
karlnapf committed May 5, 2012
1 parent 04e2696 commit 49cccd0
Show file tree
Hide file tree
Showing 7 changed files with 444 additions and 0 deletions.
158 changes: 158 additions & 0 deletions src/shogun/statistics/LinearTimeMMD.cpp
@@ -0,0 +1,158 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#include <shogun/statistics/LinearTimeMMD.h>
#include <shogun/features/Features.h>

using namespace shogun;

CLinearTimeMMD::CLinearTimeMMD() : CTwoSampleTestStatistic()
{
init();
}

CLinearTimeMMD::CLinearTimeMMD(CKernel* kernel, CFeatures* p_and_q,
index_t q_start) :CTwoSampleTestStatistic(p_and_q, q_start)
{
init();

if (q_start!=p_and_q->get_num_vectors()/2)
{
SG_ERROR("CLinearTimeMMD: Only features with equal number of vectors "
"are currently possible\n");
}

m_kernel=kernel;
SG_REF(kernel);
}

CLinearTimeMMD::~CLinearTimeMMD()
{
SG_UNREF(m_kernel);
}

void CLinearTimeMMD::init()
{
/* TODO register parameters*/

m_kernel=NULL;
m_threshold_method=MMD_BOOTSTRAP;
m_bootstrap_iterations=100;
}

float64_t CLinearTimeMMD::compute_statistic()
{
/* TODO maybe add parallelized kernel matrix trace method to CKernel? */
/* TODO features with a different number of vectors should be allowed */

/* m is number of samples from each distribution, m_2 is half of it
* using names from JLMR paper (see class documentation) */
index_t m=m_q_start;
index_t m_2=m/2;

/* allocate memory */
SGVector<float64_t> tr_K_x(m_2);
SGVector<float64_t> tr_K_y(m_2);
SGVector<float64_t> tr_K_xy(m);

/* compute traces of kernel matrices for linear MMD */
m_kernel->init(m_p_and_q, m_p_and_q);

/* p and p */
for (index_t i=0; i<m_2; ++i)
tr_K_x.vector[i]=m_kernel->kernel(i, m_2+i);

/* q and q */
for (index_t i=m_q_start; i<m+m_2; ++i)
tr_K_y.vector[i-m_q_start]=m_kernel->kernel(i, m_2+i);

/* p and q */
for (index_t i=0; i<m; ++i)
tr_K_xy.vector[i]=m_kernel->kernel(i, m+i);

/* compute result */
float64_t first=0;
float64_t second=0;
float64_t third=0;

for (index_t i=0; i<m_2; ++i)
{
first+=tr_K_x.vector[i];
second+=tr_K_y.vector[i];
third+=tr_K_xy.vector[i];
}

for (index_t i=m_2; i<m; ++i)
third+=tr_K_xy.vector[i-m_2];

tr_K_x.destroy_vector();
tr_K_y.destroy_vector();
tr_K_xy.destroy_vector();

return 1.0/m_2*(first+second)+1.0/m*third;
}

float64_t CLinearTimeMMD::compute_threshold(float64_t confidence)
{
float64_t result=0;

switch (m_threshold_method)
{
case MMD_BOOTSTRAP:
result=bootstrap_threshold(confidence);
break;

default:
SG_ERROR("%s::compute_threshold(): Unknown method to compute"
" threshold!\n");

}

return result;
}

float64_t CLinearTimeMMD::bootstrap_threshold(float64_t confidence)
{
/* compute mean of all bootstrap statistics using running averages */
SGVector<float64_t> results(m_bootstrap_iterations);

/* memory for index permutations, (would slow down loop) */
SGVector<index_t> ind_permutation(m_p_and_q->get_num_vectors());
ind_permutation.range_fill();

for (index_t i=0; i<m_bootstrap_iterations; ++i)
{
/* idea: merge features of p and q, shuffle, and compute statistic.
* This is done using subsets here */

/* create index permutation and add as subset. This will mix samples
* from p and q */
CMath::permute_vector(ind_permutation);
m_p_and_q->add_subset(ind_permutation);

/* compute statistic for this permutation of mixed samples */
results.vector[i]=compute_statistic();

/* clean up */
m_p_and_q->remove_subset();
}

/* clean up */
ind_permutation.destroy_vector();

/* compute threshold, sort elements and return the one that corresponds to
* confidence niveau */
CMath::qsort(results.vector, results.vlen);
index_t result_idx=CMath::round((1-confidence)*results.vlen);
float64_t result=results.vector[result_idx];

/* clean up and return */
results.destroy_vector();
return result;
}
58 changes: 58 additions & 0 deletions src/shogun/statistics/LinearTimeMMD.h
@@ -0,0 +1,58 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#ifndef __LINEARTIMEMMD_H_
#define __LINEARTIMEMMD_H_

#include <shogun/statistics/TwoSampleTestStatistic.h>
#include <shogun/kernel/Kernel.h>

namespace shogun
{

class CFeatures;

enum EMMDThreshold
{
MMD_BOOTSTRAP
};

class CLinearTimeMMD : public CTwoSampleTestStatistic
{
public:
CLinearTimeMMD();
CLinearTimeMMD(CKernel* kernel, CFeatures* p_and_q, index_t q_start);

virtual ~CLinearTimeMMD();

virtual float64_t compute_statistic();
virtual float64_t compute_threshold(float64_t confidence);

inline virtual const char* get_name() const
{
return "LinearTimeMMD";
};

protected:
float64_t bootstrap_threshold(float64_t confidence);

private:
void init();

protected:
CKernel* m_kernel;

EMMDThreshold m_threshold_method;
index_t m_bootstrap_iterations;

};

}

#endif /* __LINEARTIMEMMD_H_ */
57 changes: 57 additions & 0 deletions src/shogun/statistics/StatisticalTest.cpp
@@ -0,0 +1,57 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#include <shogun/statistics/StatisticalTest.h>
#include <shogun/statistics/TestStatistic.h>

using namespace shogun;

CStatisticalTest::CStatisticalTest() : CSGObject()
{
init();
}

CStatisticalTest::CStatisticalTest(CTestStatistic* statistic,
float64_t confidence) : CSGObject()
{
init();

m_statistic=statistic;
SG_REF(m_statistic);

m_confidence=confidence;
}

CStatisticalTest::~CStatisticalTest()
{
SG_UNREF(m_statistic);
}

bool CStatisticalTest::perform_test()
{
if (!m_statistic)
{
SG_ERROR("CStatisticalTest::perform_test(): No object to compute test "
"statistic!\n");
}

float64_t statistic=m_statistic->compute_statistic();
float64_t threshold=m_statistic->compute_threshold(m_confidence);

/* reject null-hypothesis if statistic is greater than threshold */
return statistic<threshold;
}

void CStatisticalTest::init()
{
/* TODO register parameters*/

m_statistic=NULL;
m_confidence=0;
}
47 changes: 47 additions & 0 deletions src/shogun/statistics/StatisticalTest.h
@@ -0,0 +1,47 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#ifndef __STATISTICALTEST_H_
#define __STATISTICALTEST_H_

#include <shogun/base/SGObject.h>

namespace shogun
{

class CTestStatistic;

class CStatisticalTest : public CSGObject
{
public:
CStatisticalTest();
CStatisticalTest(CTestStatistic* statistic, float64_t confidence);

virtual ~CStatisticalTest();

/** TODO
*
* @return true if the NULL-hypothesis is rejected */
virtual bool perform_test();

inline virtual const char* get_name() const { return "StatisticalTest"; }

private:
void init();

protected:
/** Confidence niveau of the test, test correct with (1-m_confidence) */
float64_t m_confidence;

CTestStatistic* m_statistic;
};

}

#endif /* __STATISTICALTEST_H_ */
42 changes: 42 additions & 0 deletions src/shogun/statistics/TestStatistic.h
@@ -0,0 +1,42 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#ifndef __TESTSTATISTIC_H_
#define __TESTSTATISTIC_H_

#include <shogun/base/SGObject.h>

namespace shogun
{

class CTestStatistic : public CSGObject
{
public:
CTestStatistic() {};

virtual ~CTestStatistic() {};

virtual float64_t compute_statistic()
{
SG_ERROR("%s::compute_statistic() is not implemented!\n");
return 0.0;
}

virtual float64_t compute_threshold(float64_t confidence)
{
SG_ERROR("%s::compute_threshold() is not implemented!\n");
return 0.0;
}

inline virtual const char* get_name() const=0;
};

}

#endif /* __TESTSTATISTIC_H_ */

0 comments on commit 49cccd0

Please sign in to comment.