Skip to content

Commit

Permalink
Fixed csharp crasher with making sequence data protected
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Aug 26, 2012
1 parent 1468c7e commit 0e7695a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
11 changes: 7 additions & 4 deletions src/shogun/evaluation/StructuredAccuracy.cpp
Expand Up @@ -98,15 +98,18 @@ float64_t CStructuredAccuracy::evaluate_sequence(CStructuredLabels* predicted,
CSequence* pred_seq =
CSequence::obtain_from_generic(predicted->get_label(i));

REQUIRE(true_seq->data.vlen == pred_seq->data.vlen, "Corresponding ground "
SGVector<int32_t> true_seq_data = true_seq->get_data();
SGVector<int32_t> pred_seq_data = pred_seq->get_data();

REQUIRE(true_seq_data.size() == pred_seq_data.size(), "Corresponding ground "
"truth and predicted sequences must be equally long\n");

num_equal = 0;
// Count the number of elements that are equal in both sequences
for ( int32_t j = 0 ; j < true_seq->data.vlen ; ++j )
num_equal += true_seq->data[j] == pred_seq->data[j];
for ( int32_t j = 0 ; j < true_seq_data.size() ; ++j )
num_equal += true_seq_data[j] == pred_seq_data[j];

accuracies[i] = (1.0*num_equal) / true_seq->data.vlen;
accuracies[i] = (1.0*num_equal) / true_seq_data.size();

SG_UNREF(true_seq);
SG_UNREF(pred_seq);
Expand Down
12 changes: 8 additions & 4 deletions src/shogun/structure/HMSVMLabels.h
Expand Up @@ -23,16 +23,17 @@ class CHMSVMLabels;

/** @brief Class CSequence to be used in the application of Structured Output
* (SO) learning to Hidden Markov Support Vector Machines (HM-SVM). */
struct CSequence : public CStructuredData
class CSequence : public CStructuredData
{
public:
/** data type */
STRUCTURED_DATA_TYPE(SDT_SEQUENCE);

/** constructor
*
* @param seq data sequence
*/
CSequence(SGVector< int32_t > seq) : CStructuredData(), data(seq) { }
CSequence(SGVector< int32_t > seq = SGVector<int32_t>()) : CStructuredData(), data(seq) { }

/** destructor */
~CSequence() { }
Expand All @@ -54,11 +55,14 @@ struct CSequence : public CStructuredData
/** @return name of SGSerializable */
virtual const char* get_name() const { return "Sequence"; }

/** returns data */
SGVector<int32_t> get_data() const { return data; }

protected:

/** data sequence */
SGVector< int32_t > data;

/** returns data */
SGVector<int32_t> get_data() const { return data; }
};

/** @brief Class CHMSVMLabels to be used in the application of Structured Output
Expand Down
9 changes: 5 additions & 4 deletions src/shogun/structure/HMSVMModel.cpp
Expand Up @@ -154,9 +154,9 @@ CResultSet* CHMSVMModel::argmax(
CSequence* ytrue =
CSequence::obtain_from_generic(m_labels->get_label(feat_idx));

REQUIRE(ytrue->data.vlen == T, "T, the length of the feature "
REQUIRE(ytrue->get_data().size() == T, "T, the length of the feature "
"x^i (%d) and the length of its corresponding label y^i "
"(%d) must be the same.\n", T, ytrue->data.vlen);
"(%d) must be the same.\n", T, ytrue->get_data().size());

SGMatrix< float64_t > loss_matrix = m_state_model->loss_matrix(ytrue);

Expand Down Expand Up @@ -368,9 +368,10 @@ bool CHMSVMModel::check_training_setup() const
{
seq = CSequence::obtain_from_generic(hmsvm_labels->get_label(i));

for ( int32_t j = 0 ; j < seq->data.vlen ; ++j )
SGVector<int32_t> seq_data = seq->get_data();
for ( int32_t j = 0 ; j < seq_data.size() ; ++j )
{
state = seq->data[j];
state = seq_data[j];

if ( state < 0 || state >= hmsvm_labels->get_num_states() )
{
Expand Down
5 changes: 3 additions & 2 deletions src/shogun/structure/TwoStateModel.cpp
Expand Up @@ -78,11 +78,12 @@ SGVector< int32_t > CTwoStateModel::labels_to_states(CSequence* label_seq) const
// 2 -> negative state (label == 0)
// 3 -> positive state (label == 1)

SGVector< int32_t > state_seq(label_seq->data.vlen);
SGVector< int32_t > seq_data = label_seq->get_data();
SGVector< int32_t > state_seq(seq_data.size());
for ( int32_t i = 1 ; i < state_seq.vlen-1 ; ++i )
{
//FIXME make independent of values 0-1 in labels
state_seq[i] = label_seq->data[i] + 2;
state_seq[i] = seq_data[i] + 2;
}

// The first element is always start state
Expand Down

0 comments on commit 0e7695a

Please sign in to comment.