Skip to content

Commit

Permalink
initial add of load_file_parameter method
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Dec 30, 2011
1 parent ba1f57d commit a37278b
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
104 changes: 104 additions & 0 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -15,6 +15,8 @@
#include <shogun/base/init.h>
#include <shogun/base/Version.h>
#include <shogun/base/Parameter.h>
#include <shogun/base/ParameterMap.h>
#include <shogun/base/DynArray.h>

#include <stdlib.h>
#include <stdio.h>
Expand Down Expand Up @@ -126,6 +128,7 @@ CSGObject::~CSGObject()
unset_global_objects();
delete m_parameters;
delete m_model_selection_parameters;
delete m_parameter_map;
}

#ifdef USE_REFERENCE_COUNTING
Expand Down Expand Up @@ -392,6 +395,106 @@ bool CSGObject::load_serializable(CSerializableFile* file,
return true;
}

TParameter* CSGObject::load_file_parameter(SGParamInfo* param_info,
int32_t file_version, CSerializableFile* file, const char* prefix)
{
/* ensure that recursion works */
if (file_version>param_info->m_param_version)
SG_SERROR("parameter version in file is more recent than provided!\n");

TParameter* result;

/* do mapping */
char* s=param_info->to_string();
SG_SPRINT("try to get mapping for: %s\n", s);
SG_FREE(s);
SGParamInfo* old=m_parameter_map->get(param_info);
m_parameter_map->print_map();
bool free_old=false;
if (old)
{
s=old->to_string();
SG_SPRINT("found: %s\n", s);
SG_FREE(s);
}
else
{
/* if no mapping was found, nothing has changed. Simply create new param
* info with decreased version */
SG_SPRINT("no mapping found, ");
if (file_version<param_info->m_param_version)
{
old=new SGParamInfo(*param_info);
old->m_param_version--;
free_old=true;
s=old->to_string();
SG_SPRINT("using %s\n", s);
SG_FREE(s);
}
else
{
SG_SPRINT("reached file version\n");
}
}

/* case file version same as provided version.
* means that parameter has to be loaded from file, recursion stops here */
if (file_version==param_info->m_param_version)
{
/* allocate memory for length and matrix/vector
* This has to be done because this stuff normally is in the class
* variables which do not exist in this case. Deletion is handled
* via the m_delete_data flag of TParameter */

/* length has to be allocated for matrices/vectors
* are also created here but no data allocation takes place */
index_t* len_x=NULL;
index_t* len_y=NULL;

switch (param_info->m_ctype)
{
case CT_VECTOR: case CT_SGVECTOR:
len_y=SG_MALLOC(index_t, 1);
break;
case CT_MATRIX: case CT_SGMATRIX:
len_x=SG_MALLOC(index_t, 1);
len_y=SG_MALLOC(index_t, 1);
break;
case CT_SCALAR:
break;
case CT_NDARRAY:
SG_NOTIMPLEMENTED;
default:
break;
}

/* create type and copy lengths, empty data for now */
TSGDataType type(param_info->m_ctype, param_info->m_stype,
param_info->m_ptype, len_y, len_x);
result=new TParameter(&type, NULL, param_info->m_name, "");

/* for scalars, allocate memory because normally they are on stack */
if (param_info->m_ctype==CT_SCALAR)
{
result->m_parameter=SG_MALLOC(char, type.get_size());
}

/* tell instance to load data from file */
result->load(file, prefix);
SG_SPRINT("done\n");
//CMath::display_vector((float64_t*)result->m_parameter, *result->m_datatype.m_length_y);
}
/* recursion with mapped type, a mapping exists in this case (ensured by
* above assert) */
else
result=load_file_parameter(old, file_version, file, prefix);

if (free_old)
delete old;

return result;
}

bool CSGObject::save_parameter_version(CSerializableFile* file,
const char* prefix)
{
Expand Down Expand Up @@ -463,6 +566,7 @@ void CSGObject::init()
version = NULL;
m_parameters = new Parameter();
m_model_selection_parameters = new Parameter();
m_parameter_map=new ParameterMap();
m_generic = PT_NOT_GENERIC;
m_load_pre_called = false;
m_load_post_called = false;
Expand Down
22 changes: 22 additions & 0 deletions src/shogun/base/SGObject.h
Expand Up @@ -32,7 +32,11 @@ class IO;
class Parallel;
class Version;
class Parameter;
class ParameterMap;
class SGParamInfo;
class CSerializableFile;
struct TParameter;
template <class T> class DynArray;

// define reference counter macros
//
Expand Down Expand Up @@ -162,6 +166,21 @@ class CSGObject
virtual bool load_serializable(CSerializableFile* file,
const char* prefix="");

/** loads a a specified parameter from a file with a specified version
* The provided parameter info has a version which is recursively mapped
* until the file parameter version is reached.
*
* @param param_info information of parameter
* @param file_version parameter version of the file, must be <= provided
* parameter version
* @param file file to load from
* @param prefix prefix for members
* @return new TParameter instance with the attached data
*/
TParameter* load_file_parameter(SGParamInfo* param_info,
int32_t file_version, CSerializableFile* file,
const char* prefix="");

/** set the io object
*
* @param io io object to use
Expand Down Expand Up @@ -299,6 +318,9 @@ class CSGObject
/** model selection parameters */
Parameter* m_model_selection_parameters;

/** map for different parameter versions */
ParameterMap* m_parameter_map;

private:

EPrimitiveType m_generic;
Expand Down

0 comments on commit a37278b

Please sign in to comment.