Skip to content

Commit

Permalink
added buffer_info structure for memory management in protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
gsomix committed Aug 10, 2012
1 parent 1604259 commit 8cebdf6
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 198 deletions.
Expand Up @@ -9,89 +9,90 @@

#ifdef SWIGPYTHON

%include "ProtoHelper.i"
%include "protocols_helper.i"

/* Numeric operators for DenseFeatures */
%define NUMERIC_DENSEFEATURES(class_name, type_name, format_str, operator_name, operator)

PyObject* class_name ## _inplace ## operator_name ## (PyObject *self, PyObject *o2)
{
CDenseFeatures< type_name > * arg1 = 0;
CDenseFeatures< type_name > * arg1=0; // self in c++ repr

void *argp1 = 0 ;
int res1 = 0;
int res2 = 0;
int res3 = 0;
void *argp1=0; // pointer to self
int res1=0; // result for self's casting
int res2=0; // result for checking buffer
int res3=0; // result for getting buffer

PyObject* resultobj = 0;
PyObject* resultobj=0;
Py_buffer view;
SGMatrix< type_name > buf; // internal buffer of self

int num_feat, num_vec;
int shape[2];

SGMatrix< type_name > temp;
int num_feat, num_vec; // shape of buffer of self
Py_ssize_t shape[2];
Py_ssize_t strides[2];

type_name *lhs;
type_name *buf;
char *rhs;

res1 = SWIG_ConvertPtr(self, &argp1, SWIG_TypeQuery("shogun::CDenseFeatures<type_name>"), 0 | 0 );
arg1 = reinterpret_cast< CDenseFeatures< type_name > * >(argp1);
res1=SWIG_ConvertPtr(self, &argp1, SWIG_TypeQuery("shogun::CDenseFeatures<type_name>"), 0 | 0 );
arg1=reinterpret_cast< CDenseFeatures< type_name > * >(argp1);

res2 = PyObject_CheckBuffer(o2);
res2=PyObject_CheckBuffer(o2);
if (!res2)
{
SWIG_exception_fail(SWIG_ArgError(res1), "this object don't support buffer protocol");
SWIG_exception_fail(SWIG_ArgError(res2), "this object don't support buffer protocol");
}

res3 = PyObject_GetBuffer(o2, &view, PyBUF_F_CONTIGUOUS | PyBUF_ND | PyBUF_STRIDES | 0);
if (res3 != 0 || view.buf==NULL)
res3=PyObject_GetBuffer(o2, &view, PyBUF_F_CONTIGUOUS | PyBUF_ND | PyBUF_STRIDES | 0);
if (res3!=0 || view.buf==NULL)
{
SWIG_exception_fail(SWIG_ArgError(res1), "bad buffer");
SWIG_exception_fail(SWIG_ArgError(res3), "bad buffer");
}

// checking that buffer is right
if (view.ndim != 2)
if (view.ndim!=2)
{
SWIG_exception_fail(SWIG_ArgError(res1), "wrong dimension");
SWIG_exception_fail(SWIG_ArgError(view.ndim), "wrong dimension");
}

if (view.itemsize != sizeof(type_name))
if (view.itemsize!=sizeof(type_name))
{
SWIG_exception_fail(SWIG_ArgError(res1), "wrong type");
SWIG_exception_fail(SWIG_ArgError(view.itemsize), "wrong type");
}

if (view.shape == NULL)
if (view.shape==NULL)
{
SWIG_exception_fail(SWIG_ArgError(res1), "wrong shape");
SWIG_exception_fail(SWIG_ArgError(0), "wrong shape");
}

shape[0] = view.shape[0];
shape[1] = view.shape[1];
if (shape[0] != arg1->get_num_features() || shape[1] != arg1->get_num_vectors())
SWIG_exception_fail(SWIG_ArgError(res1), "wrong size");
shape[0]=view.shape[0];
shape[1]=view.shape[1];
if (shape[0]!=arg1->get_num_features() || shape[1]!=arg1->get_num_vectors())
SWIG_exception_fail(SWIG_ArgError(0), "wrong size");

strides[0]=view.strides[0];
strides[1]=view.strides[1];

if (view.len != (shape[0]*shape[1])*view.itemsize)
SWIG_exception_fail(SWIG_ArgError(res1), "bad buffer");
if (view.len!=(shape[0]*shape[1])*view.itemsize)
SWIG_exception_fail(SWIG_ArgError(view.len), "bad buffer");

// result calculation
//lhs = arg1->get_feature_matrix(num_feat, num_vec);
temp=arg1->get_feature_matrix();
buf=arg1->get_feature_matrix();
num_feat=arg1->get_num_features();
num_vec=arg1->get_num_vectors();

lhs=temp.matrix;
lhs=buf.matrix;
rhs=(char*) view.buf;

// TODO strides support!
buf = (type_name*) view.buf;
for (int i = 0; i < num_vec; i++)
for (int i=0; i<num_vec; i++)
{
for (int j = 0; j < num_feat; j++)
for (int j=0; j<num_feat; j++)
{
lhs[num_feat*i + j] ## operator ## = buf[num_feat*i + j];
lhs[num_feat*i + j] ## operator ## = (*(type_name*) (rhs + strides[0]*i + strides[1]*j));
}
}

resultobj = self;
resultobj=self;
PyBuffer_Release(&view);

Py_INCREF(resultobj);
Expand Down Expand Up @@ -120,7 +121,7 @@ static int class_name ## _getbuffer(PyObject *self, Py_buffer *view, int flags)
Py_ssize_t* shape=NULL;
Py_ssize_t* strides=NULL;

SGMatrix< type_name > temp;
buffer_matrix_ ## type_name ## _info* info=NULL;

static char* format=(char *) format_str; // http://docs.python.org/dev/library/struct.html#module-struct

Expand All @@ -146,12 +147,13 @@ static int class_name ## _getbuffer(PyObject *self, Py_buffer *view, int flags)

arg1=reinterpret_cast< CDenseFeatures < type_name >* >(argp1);

//view->buf=arg1->get_feature_matrix(num_feat, num_vec);
temp=arg1->get_feature_matrix();
info=new buffer_matrix_ ## type_name ## _info;

info->buf=arg1->get_feature_matrix();
num_feat=arg1->get_num_features();
num_vec=arg1->get_num_vectors();

view->buf=temp.matrix;
view->buf=info->buf.matrix;

shape=new Py_ssize_t[2];
shape[0]=num_feat;
Expand All @@ -161,9 +163,12 @@ static int class_name ## _getbuffer(PyObject *self, Py_buffer *view, int flags)
strides[0]=sizeof( type_name );
strides[1]=sizeof( type_name ) * num_feat;

info->shape=shape;
info->strides=strides;

view->ndim=2;

view->format=format;
view->format=(char*) format_str;
view->itemsize=strides[0];

view->len=(shape[0]*shape[1])*view->itemsize;
Expand All @@ -172,7 +177,7 @@ static int class_name ## _getbuffer(PyObject *self, Py_buffer *view, int flags)

view->readonly=0;
view->suboffsets=NULL;
view->internal=NULL;
view->internal=(void*) info;

view->obj=(PyObject*) self;
Py_INCREF(self);
Expand All @@ -187,13 +192,18 @@ fail:
/* used by PyBuffer_Release */
static void class_name ## _releasebuffer(PyObject *self, Py_buffer *view)
{
if (view->obj!=NULL)
buffer_matrix_ ## type_name ## _info* temp=NULL;
if (view->obj!=NULL && view->internal!=NULL)
{
if (view->shape!=NULL)
delete[] view->shape;
temp=(buffer_matrix_ ## type_name ## _info*) view->internal;
if (temp->shape!=NULL)
delete[] temp->shape;

if (temp->strides!=NULL)
delete[] temp->strides;

if (view->strides!=NULL)
delete[] view->strides;
temp->buf=SGMatrix< type_name >();
delete temp;
}
}

Expand Down

0 comments on commit 8cebdf6

Please sign in to comment.