Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
initial work on computing p value using gamma distribution for quadra…
…tic MMD
  • Loading branch information
karlnapf committed Jun 7, 2012
1 parent 9e5ef91 commit 4f3c76e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
65 changes: 65 additions & 0 deletions src/shogun/statistics/QuadraticTimeMMD.cpp
Expand Up @@ -9,6 +9,7 @@

#include <shogun/statistics/QuadraticTimeMMD.h>
#include <shogun/features/Features.h>
#include <shogun/mathematics/Statistics.h>

using namespace shogun;

Expand Down Expand Up @@ -159,3 +160,67 @@ SGVector<float64_t> CQuadraticTimeMMD::sample_null_spectrum(index_t num_samples,

return null_samples;
}

float64_t CQuadraticTimeMMD::compute_p_value_gamma(float64_t statistic)
{
if (m_q_start!=m_p_and_q->get_num_vectors()/2)
{
/* TODO support different numbers of samples */
SG_ERROR("%s::compute_threshold_gamma(): Currently, only equal "
"sample sizes are supported\n", get_name());
}

/* imaginary matrix K=[K KL; KL' L] (MATLAB notation)
* K is matrix for XX, L is matrix for YY, KL is XY, LK is YX
* works since X and Y are concatenated here */
m_kernel->init(m_p_and_q, m_p_and_q);
SGMatrix<float64_t> K=m_kernel->get_kernel_matrix();
CMath::display_matrix(K, "kernel matrix");

/* compute mean under H0 of MMD, which is
* meanMMD = 2/m * ( 1 - 1/m*sum(diag(KL)) );
* in MATLAB */
float64_t mean_mmd=0;
for (index_t i=0; i<m_q_start; ++i)
{
/* virtual KL matrix is in upper right corner of SHOGUN K matrix
* so this sums the diagonal of the matrix between X and Y*/
mean_mmd+=K(i, m_q_start+i);

/* remove diagonal of all pairs of kernel matrices on the fly */
K(i, i)=0;
K(m_q_start+i, m_q_start+i)=0;
K(i, m_q_start+i)=0;
K(m_q_start+i, i)=0;
}
mean_mmd=2.0/m_q_start*(1.0-1.0/m_q_start*mean_mmd);
SG_PRINT("mean mmd: %f\n", mean_mmd);

/* compute variance under H0 of MMD, which is
* varMMD = 2/m/(m-1) * 1/m/(m-1) * sum(sum( (K + L - KL - KL').^2 ));
* in MATLAB, so sum up all elements */
float64_t var_mmd=0;
for (index_t i=0; i<m_q_start; ++i)
{
for (index_t j=0; j<m_q_start; ++j)
{
float64_t to_add=0;
to_add+=K(i, j);
to_add+=K(m_q_start+i, m_q_start+j);
to_add-=K(i, m_q_start+j);
to_add-=K(m_q_start+i, j);
var_mmd+=CMath::pow(to_add, 2);
}
}
var_mmd*=2.0/m_q_start/(m_q_start-1)*1.0/m_q_start/(m_q_start-1);
SG_PRINT("var mmd: %f\n", var_mmd);

float64_t a=CMath::pow(mean_mmd, 2)/var_mmd;
float64_t b=var_mmd*m_q_start / mean_mmd;

SG_PRINT("a=%f, b=%f\n", a,b);

/* return: cdf('gam',statistic,al,bet) (MATLAB)
* which will get the position in the null distribution */

}
2 changes: 2 additions & 0 deletions src/shogun/statistics/QuadraticTimeMMD.h
Expand Up @@ -53,6 +53,8 @@ class CQuadraticTimeMMD : public CKernelTwoSampleTestStatistic
SGVector<float64_t> sample_null_spectrum(index_t num_samples,
index_t num_eigenvalues=-1);

float64_t compute_p_value_gamma(float64_t statistic);

private:
void init();
};
Expand Down

0 comments on commit 4f3c76e

Please sign in to comment.