Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Merge pull request #603 from puffin444/master
Gradient Search Framework.
  • Loading branch information
Soeren Sonnenburg committed Jun 27, 2012
2 parents c9ae118 + 90d49c8 commit 60e54ed
Show file tree
Hide file tree
Showing 51 changed files with 2,001 additions and 422 deletions.
Expand Up @@ -108,10 +108,15 @@ void test_cross_validation()
cross->set_conf_int_alpha(0.05);

/* actual evaluation */
CrossValidationResult result=cross->evaluate();
result.print_result();
CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");

result->print_result();

/* clean up */
SG_UNREF(result);
SG_UNREF(cross);
SG_UNREF(features);
}
Expand Down
Expand Up @@ -103,6 +103,7 @@ void test_cross_validation()
cross->set_num_runs(5);
cross->set_conf_int_alpha(0.05);

CrossValidationResult* tmp;
/* no locking */
index_t repetitions=5;
SG_SPRINT("unlocked x-val\n");
Expand All @@ -111,25 +112,39 @@ void test_cross_validation()
CTime time;
time.start();
for (index_t i=0; i<repetitions; ++i)
cross->evaluate();
{
tmp = (CrossValidationResult*)cross->evaluate();
SG_UNREF(tmp);
}

time.stop();
SG_SPRINT("%f sec\n", time.cur_time_diff());

/* auto_locking in every iteration of this loop (better, not so nice) */
SG_SPRINT("locked in every iteration x-val\n");
cross->set_autolock(true);
time.start();

for (index_t i=0; i<repetitions; ++i)
cross->evaluate();
{
tmp = (CrossValidationResult*)cross->evaluate();
SG_UNREF(tmp);
}

time.stop();
SG_SPRINT("%f sec\n", time.cur_time_diff());

/* lock once before, (no locking/unlocking in this loop) */
svm->data_lock(labels, features);
SG_SPRINT("locked x-val\n");
time.start();
for (index_t i=0; i<repetitions; ++i)
cross->evaluate();

for (index_t i=0; i<repetitions; ++i)
{
tmp = (CrossValidationResult*)cross->evaluate();
SG_UNREF(tmp);
}

time.stop();
SG_SPRINT("%f sec\n", time.cur_time_diff());

Expand Down
Expand Up @@ -122,10 +122,15 @@ void test_cross_validation()
cross->set_conf_int_alpha(0.05);

/* actual evaluation */
CrossValidationResult result=cross->evaluate();
result.print_result();
CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");

result->print_result();

/* clean up */
SG_UNREF(result);
SG_UNREF(cross);
SG_UNREF(features);
SG_UNREF(labels);
Expand Down
Expand Up @@ -95,14 +95,19 @@ void test_cross_validation()
cross->set_conf_int_alpha(0.05);

/* actual evaluation */
CrossValidationResult result=cross->evaluate();
CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");

SG_SPRINT("cross_validation estimate:\n");
result.print_result();
result->print_result();

/* same crude assertion as for above evaluation */
ASSERT(result.mean<2);
ASSERT(result->mean<2);

/* clean up */
SG_UNREF(result);
SG_UNREF(cross);
SG_UNREF(features);
}
Expand Down
Expand Up @@ -153,19 +153,29 @@ int main(int argc, char **argv)
/* larger number of runs to have tighter confidence intervals */
cross->set_num_runs(10);
cross->set_conf_int_alpha(0.01);
CrossValidationResult result=cross->evaluate();
CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");

SG_SPRINT("result: ");
result.print_result();
result->print_result();

/* now again but unlocked */
SG_UNREF(best_combination);
cross->set_autolock(true);
best_combination=grid_search->select_model(print_state);
best_combination->apply_to_machine(classifier);
result=cross->evaluate();
SG_UNREF(result);
result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");

SG_SPRINT("result (unlocked): ");

/* clean up destroy result parameter */
SG_UNREF(result);
SG_UNREF(best_combination);
SG_UNREF(grid_search);

Expand Down
Expand Up @@ -139,13 +139,18 @@ void test_cross_validation()
/* larger number of runs to have tighter confidence intervals */
cross->set_num_runs(10);
cross->set_conf_int_alpha(0.01);
CrossValidationResult result=cross->evaluate();
CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");

SG_SPRINT("result: ");
result.print_result();
result->print_result();

/* clean up */
SG_UNREF(features);
SG_UNREF(best_combination);
SG_UNREF(result);
SG_UNREF(grid_search);
}

Expand Down
Expand Up @@ -98,10 +98,15 @@ int main(int argc, char **argv)
best_combination->print_tree();

best_combination->apply_to_machine(classifier);
CrossValidationResult result=cross->evaluate();
result.print_result();
CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");

result->print_result();

/* clean up */
SG_UNREF(result);
SG_UNREF(best_combination);
SG_UNREF(grid_search);
#endif // HAVE_LAPACK
Expand Down
Expand Up @@ -138,7 +138,11 @@ int main(int argc, char **argv)
/* larger number of runs to have tighter confidence intervals */
cross->set_num_runs(10);
cross->set_conf_int_alpha(0.01);
CrossValidationResult result=cross->evaluate();
CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_ERROR("Evaluation result is not of type CrossValidationResult!");

SG_SPRINT("result: ");
result.print_result();

Expand Down
Expand Up @@ -146,11 +146,16 @@ int main(int argc, char **argv)
cross->set_num_runs(10);
cross->set_conf_int_alpha(0.01);
classifier->data_lock(labels, features);
CrossValidationResult result=cross->evaluate();
CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();

if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");

SG_SPRINT("result: ");
result.print_result();
result->print_result();

/* clean up */
SG_UNREF(result);
SG_UNREF(best_combination);
SG_UNREF(grid_search);

Expand Down
70 changes: 67 additions & 3 deletions examples/undocumented/libshogun/regression_gaussian_process.cpp
Expand Up @@ -16,7 +16,11 @@
#include <shogun/regression/gp/GaussianLikelihood.h>
#include <shogun/regression/gp/ZeroMean.h>
#include <shogun/regression/GaussianProcessRegression.h>

#include <shogun/evaluation/GradientEvaluation.h>
#include <shogun/modelselection/GradientModelSelection.h>
#include <shogun/modelselection/ModelSelectionParameters.h>
#include <shogun/modelselection/ParameterCombination.h>
#include <shogun/evaluation/GradientCriterion.h>

using namespace shogun;

Expand Down Expand Up @@ -75,15 +79,71 @@ int main(int argc, char **argv)

SG_REF(labels);
CGaussianKernel* test_kernel = new CGaussianKernel(10, 2);

test_kernel->init(features, features);

CZeroMean* mean = new CZeroMean();
CGaussianLikelihood* lik = new CGaussianLikelihood();
lik->set_sigma(0.01);
CExactInferenceMethod* inf = new CExactInferenceMethod(test_kernel, features, mean, labels, lik);
SG_REF(inf);

CGaussianProcessRegression* gp = new CGaussianProcessRegression(inf, features, labels);

CModelSelectionParameters* root=new CModelSelectionParameters();

CModelSelectionParameters* c2=new CModelSelectionParameters("Inference Method", inf);
root->append_child(c2);

CModelSelectionParameters* c3=new CModelSelectionParameters("Likelihood Model", lik);
c2->append_child(c3);

CModelSelectionParameters* c1=new CModelSelectionParameters("sigma");
c3->append_child(c1);
c1->build_values(-10.0, 2.0, R_EXP);

CModelSelectionParameters* c4=new CModelSelectionParameters("Kernel", test_kernel);
c2->append_child(c4);

CModelSelectionParameters* c5=new CModelSelectionParameters("width");
c4->append_child(c5);
c5->build_values(-10.0, 2.0, R_EXP);

/* cross validation class for evaluation in model selection */
SG_REF(gp);

CGradientCriterion* crit = new CGradientCriterion();

CGradientEvaluation* grad=new CGradientEvaluation(gp, features, labels,
crit);

grad->set_function(inf);

gp->print_modsel_params();

root->print_tree();

/* handles all of the above structures in memory */
CGradientModelSelection* grad_search=new CGradientModelSelection(
root, grad);

/* set autolocking to false to get rid of warnings */
grad->set_autolock(false);

CParameterCombination* best_combination=grad_search->select_model(true);

SG_SPRINT("best parameter(s):\n");
best_combination->print_tree();

best_combination->apply_to_machine(gp);
CGradientResult* result=(CGradientResult*)grad->evaluate();

if(result->get_result_type() != GRADIENTEVALUATION_RESULT)
SG_SERROR("Evaluation result not a GradientEvaluationResult!");

result->print_result();


SGVector<float64_t> alpha = inf->get_alpha();
SGVector<float64_t> labe = labels->get_labels();
SGVector<float64_t> diagonal = inf->get_diagonal_vector();
Expand All @@ -105,8 +165,12 @@ int main(int argc, char **argv)
SG_UNREF(features2);
SG_UNREF(predictions);
SG_UNREF(labels);
SG_UNREF(inf);
SG_UNREF(gp);

SG_UNREF(grad_search);
SG_UNREF(best_combination);
SG_UNREF(result);

exit_shogun();

return 0;
Expand Down

0 comments on commit 60e54ed

Please sign in to comment.