Skip to content

Commit

Permalink
~ mean and cov of lhs features to class member and their computations…
Browse files Browse the repository at this point in the history
… to init method
  • Loading branch information
iglesias committed Feb 27, 2012
1 parent dbfdbd8 commit 7379fa6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
13 changes: 6 additions & 7 deletions src/shogun/distance/MahalanobisDistance.cpp
Expand Up @@ -41,6 +41,11 @@ bool CMahalanobisDistance::init(CFeatures* l, CFeatures* r)
{
CRealDistance::init(l, r);

mean = ((CSimpleFeatures<float64_t>*) l)->get_mean();
icov = ((CSimpleFeatures<float64_t>*) l)->get_cov();

CMath::inverse(icov);

return true;
}

Expand All @@ -50,10 +55,6 @@ void CMahalanobisDistance::cleanup()

float64_t CMahalanobisDistance::compute(int32_t idx_a, int32_t idx_b)
{

SGVector<float64_t> mean = ((CSimpleFeatures<float64_t>*) lhs)->get_mean();
SGMatrix<float64_t> cov = ((CSimpleFeatures<float64_t>*) lhs)->get_cov();

int32_t blen;
bool bfree;
float64_t* bvec = ((CSimpleFeatures<float64_t>*) rhs)->
Expand All @@ -65,11 +66,9 @@ float64_t CMahalanobisDistance::compute(int32_t idx_a, int32_t idx_b)
for (int32_t i = 0 ; i<blen ; i++)
diff[i] -= mean[i];

CMath::inverse(cov);

SGVector<float64_t> v = diff.clone();
cblas_dgemv(CblasColMajor, CblasNoTrans,
cov.num_rows, cov.num_cols, 1.0, cov.matrix,
icov.num_rows, icov.num_cols, 1.0, icov.matrix,
diff.vlen, diff.vector, 1, 0.0, v.vector, 1);

float64_t result = cblas_ddot(v.vlen, v.vector, 1, diff.vector, 1);
Expand Down
7 changes: 6 additions & 1 deletion src/shogun/distance/MahalanobisDistance.h
Expand Up @@ -99,7 +99,7 @@ class CMahalanobisDistance: public CRealDistance
/// compute Mahalanobis distance between a feature vector of the
/// rhs to the lhs distribution
/// idx_a is not used here but included because of inheritance
/// idx_b denote the index of the feature vector
/// idx_b denotes the index of the feature vector
/// in the corresponding feature object rhs
virtual float64_t compute(int32_t idx_a, int32_t idx_b);

Expand All @@ -109,6 +109,11 @@ class CMahalanobisDistance: public CRealDistance
protected:
/** if application of sqrt on matrix computation is disabled */
bool disable_sqrt;

/** vector mean of the lhs feature vectors */
SGVector<float64_t> mean;
/** inverse of the covariance matrix of lhs feature vectors */
SGMatrix<float64_t> icov;
};

} // namespace shogun
Expand Down

0 comments on commit 7379fa6

Please sign in to comment.