Skip to content

Commit

Permalink
Fix for set subfeatures weights
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Apr 29, 2012
1 parent 1cac205 commit ccbdee1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 22 deletions.
21 changes: 11 additions & 10 deletions src/shogun/features/CombinedDotFeatures.cpp
Expand Up @@ -350,34 +350,35 @@ int32_t CCombinedDotFeatures::get_nnz_features_for_vector(int32_t num)
return result;
}

void CCombinedDotFeatures::get_subfeature_weights(float64_t** weights, int32_t* num_weights)
SGVector<float64_t> CCombinedDotFeatures::get_subfeature_weights()
{
*num_weights = get_num_feature_obj();
ASSERT(*num_weights > 0);
int32_t num_weights = get_num_feature_obj();
ASSERT(num_weights > 0);

*weights=SG_MALLOC(float64_t, *num_weights);
float64_t* w = *weights;
float64_t* weights=SG_MALLOC(float64_t, num_weights);

CListElement* current = NULL;
CDotFeatures* f = get_first_feature_obj(current);

int32_t i = 0;
while (f)
{
*w++=f->get_combined_feature_weight();
weights[i] = f->get_combined_feature_weight();

SG_UNREF(f);
f = get_next_feature_obj(current);
i++;
}
return SGVector<float64_t>(weights,num_weights);
}

void CCombinedDotFeatures::set_subfeature_weights(
float64_t* weights, int32_t num_weights)
void CCombinedDotFeatures::set_subfeature_weights(const SGVector<float64_t>& weights)
{
int32_t i=0 ;
int32_t i = 0;
CListElement* current = NULL ;
CDotFeatures* f = get_first_feature_obj(current);

ASSERT(num_weights==get_num_feature_obj());
ASSERT(weights.vlen==get_num_feature_obj());

while(f)
{
Expand Down
8 changes: 2 additions & 6 deletions src/shogun/features/CombinedDotFeatures.h
Expand Up @@ -277,18 +277,14 @@ class CCombinedDotFeatures : public CDotFeatures

/** get subfeature weights
*
* @param weights subfeature weights
* @param num_weights where number of weights is stored
*/
virtual void get_subfeature_weights(float64_t** weights, int32_t* num_weights);
virtual SGVector<float64_t> get_subfeature_weights();

/** set subfeature weights
*
* @param weights new subfeature weights
* @param num_weights number of subfeature weights
*/
virtual void set_subfeature_weights(
float64_t* weights, int32_t num_weights);
virtual void set_subfeature_weights(const SGVector<float64_t>& weights);

/** @return object name */
inline virtual const char* get_name() const { return "CombinedDotFeatures"; }
Expand Down
10 changes: 4 additions & 6 deletions src/shogun/ui/SGInterface.cpp
Expand Up @@ -3762,11 +3762,9 @@ bool CSGInterface::cmd_get_dotfeature_weights_combined()
if (features->get_feature_class()!=C_COMBINED_DOT)
SG_ERROR("Only works for combined dot features.\n");

float64_t* weights=NULL;
int32_t len=0;
((CCombinedDotFeatures*) features)->get_subfeature_weights(&weights, &len);
set_vector(weights, len);
SG_FREE(weights);
SGVector<float64_t> weights = ((CCombinedDotFeatures*) features)->get_subfeature_weights();
set_vector(weights.vector, weights.vlen);
weights.destroy_vector();

return true;
}
Expand Down Expand Up @@ -3801,7 +3799,7 @@ bool CSGInterface::cmd_set_dotfeature_weights_combined()
int32_t len=0;
get_matrix(weights, dim, len);

((CCombinedDotFeatures*) features)->set_subfeature_weights(weights, len);
((CCombinedDotFeatures*) features)->set_subfeature_weights(SGVector<float64_t>(weights, len));

return true;
}
Expand Down

0 comments on commit ccbdee1

Please sign in to comment.