Commit 7d7a5e0e authored by Sebastian N.'s avatar Sebastian N.
Browse files

Added exclude parameter to eval metric

parent 4ab0962b
Pipeline #195375 passed with stages
in 9 minutes and 56 seconds
......@@ -24,6 +24,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
......@@ -31,19 +32,24 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
LossWeightsEntry implements ConfigEntry = name:"loss_weights" ":" value:DoubleVectorValue;
EvalMetricValue implements ConfigValue =(accuracy:"accuracy"
| bleu:"bleu"
| crossEntropy:"cross_entropy"
| f1:"f1"
| mae:"mae"
| mse:"mse"
| perplexity:"perplexity"
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
EvalMetricEntry implements MultiParamConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
interface EvalMetricValue extends MultiParamValue;
AccuracyEvalMetric implements EvalMetricValue = name:"accuracy";
BleuMetric implements EvalMetricValue = name:"bleu" ("{" params:BleuEntry* "}")?;
CrossEntropyEvalMetric implements EvalMetricValue = name:"cross_entropy";
F1EvalMetric implements EvalMetricValue = name:"f1";
MAEEvalMetric implements EvalMetricValue = name:"mae";
MSEEvalMetric implements EvalMetricValue = name:"mse";
PerplexityEvalMetric implements EvalMetricValue = name:"perplexity";
RMSEEvalMetric implements EvalMetricValue = name:"rmse";
TopKAccuracyEvalMetric implements EvalMetricValue = name:"top_k_accuracy";
interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTIntegerListValue;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.monticore.numberunit._ast.ASTNumberWithUnit;
import de.se_rwth.commons.logging.Log;
public class CheckIntegerList implements CNNTrainASTIntegerListValueCoCo {
@Override
public void check(ASTIntegerListValue node) {
for (ASTNumberWithUnit element : node.getNumberList()) {
Double unitNumber = element.getNumber().get();
if ((unitNumber % 1)!= 0) {
Log.error("0" + ErrorCodes.NOT_INTEGER_CODE +" Value has to be an integer."
, node.get_SourcePositionStart());
}
}
}
}
......@@ -156,32 +156,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void visit(ASTEvalMetricEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentAccuracy()){
value.setValue(EvalMetric.ACCURACY);
}
else if (node.getValue().isPresentCrossEntropy()){
value.setValue(EvalMetric.CROSS_ENTROPY);
}
else if (node.getValue().isPresentF1()){
value.setValue(EvalMetric.F1);
}
else if (node.getValue().isPresentMae()){
value.setValue(EvalMetric.MAE);
}
else if (node.getValue().isPresentMse()){
value.setValue(EvalMetric.MSE);
}
else if (node.getValue().isPresentRmse()){
value.setValue(EvalMetric.RMSE);
}
else if (node.getValue().isPresentTopKAccuracy()){
value.setValue(EvalMetric.TOP_K_ACCURACY);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
processMultiParamConfigVisit(node, node.getValue().getName());
}
@Override
public void endVisit(ASTEvalMetricEntry node) {
processMultiParamConfigEndVisit(node);
}
@Override
......@@ -335,7 +315,13 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.map(n -> n.getNumber().get())
.collect(Collectors.toList());
}
private List<Integer> getIntegerListFromValue(ASTIntegerListValue value) {
return value.getNumberList().stream()
.filter(n -> n.getNumber().isPresent())
.map(n -> n.getNumber().get().intValue())
.collect(Collectors.toList());
}
@Override
public void visit(ASTLearningMethodEntry node) {
......@@ -598,6 +584,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.filter(n -> n.getNumber().isPresent())
.map(n -> n.getNumber().get())
.collect(Collectors.toList());
} else if (configValue instanceof ASTIntegerListValue) {
return getIntegerListFromValue((ASTIntegerListValue)configValue);
}
throw new UnsupportedOperationException("Unknown Value type: " + configValue.getClass());
}
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum EvalMetric {
ACCURACY{
@Override
public String toString() {
return "accuracy";
}
},
BLEU{
@Override
public String toString() {
return "bleu";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "crossEntropy";
}
},
F1{
@Override
public String toString() {
return "f1";
}
},
MAE{
@Override
public String toString() {
return "mae";
}
},
MSE{
@Override
public String toString() {
return "mse";
}
},
PERPLEXITY{
@Override
public String toString() {
return "perplexity";
}
},
RMSE{
@Override
public String toString() {
return "rmse";
}
},
TOP_K_ACCURACY{
@Override
public String toString() {
return "topKAccuracy";
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment