Skip to content

Commit

Permalink
Implemented significant signatures liblinear regression
Browse files Browse the repository at this point in the history
  • Loading branch information
olas committed Jun 16, 2015
1 parent 0bf9282 commit a9c79c8
Showing 1 changed file with 86 additions and 1 deletion.
Expand Up @@ -21,6 +21,7 @@
import java.util.List;
import java.util.Map;

import libsvm.svm;
import libsvm.svm_node;
import net.bioclipse.cdk.domain.ICDKMolecule;
import net.bioclipse.core.business.BioclipseException;
Expand Down Expand Up @@ -330,10 +331,94 @@ else if (prediction<regrLowerThreshold && !lowIsNegative)
match.setName("Result: " + formatter.format(prediction));
// return results;
}

Map<Integer, Double> atomGradientComponents = getAtomGradientComponents(model, instance, predModel);

match = new ScaledResultMatch("Result: "
+ formatter.format( prediction ),
ITestResult.INFORMATIVE);
results.clear();
results.add(match);

int result=ITestResult.INCONCLUSIVE;
if (regrLowerThreshold!=null && regrUpperThreshold!=null){

if (prediction<regrLowerThreshold){
if (lowIsNegative)
result=ITestResult.NEGATIVE;
else
result=ITestResult.POSITIVE;
}else if (prediction>regrUpperThreshold){
if (lowIsNegative)
result=ITestResult.POSITIVE;
else
result=ITestResult.NEGATIVE;
}
match.setClassification(result);
}


System.out.println("Scaled results:");
//Color atoms according to accumulated gradient values
for (int currentAtomNr : atomGradientComponents.keySet()){
Double currentDeriv = atomGradientComponents.get(currentAtomNr);

double scaledDeriv = scaleDerivative(currentDeriv);
match.putAtomResult( currentAtomNr, scaledDeriv );
System.out.println("Atom: " + currentAtomNr + " has deriv=" + currentDeriv +" scaled=" + scaledDeriv );

}

//No extracted signatures yet, return
return results;
}

public Map<Integer, Double> getAtomGradientComponents(Model model, Feature[] instance, PredictionModel predModel){

// Get the most significant signature for classification or the sum of all gradient components for regression.
List<Double> gradientComponents = new ArrayList<Double>();
int nOverk = fact(model.getNrClass())/(fact(2)*fact(model.getNrClass()-2)); // The number of decision functions for a classification.
double decValues[] = new double[nOverk];
double lowerPointValue[] = new double[nOverk];
double higherPointValue[] = new double[nOverk];
Linear.predictValues(model, instance, decValues);
lowerPointValue = decValues.clone();
for (int element = 0; element < instance.length; element++){
// Temporarily increase the descriptor value by one to compute the corresponding component of the gradient of the decision function.
instance[element].setValue(instance[element].getValue() + 1.00);
Linear.predictValues(model, instance, decValues);
higherPointValue = decValues.clone();
double gradComponentValue = 0.0;
for (int curDecisionFunc = 0; curDecisionFunc < nOverk; curDecisionFunc++) {
gradComponentValue = gradComponentValue + higherPointValue[curDecisionFunc]-lowerPointValue[curDecisionFunc];
}
gradientComponents.add(gradComponentValue);
// Set the value back to what it was.
instance[element].setValue(instance[element].getValue() - 1.00);

}

Map<Integer, Double> atomGreadientComponents = new HashMap<Integer, Double>(); // Contains a sum of all gradient components, based on modelSignatures, for a given atom.
for (int element = 0; element < instance.length; element++){
double componentVal = gradientComponents.get(element);
List<Integer> atomNrList = predModel.getMoleculeSignaturesAtomNr().get(signLibLinearModel.getModelSignatures().get(instance[element].getIndex()-1));
Iterator<Integer> atomNrInteger = atomNrList.iterator();
while (atomNrInteger.hasNext()){
int atomNr = atomNrInteger.next();
if (atomGreadientComponents.containsKey(atomNr)){
atomGreadientComponents.put(atomNr, atomGreadientComponents.get(atomNr)+componentVal);
}
else{
atomGreadientComponents.put(atomNr,componentVal);
}
}
}
System.out.println(atomGreadientComponents.toString());


return atomGreadientComponents;
}


// //End here if we should not extract significant signatures
// if (!extractSignificantSignatures)
Expand Down Expand Up @@ -566,7 +651,7 @@ else if (prediction<regrLowerThreshold && !lowIsNegative)
// }
//
// return results;
}
// }

/**
* A prediction model tales a mol as input, sets up all signatures needed,
Expand Down

0 comments on commit a9c79c8

Please sign in to comment.