Skip to content

Commit

Permalink
implemented thresholds for regression models
Browse files Browse the repository at this point in the history
  • Loading branch information
Ola Spjuth committed Jun 20, 2012
1 parent 31baf5d commit 1718cba
Showing 1 changed file with 63 additions and 27 deletions.
Expand Up @@ -56,17 +56,26 @@ public class SignaturesLibSVMPrediction extends AbstractDSTest{
private static final String LOW_PERCENTILE = "lowPercentile";
private static final String TRAIN_PARAMETER = "trainFile";

private static final String REGRESSION_LOWER_THRESHOLD = "lower_threshold";
private static final String REGRESSION_UPPER_THRESHOLD = "upper_threshold";
private static final String LOW_IS_NEGATIVE = "lowIsNegative";

private static final int NR_NEAR_NEIGHBOURS = 3;


protected double lowPercentile;
protected double highPercentile;
protected String trainFilename;

protected Float regrUpperThreshold=null;
protected Float regrLowerThreshold=null;
protected boolean lowIsNegative=true;


//We need to ensure that '.' is always decimal separator in all locales
DecimalFormat formatter=new DecimalFormat("0.000");

private Vector<svm_node[]> nearNeighborData;
// private Vector<svm_node[]> nearNeighborData;


//The model file
Expand Down Expand Up @@ -114,7 +123,7 @@ public List<String> getRequiredParameters() {
ret.add( SIGNATURES_MIN_HEIGHT );
// ret.add( HIGH_PERCENTILE );
// ret.add( LOW_PERCENTILE );
ret.add( TRAIN_PARAMETER );
// ret.add( TRAIN_PARAMETER );
return ret;
}

Expand Down Expand Up @@ -152,6 +161,15 @@ public void initialize(IProgressMonitor monitor) throws DSException {

System.out.println("Low percentile is: " + lowPercentile);
System.out.println("High percentile is: " + highPercentile);

try{
regrLowerThreshold=Float.parseFloat(getParameters().get(REGRESSION_LOWER_THRESHOLD));
regrUpperThreshold=Float.parseFloat(getParameters().get(REGRESSION_UPPER_THRESHOLD));
lowIsNegative=Boolean.parseBoolean(getParameters().get(LOW_IS_NEGATIVE));
}catch (Exception e){
logger.debug(getName() + " is a regression model without thresholds");
}

}

positiveValue=getParameters().get( "positiveValue" );
Expand Down Expand Up @@ -186,22 +204,22 @@ public void initialize(IProgressMonitor monitor) throws DSException {
throw new DSException("Could not read model file '" + modelPath
+ "' due to: " + e.getMessage());
}

String p_trainFilename = getParameters().get( TRAIN_PARAMETER );

try {
trainFilename = FileUtil.getFilePath(p_trainFilename, getPluginID());
} catch (Exception e) {
e.printStackTrace();
throw new DSException("Error reading train file: " + e.getMessage());
}

// Read the train file. It will be used to retrieve near neighbors to a query from the training data.
try {
nearNeighborData = createNearNeighborData(trainFilename);
} catch (IOException e) {
returnError("could not read train file", "could not read train file");
}
//
// String p_trainFilename = getParameters().get( TRAIN_PARAMETER );
//
// try {
// trainFilename = FileUtil.getFilePath(p_trainFilename, getPluginID());
// } catch (Exception e) {
// e.printStackTrace();
// throw new DSException("Error reading train file: " + e.getMessage());
// }
//
// // Read the train file. It will be used to retrieve near neighbors to a query from the training data.
// try {
// nearNeighborData = createNearNeighborData(trainFilename);
// } catch (IOException e) {
// returnError("could not read train file", "could not read train file");
// }


if (svmModel.param.svm_type == 0) // This is a classification model.
Expand Down Expand Up @@ -438,18 +456,23 @@ protected List<? extends ITestResult> doRunTest(ICDKMolecule cdkmol,
int negIX = classLabels.indexOf(negativeValue);
int predIX = classLabels.indexOf(predictedClassLabel);

if (predIX<=posIX)
if (predIX<=posIX){
match.setClassification( ITestResult.POSITIVE );
else if (predIX>=negIX)
match.setName("Result: Positive");
}
else if (predIX>=negIX){
match.setClassification( ITestResult.NEGATIVE );
else
match.setName("Result: Negative");
}
else{
match.setClassification( ITestResult.INCONCLUSIVE );
match.setName("Result: Inconclusive");
}


if (significantSignature.length()>0){
//OK, color atoms

match.setName(significantSignature);

for (int centerAtom : centerAtoms){

Expand Down Expand Up @@ -526,15 +549,28 @@ else if (predIX>=negIX)
}
System.out.println(atomGreadientComponents.toString());


ScaledResultMatch match = new ScaledResultMatch("Result: "
+ formatter.format( prediction ),
ITestResult.POSITIVE);
ITestResult.INFORMATIVE);

//Neg prediction means green - Negative overall results for the model
if (prediction<=0)
match.setClassification(ITestResult.NEGATIVE);
int result=ITestResult.INCONCLUSIVE;
if (regrLowerThreshold!=null && regrUpperThreshold!=null){

if (prediction<regrLowerThreshold){
if (lowIsNegative)
result=ITestResult.NEGATIVE;
else
result=ITestResult.POSITIVE;
}else if (prediction>regrUpperThreshold){
if (lowIsNegative)
result=ITestResult.POSITIVE;
else
result=ITestResult.NEGATIVE;
}
match.setClassification(result);
}


//Color atoms according to accumulated gradient values
for (int currentAtomNr : atomGreadientComponents.keySet()){
Double currentDeriv = atomGreadientComponents.get(currentAtomNr);
Expand Down

0 comments on commit 1718cba

Please sign in to comment.