Commit a8f6e66a authored by Christian Fuß's avatar Christian Fuß
Browse files

set version to 0.3.8-SNAPSHOT

parents 19552604 c2be8008
Pipeline #202179 passed with stages
in 16 minutes and 40 seconds
......@@ -18,7 +18,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.3.7-SNAPSHOT</version>
<version>0.3.8-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -24,6 +24,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
......@@ -31,17 +32,24 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
LossWeightsEntry implements ConfigEntry = name:"loss_weights" ":" value:DoubleVectorValue;
EvalMetricValue implements ConfigValue =(accuracy:"accuracy"
| crossEntropy:"cross_entropy"
| f1:"f1"
| mae:"mae"
| mse:"mse"
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
EvalMetricEntry implements MultiParamConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
interface EvalMetricValue extends MultiParamValue;
AccuracyEvalMetric implements EvalMetricValue = name:"accuracy";
BleuMetric implements EvalMetricValue = name:"bleu" ("{" params:BleuEntry* "}")?;
CrossEntropyEvalMetric implements EvalMetricValue = name:"cross_entropy";
F1EvalMetric implements EvalMetricValue = name:"f1";
MAEEvalMetric implements EvalMetricValue = name:"mae";
MSEEvalMetric implements EvalMetricValue = name:"mse";
PerplexityEvalMetric implements EvalMetricValue = name:"perplexity";
RMSEEvalMetric implements EvalMetricValue = name:"rmse";
TopKAccuracyEvalMetric implements EvalMetricValue = name:"top_k_accuracy";
interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
......@@ -131,6 +139,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue;
// Visual attention Extension
SaveAttentionImage implements ConfigEntry = name:"save_attention_image" ":" value:BooleanValue;
// Reinforcement Extensions
interface MultiParamValue extends ConfigValue;
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTIntegerListValue;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.monticore.numberunit._ast.ASTNumberWithUnit;
import de.se_rwth.commons.logging.Log;
public class CheckIntegerList implements CNNTrainASTIntegerListValueCoCo {
@Override
public void check(ASTIntegerListValue node) {
for (ASTNumberWithUnit element : node.getNumberList()) {
Double unitNumber = element.getNumber().get();
if ((unitNumber % 1)!= 0) {
Log.error("0" + ErrorCodes.NOT_INTEGER_CODE +" Value has to be an integer."
, node.get_SourcePositionStart());
}
}
}
}
......@@ -38,6 +38,7 @@ class ParameterAlgorithmMapping {
ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class,
ASTEvalMetricEntry.class,
ASTExcludeBleuEntry.class,
ASTNormalizeEntry.class,
ASTNumEpochEntry.class,
ASTLossEntry.class,
......@@ -46,7 +47,8 @@ class ParameterAlgorithmMapping {
ASTFromLogitsEntry.class,
ASTMarginEntry.class,
ASTLabelFormatEntry.class,
ASTRhoEntry.class
ASTRhoEntry.class,
ASTSaveAttentionImage.class
);
private static final List<Class> GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList(
......
......@@ -156,32 +156,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void visit(ASTEvalMetricEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentAccuracy()){
value.setValue(EvalMetric.ACCURACY);
}
else if (node.getValue().isPresentCrossEntropy()){
value.setValue(EvalMetric.CROSS_ENTROPY);
}
else if (node.getValue().isPresentF1()){
value.setValue(EvalMetric.F1);
}
else if (node.getValue().isPresentMae()){
value.setValue(EvalMetric.MAE);
}
else if (node.getValue().isPresentMse()){
value.setValue(EvalMetric.MSE);
}
else if (node.getValue().isPresentRmse()){
value.setValue(EvalMetric.RMSE);
}
else if (node.getValue().isPresentTopKAccuracy()){
value.setValue(EvalMetric.TOP_K_ACCURACY);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
processMultiParamConfigVisit(node, node.getValue().getName());
}
@Override
public void endVisit(ASTEvalMetricEntry node) {
processMultiParamConfigEndVisit(node);
}
@Override
......@@ -335,7 +315,21 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.map(n -> n.getNumber().get())
.collect(Collectors.toList());
}
private List<Integer> getIntegerListFromValue(ASTIntegerListValue value) {
return value.getNumberList().stream()
.filter(n -> n.getNumber().isPresent())
.map(n -> n.getNumber().get().intValue())
.collect(Collectors.toList());
}
@Override
public void endVisit(ASTSaveAttentionImage node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTLearningMethodEntry node) {
......@@ -598,6 +592,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.filter(n -> n.getNumber().isPresent())
.map(n -> n.getNumber().get())
.collect(Collectors.toList());
} else if (configValue instanceof ASTIntegerListValue) {
return getIntegerListFromValue((ASTIntegerListValue)configValue);
}
throw new UnsupportedOperationException("Unknown Value type: " + configValue.getClass());
}
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum EvalMetric {
ACCURACY{
@Override
public String toString() {
return "accuracy";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "crossEntropy";
}
},
F1{
@Override
public String toString() {
return "f1";
}
},
MAE{
@Override
public String toString() {
return "mae";
}
},
MSE{
@Override
public String toString() {
return "mse";
}
},
RMSE{
@Override
public String toString() {
return "rmse";
}
},
TOP_K_ACCURACY{
@Override
public String toString() {
return "topKAccuracy";
}
}
}
......@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain.helper;
public class ConfigEntryNameConstants {
public static final String LEARNING_METHOD = "learning_method";
public static final String EVAL_METRIC = "eval_metric";
public static final String NUM_EPISODES = "num_episodes";
public static final String DISCOUNT_FACTOR = "discount_factor";
public static final String NUM_MAX_STEPS = "num_max_steps";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment