Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Introduced CQuadraticTimeMMD and another level in the class hierarchy:
CKernelTwoSampleTestStatistic.
Minor fixes
  • Loading branch information
karlnapf committed May 23, 2012
1 parent ba81d43 commit 9725b6d
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 41 deletions.
41 changes: 41 additions & 0 deletions src/shogun/statistics/KernelTwoSampleTestStatistic.cpp
@@ -0,0 +1,41 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#include <shogun/statistics/KernelTwoSampleTestStatistic.h>
#include <shogun/features/Features.h>
#include <shogun/kernel/Kernel.h>

using namespace shogun;

CKernelTwoSampleTestStatistic::CKernelTwoSampleTestStatistic() :
CTwoSampleTestStatistic()
{
init();
}

CKernelTwoSampleTestStatistic::CKernelTwoSampleTestStatistic(CKernel* kernel,
CFeatures* p_and_q, index_t q_start) :
CTwoSampleTestStatistic(p_and_q, q_start)
{
init();

m_kernel=kernel;
SG_REF(kernel);
}

CKernelTwoSampleTestStatistic::~CKernelTwoSampleTestStatistic()
{
SG_UNREF(m_kernel);
}

void CKernelTwoSampleTestStatistic::init()
{
/* TODO register params */
m_kernel=NULL;
}
41 changes: 41 additions & 0 deletions src/shogun/statistics/KernelTwoSampleTestStatistic.h
@@ -0,0 +1,41 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#ifndef __KERNELTWOSAMPLETESTSTATISTIC_H_
#define __KERNELTWOSAMPLETESTSTATISTIC_H_

#include <shogun/statistics/TwoSampleTestStatistic.h>

namespace shogun
{

class CFeatures;
class CKernel;

class CKernelTwoSampleTestStatistic : public CTwoSampleTestStatistic
{
public:
CKernelTwoSampleTestStatistic();
CKernelTwoSampleTestStatistic(CKernel* kernel, CFeatures* p_and_q,
index_t q_start);

virtual ~CKernelTwoSampleTestStatistic();

inline virtual const char* get_name() const=0;

private:
void init();

protected:
CKernel* m_kernel;
};

}

#endif /* __KERNELTWOSAMPLETESTSTATISTIC_H_ */
24 changes: 7 additions & 17 deletions src/shogun/statistics/LinearTimeMMD.cpp
Expand Up @@ -12,13 +12,14 @@

using namespace shogun;

CLinearTimeMMD::CLinearTimeMMD() : CTwoSampleTestStatistic()
CLinearTimeMMD::CLinearTimeMMD() : CKernelTwoSampleTestStatistic()
{
init();
}

CLinearTimeMMD::CLinearTimeMMD(CKernel* kernel, CFeatures* p_and_q,
index_t q_start) :CTwoSampleTestStatistic(p_and_q, q_start)
index_t q_start) :
CKernelTwoSampleTestStatistic(kernel, p_and_q, q_start)
{
init();

Expand All @@ -27,22 +28,16 @@ CLinearTimeMMD::CLinearTimeMMD(CKernel* kernel, CFeatures* p_and_q,
SG_ERROR("CLinearTimeMMD: Only features with equal number of vectors "
"are currently possible\n");
}

m_kernel=kernel;
SG_REF(kernel);
}

CLinearTimeMMD::~CLinearTimeMMD()
{
SG_UNREF(m_kernel);

}

void CLinearTimeMMD::init()
{
/* TODO register parameters*/

m_kernel=NULL;
m_threshold_method=MMD_NONE;
}

float64_t CLinearTimeMMD::compute_statistic()
Expand All @@ -52,6 +47,7 @@ float64_t CLinearTimeMMD::compute_statistic()

/* m is number of samples from each distribution, m_2 is half of it
* using names from JLMR paper (see class documentation) */
// TODO here is something wrong! (possibly)
index_t m=m_q_start;
index_t m_2=m/2;

Expand Down Expand Up @@ -99,16 +95,10 @@ float64_t CLinearTimeMMD::compute_p_value(float64_t statistic)

switch (m_threshold_method)
{
case MMD_NONE:
/* use super-class method for bootstrapping */
result=CTwoSampleTestStatistic::compute_p_value(statistic);
break;

/* TODO implement new null distribution approximations here */
default:
SG_ERROR("%s::compute_threshold(): Unknown method to compute"
" threshold!\n");
result=CKernelTwoSampleTestStatistic::compute_p_value(statistic);
break;

}

return result;
Expand Down
18 changes: 2 additions & 16 deletions src/shogun/statistics/LinearTimeMMD.h
Expand Up @@ -10,23 +10,16 @@
#ifndef __LINEARTIMEMMD_H_
#define __LINEARTIMEMMD_H_

#include <shogun/statistics/TwoSampleTestStatistic.h>
#include <shogun/statistics/KernelTwoSampleTestStatistic.h>
#include <shogun/kernel/Kernel.h>

namespace shogun
{

class CFeatures;

/** enum for different method to compute p-value of test, MMD_NONE will result
* in calling CTwoSampleTestStatistic::compute_p_value, where bootstrapping
* is implemented */
enum EMMDThreshold
{
MMD_NONE
};

class CLinearTimeMMD : public CTwoSampleTestStatistic
class CLinearTimeMMD : public CKernelTwoSampleTestStatistic
{
public:
CLinearTimeMMD();
Expand All @@ -44,13 +37,6 @@ class CLinearTimeMMD : public CTwoSampleTestStatistic

private:
void init();

protected:
CKernel* m_kernel;

EMMDThreshold m_threshold_method;
index_t m_bootstrap_iterations;

};

}
Expand Down
108 changes: 108 additions & 0 deletions src/shogun/statistics/QuadraticTimeMMD.cpp
@@ -0,0 +1,108 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#include <shogun/statistics/QuadraticTimeMMD.h>
#include <shogun/features/Features.h>

using namespace shogun;

CQuadraticTimeMMD::CQuadraticTimeMMD() : CKernelTwoSampleTestStatistic()
{
init();
}

CQuadraticTimeMMD::CQuadraticTimeMMD(CKernel* kernel, CFeatures* p_and_q,
index_t q_start) :
CKernelTwoSampleTestStatistic(kernel, p_and_q, q_start)
{
init();

if (q_start!=p_and_q->get_num_vectors()/2)
{
SG_ERROR("CQuadraticTimeMMD: Only features with equal number of vectors "
"are currently possible\n");
}
}

CQuadraticTimeMMD::~CQuadraticTimeMMD()
{

}

void CQuadraticTimeMMD::init()
{
/* TODO register parameters*/
}

float64_t CQuadraticTimeMMD::compute_statistic()
{
/* split computations into three terms from JLMR paper (see documentation )*/
index_t m=m_q_start;
index_t n=m_p_and_q->get_num_vectors();

/* init kernel with features */
m_kernel->init(m_p_and_q, m_p_and_q);

/* first term */
float64_t first=0;
for (index_t i=0; i<m; ++i)
{
for (index_t j=0; j<m; ++j)
{
/* ensure i!=j */
if (i==j)
continue;

first+=m_kernel->kernel(i,j);
}
}
first/=m*(m-1);

/* second term */
float64_t second=0;
for (index_t i=m_q_start; i<n; ++i)
{
for (index_t j=m_q_start; j<n; ++j)
{
/* ensure i!=j */
if (i==j)
continue;

second+=m_kernel->kernel(i,j);
}
}
second/=n*(n-1);

/* third term */
float64_t third=0;
for (index_t i=0; i<m; ++i)
{
for (index_t j=m_q_start; j<n; ++j)
third+=m_kernel->kernel(i,j);
}
third*=-2.0/(m*n);

return first+second-third;
}

float64_t CQuadraticTimeMMD::compute_p_value(float64_t statistic)
{
float64_t result=0;

switch (m_threshold_method)
{
/* TODO implement new null distribution approximations here */
default:
result=CKernelTwoSampleTestStatistic::compute_p_value(statistic);
break;
}

return result;
}

43 changes: 43 additions & 0 deletions src/shogun/statistics/QuadraticTimeMMD.h
@@ -0,0 +1,43 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Heiko Strathmann
*/

#ifndef __QUADRACTIMEMMD_H_
#define __QUADRACTIMEMMD_H_

#include <shogun/statistics/KernelTwoSampleTestStatistic.h>
#include <shogun/kernel/Kernel.h>

namespace shogun
{

class CFeatures;

class CQuadraticTimeMMD : public CKernelTwoSampleTestStatistic
{
public:
CQuadraticTimeMMD();
CQuadraticTimeMMD(CKernel* kernel, CFeatures* p_and_q, index_t q_start);

virtual ~CQuadraticTimeMMD();

virtual float64_t compute_statistic();
virtual float64_t compute_p_value(float64_t statistic);

inline virtual const char* get_name() const
{
return "QuadraticTimeMMD";
};

private:
void init();
};

}

#endif /* __QUADRACTIMEMMD_H_ */
7 changes: 7 additions & 0 deletions src/shogun/statistics/StatisticalTest.cpp
Expand Up @@ -48,3 +48,10 @@ void CStatisticalTest::init()

m_statistic=NULL;
}

void CStatisticalTest::set_statistic(CTestStatistic* statistic)
{
SG_UNREF(m_statistic);
m_statistic=statistic;
SG_REF(statistic);
}
3 changes: 3 additions & 0 deletions src/shogun/statistics/StatisticalTest.h
Expand Up @@ -32,6 +32,9 @@ class CStatisticalTest : public CSGObject
* @return p-value of test result */
virtual float64_t perform_test();

/** sets a new test statistic, replacing the old one */
void set_statistic(CTestStatistic* statistic);

inline virtual const char* get_name() const { return "StatisticalTest"; }

private:
Expand Down

0 comments on commit 9725b6d

Please sign in to comment.