Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
added python protocol for SGVector
- Loading branch information
Showing
6 changed files
with
556 additions
and
73 deletions.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
examples/undocumented/python_modular/features_director_dot_modular.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import numpy | ||
try: | ||
from shogun.Features import DirectorDotFeatures | ||
from shogun.Library import RealVector | ||
except ImportError: | ||
print "recompile shogun with --enable-swig-directors" | ||
import sys | ||
sys.exit(0) | ||
|
||
from tools.load import LoadMatrix | ||
lm=LoadMatrix() | ||
|
||
traindat = lm.load_numbers('../data/fm_train_real.dat') | ||
testdat = lm.load_numbers('../data/fm_test_real.dat') | ||
label_traindat = lm.load_labels('../data/label_train_twoclass.dat') | ||
|
||
parameter_list = [[traindat,testdat,label_traindat,0.9,1e-3],[traindat,testdat,label_traindat,0.8,1e-2]] | ||
|
||
class NumpyFeatures(DirectorDotFeatures): | ||
|
||
# variables | ||
data=numpy.empty((1,1)) | ||
|
||
# constructor | ||
def __init__(self, d): | ||
DirectorDotFeatures.__init__(self) | ||
self.data = d | ||
|
||
# overloaded methods | ||
def add_to_dense_sgvec(self, alpha, vec_idx1, vec2, abs): | ||
vec2+=alpha*numpy.abs(self.data[:,vec_idx1]) | ||
|
||
def dot(self, vec_idx1, df, vec_idx2): | ||
return numpy.dot(self.data[:,vec_idx1], (df.get_computed_dot_feature_matrix())[:,vec_idx2]) | ||
|
||
def get_num_vectors(self): | ||
return self.data.shape[1] | ||
|
||
def get_dim_feature_space(self): | ||
return self.data.shape[0] | ||
|
||
# operators | ||
def __add__(self, other): | ||
return NumpyFeatures(self.data+other.data) | ||
|
||
def __sub__(self, other): | ||
return NumpyFeatures(self.data-other.data) | ||
|
||
def __iadd__(self, other): | ||
return NumpyFeatures(self.data+other.data) | ||
|
||
def __isub__(self, other): | ||
return NumpyFeatures(self.data-other.data) | ||
|
||
def features_director_dot_modular (fm_train_real, fm_test_real, | ||
label_train_twoclass, C, epsilon): | ||
|
||
from shogun.Features import RealFeatures, SparseRealFeatures, BinaryLabels | ||
from shogun.Classifier import LibLinear, L2R_L2LOSS_SVC_DUAL | ||
from shogun.Mathematics import Math_init_random | ||
Math_init_random(17) | ||
|
||
feats_train=RealFeatures(fm_train_real) | ||
feats_test=RealFeatures(fm_test_real) | ||
labels=BinaryLabels(label_train_twoclass) | ||
|
||
b = RealVector() | ||
print b | ||
|
||
dfeats_train=NumpyFeatures(fm_train_real) | ||
dfeats_test=NumpyFeatures(fm_test_real) | ||
|
||
print feats_train.get_computed_dot_feature_matrix() | ||
print dfeats_train.get_computed_dot_feature_matrix() | ||
|
||
svm=LibLinear(C, feats_train, labels) | ||
svm.set_liblinear_solver_type(L2R_L2LOSS_SVC_DUAL) | ||
svm.set_epsilon(epsilon) | ||
svm.set_bias_enabled(True) | ||
svm.train() | ||
|
||
svm.set_features(feats_test) | ||
svm.apply().get_labels() | ||
predictions = svm.apply() | ||
|
||
dsvm=LibLinear(C, dfeats_train, labels) | ||
dsvm.set_liblinear_solver_type(L2R_L2LOSS_SVC_DUAL) | ||
dsvm.set_epsilon(epsilon) | ||
dsvm.set_bias_enabled(True) | ||
dsvm.train() | ||
|
||
dsvm.set_features(dfeats_test) | ||
dsvm.apply().get_labels() | ||
dpredictions = dsvm.apply() | ||
|
||
return predictions, svm, predictions.get_labels() | ||
|
||
if __name__=='__main__': | ||
print('DirectorLinear') | ||
features_director_dot_modular(*parameter_list[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* Helper functions */ | ||
%wrapper | ||
%{ | ||
void get_slice_in_bounds(Py_ssize_t* ilow, Py_ssize_t* ihigh, Py_ssize_t max_idx) | ||
{ | ||
if (*ilow<0) | ||
{ | ||
*ilow=0; | ||
} | ||
else if (*ilow>max_idx) | ||
{ | ||
*ilow = max_idx; | ||
} | ||
if (*ihigh<*ilow) | ||
{ | ||
*ihigh=*ilow; | ||
} | ||
else if (*ihigh>max_idx) | ||
{ | ||
*ihigh=max_idx; | ||
} | ||
} | ||
|
||
Py_ssize_t get_idx_in_bounds(Py_ssize_t idx, Py_ssize_t max_idx) | ||
{ | ||
if (idx>=max_idx || idx<-max_idx) | ||
{ | ||
PyErr_SetString(PyExc_IndexError, "index out of bounds"); | ||
return -1; | ||
} | ||
else if (idx<0) | ||
return idx+max_idx; | ||
|
||
return idx; | ||
} | ||
|
||
int parse_tuple_item(PyObject* item, Py_ssize_t length, | ||
Py_ssize_t* ilow, Py_ssize_t* ihigh, | ||
Py_ssize_t* step, Py_ssize_t* slicelength) | ||
{ | ||
if (PySlice_Check(item)) | ||
{ | ||
PySlice_GetIndicesEx((PySliceObject*) item, length, ilow, ihigh, step, slicelength); | ||
get_slice_in_bounds(ilow, ihigh, length); | ||
|
||
return 2; | ||
} | ||
else if (PyInt_Check(item) || PyArray_IsScalar(item, Integer) || | ||
PyLong_Check(item) || (PyIndex_Check(item) && !PySequence_Check(item))) | ||
{ | ||
npy_intp idx; | ||
idx = PyArray_PyIntAsIntp(item); | ||
idx = get_idx_in_bounds(idx, length); | ||
|
||
*ilow=idx; | ||
*ihigh=idx+1; | ||
|
||
return 1; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
%} |
Oops, something went wrong.