Skip to content

Commit

Permalink
Mahalanobis distance fixes
Browse files Browse the repository at this point in the history
- use mean of all examples
- improve documentation
- serialization support
  • Loading branch information
Soeren Sonnenburg committed Mar 5, 2012
1 parent 48fcd9f commit d3f6438
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 39 deletions.
52 changes: 23 additions & 29 deletions src/shogun/distance/MahalanobisDistance.cpp
Expand Up @@ -41,15 +41,16 @@ bool CMahalanobisDistance::init(CFeatures* l, CFeatures* r)
{
CRealDistance::init(l, r);

mean = ((CSimpleFeatures<float64_t>*) l)->get_mean();

if ( ((CSimpleFeatures<float64_t>*) l)->is_equal((CSimpleFeatures<float64_t>*) r) )
if ( l == r)
{
mean = ((CSimpleFeatures<float64_t>*) l)->get_mean();
icov = ((CSimpleFeatures<float64_t>*) l)->get_cov();
}
else
{
icov = CDotFeatures::compute_cov((CDotFeatures*)lhs, (CDotFeatures*)rhs);
mean = ((CSimpleFeatures<float64_t>*) l)->get_mean((CDotFeatures*) lhs, (CDotFeatures*) rhs);
icov = CDotFeatures::compute_cov((CDotFeatures*) lhs, (CDotFeatures*) rhs);
}

CMath::inverse(icov);
Expand All @@ -63,37 +64,25 @@ void CMahalanobisDistance::cleanup()

float64_t CMahalanobisDistance::compute(int32_t idx_a, int32_t idx_b)
{
int32_t alen, blen;
bool afree, bfree;
float64_t* avec;

float64_t* bvec = ((CSimpleFeatures<float64_t>*) rhs)->
get_feature_vector(idx_b, blen, bfree);
SGVector<float64_t> bvec = ((CSimpleFeatures<float64_t>*) rhs)->
get_feature_vector(idx_b);

SGVector<float64_t> c;
SGVector<float64_t> diff;
SGVector<float64_t> avec;

if (use_mean)
{
c = mean.clone();
}
diff = mean.clone();
else
{
avec = ((CSimpleFeatures<float64_t>*) lhs)->
get_feature_vector(idx_a, alen, afree);

c.resize_vector(alen);
for (int i = 0; i < alen; i++)
c[i] = avec[i];

((CSimpleFeatures<float64_t>*) lhs)->free_feature_vector(avec, idx_a, afree);
avec = ((CSimpleFeatures<float64_t>*) lhs)->get_feature_vector(idx_a);
diff=avec.clone();
}

ASSERT(blen == c.vlen);
ASSERT(diff.vlen == bvec.vlen);

SGVector<float64_t> diff;
diff.resize_vector(blen);
for (int32_t i = 0 ; i < diff.vlen ; i++)
diff[i] = bvec[i] - c[i];
for (int32_t i=0; i < diff.vlen; i++)
diff[i] = bvec.vector[i] - diff[i];

SGVector<float64_t> v = diff.clone();
cblas_dgemv(CblasColMajor, CblasNoTrans,
Expand All @@ -102,9 +91,13 @@ float64_t CMahalanobisDistance::compute(int32_t idx_a, int32_t idx_b)

float64_t result = cblas_ddot(v.vlen, v.vector, 1, diff.vector, 1);

((CSimpleFeatures<float64_t>*) rhs)->free_feature_vector(bvec, idx_b, bfree);
v.destroy_vector();
c.destroy_vector();
diff.destroy_vector();

if (!use_mean)
((CSimpleFeatures<float64_t>*) lhs)->free_feature_vector(avec, idx_a);

((CSimpleFeatures<float64_t>*) rhs)->free_feature_vector(bvec, idx_b);

if (disable_sqrt)
return result;
Expand All @@ -114,10 +107,11 @@ float64_t CMahalanobisDistance::compute(int32_t idx_a, int32_t idx_b)

void CMahalanobisDistance::init()
{
disable_sqrt = false;
use_mean = false;
disable_sqrt=false;
use_mean=false;

m_parameters->add(&disable_sqrt, "disable_sqrt", "If sqrt shall not be applied.");
m_parameters->add(&use_mean, "use_mean", "If distance shall be computed between mean vector and vector from rhs or between lhs and rhs.");
}

#endif /* HAVE_LAPACK */
22 changes: 16 additions & 6 deletions src/shogun/distance/MahalanobisDistance.h
Expand Up @@ -26,15 +26,25 @@ namespace shogun
* mean and covariance.
*
* \f[\displaystyle
* D = \sqrt{ (x_i - \mu)' \Sigma^{-1} (x_i - \mu) }
* D = \sqrt{ (x_i - \mu)^T \Sigma^{-1} (x_i - \mu) }
* \f]
*
* The Mahalanobis Squared distance does not take the square root:
*
* \f[\displaystyle
* D = (x_i - \mu)' \Sigma^{-1} (x_i - \mu)
* D = (x_i - \mu)^T \Sigma^{-1} (x_i - \mu)
* \f]
*
* If use_mean is set to false (which it is by default) the distance is computed
* as
*
* \f[\displaystyle
* D = \sqrt{ (x_i - \x_i')^T \Sigma^{-1} (x_i - \x_i') }
* \f]
*
* i.e., instead of the mean as reference two vector \f$x_i\f$ and \f$x_i'\f$
* are compared.
*
* @see <a href="en.wikipedia.org/wiki/Mahalanobis_distance">
* Wikipedia: Mahalanobis Distance</a>
*/
Expand Down Expand Up @@ -95,15 +105,15 @@ class CMahalanobisDistance: public CRealDistance
*/
virtual void set_disable_sqrt(bool state) { disable_sqrt=state; };

/** whether the distance is between the mean of lhs and a vector of
* rhs
/** whether the distance is computed between the mean and a vector of rhs
* or between lhs and rhs
*
* @return if the mean of lhs is used to obtain the distance
*/
virtual bool get_use_mean() { return use_mean; };

/** whether the distance is between the mean of lhs and a vector of
* rhs
/** whether the distance is computed between the mean and a vector of rhs
* or between lhs and rhs
*
* @param state new use_mean
*/
Expand Down
30 changes: 26 additions & 4 deletions src/shogun/features/DotFeatures.cpp
Expand Up @@ -409,6 +409,31 @@ SGVector<float64_t> CDotFeatures::get_mean()
return mean;
}

SGVector<float64_t> CDotFeatures::get_mean(CDotFeatures* lhs, CDotFeatures* rhs)
{
ASSERT(lhs && rhs);
ASSERT(lhs->get_dim_feature_space() == rhs->get_dim_feature_space());

int32_t num_lhs=lhs->get_num_vectors();
int32_t num_rhs=rhs->get_num_vectors();
int32_t dim=lhs->get_dim_feature_space();
ASSERT(num_lhs>0);
ASSERT(num_rhs>0);
ASSERT(dim>0);

SGVector<float64_t> mean(dim);
memset(mean.vector, 0, sizeof(float64_t)*dim);

for (int i = 0; i < num_lhs; i++)
lhs->add_to_dense_vec(1, i, mean.vector, dim);
for (int i = 0; i < num_rhs; i++)
rhs->add_to_dense_vec(1, i, mean.vector, dim);
for (int j = 0; j < dim; j++)
mean.vector[j] /= (num_lhs+num_rhs);

return mean;
}

SGMatrix<float64_t> CDotFeatures::get_cov()
{
int32_t num=get_num_vectors();
Expand Down Expand Up @@ -477,9 +502,7 @@ SGMatrix<float64_t> CDotFeatures::compute_cov(CDotFeatures* lhs, CDotFeatures* r

memset(cov.matrix, 0, sizeof(float64_t)*dim*dim);

SGVector<float64_t> mean=lhs->get_mean();
SGVector<float64_t> meanr=rhs->get_mean();
CMath::add<float64_t>(mean.vector, 0.5, mean.vector, 0.5, meanr.vector, mean.vlen);
SGVector<float64_t> mean=get_mean(lhs,rhs);

for (int i = 0; i < 2; i++)
{
Expand Down Expand Up @@ -513,7 +536,6 @@ SGMatrix<float64_t> CDotFeatures::compute_cov(CDotFeatures* lhs, CDotFeatures* r
}

mean.destroy_vector();
meanr.destroy_vector();
return cov;
}

Expand Down
6 changes: 6 additions & 0 deletions src/shogun/features/DotFeatures.h
Expand Up @@ -210,6 +210,12 @@ class CDotFeatures : public CFeatures
*/
virtual SGVector<float64_t> get_mean();

/** get mean of two CDotFeature objects
*
* @return mean returned
*/
static SGVector<float64_t> get_mean(CDotFeatures* lhs, CDotFeatures* rhs);

/** get covariance
*
* @return covariance
Expand Down

0 comments on commit d3f6438

Please sign in to comment.