Unverified Commit c2bd588f authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns Committed by GitHub

Timmermanns (#5)

* Complete rework (version 0.1.0).
parent 84995f61
# CNNTrain
[![Maintainability](https://api.codeclimate.com/v1/badges/c9ee58c9b0fe15f380f5/maintainability)](https://codeclimate.com/github/EmbeddedMontiArc/CNNTrainLang/maintainability)
[![Build Status](https://travis-ci.org/EmbeddedMontiArc/CNNTrainLang.svg?branch=master)](https://travis-ci.org/EmbeddedMontiArc/CNNTrainLang)
[![Build Status](https://circleci.com/gh/EmbeddedMontiArc/CNNTrainLang/tree/master.svg?style=shield&circle-token=:circle-token)](https://circleci.com/gh/EmbeddedMontiArc/CNNTrainLang/tree/master)
[![Coverage Status](https://coveralls.io/repos/github/EmbeddedMontiArc/CNNTrainLang/badge.svg?branch=master)](https://coveralls.io/github/EmbeddedMontiArc/CNNTrainLang?branch=master)
# CNNTrain
\ No newline at end of file
......@@ -30,19 +30,18 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.0.2-SNAPSHOT</version>
<version>0.1.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<monticore.version>4.5.3-SNAPSHOT</monticore.version>
<monticore.version>4.5.4-SNAPSHOT</monticore.version>
<se-commons.version>1.7.7</se-commons.version>
<mc.grammars.assembly.version>0.0.6-SNAPSHOT</mc.grammars.assembly.version>
<SIUnit.version>0.0.6-SNAPSHOT</SIUnit.version>
<Common-MontiCar.version>0.0.3-SNAPSHOT</Common-MontiCar.version>
<Math.version>0.0.3-SNAPSHOT-REWORK</Math.version>
<mc.grammars.assembly.version>0.0.6</mc.grammars.assembly.version>
<SIUnit.version>0.0.10-SNAPSHOT</SIUnit.version>
<Common-MontiCar.version>0.0.10-SNAPSHOT</Common-MontiCar.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......@@ -70,6 +69,12 @@
</properties>
<dependencies>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
<version>4.7.1</version>
</dependency>
<dependency>
<groupId>de.se_rwth.commons</groupId>
<artifactId>se-commons-logging</artifactId>
......@@ -138,20 +143,6 @@
<scope>provided</scope>
</dependency>
<dependency>
<groupId>de.monticore.lang</groupId>
<artifactId>math</artifactId>
<version>${Math.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang</groupId>
<artifactId>math</artifactId>
<version>${Math.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<!-- .. Test Libraries ............................................... -->
<dependency>
......@@ -231,6 +222,23 @@
<groupId>de.monticore.mojo</groupId>
<artifactId>monticore-maven-plugin</artifactId>
<version>${monticore.plugin}</version>
<configuration>
<skip>false</skip>
<script>de/monticore/monticore_noemf.groovy</script>
</configuration>
<dependencies>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4</artifactId>
<version>4.7.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.antlr/antlr4-runtime -->
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
<version>4.7.1</version>
</dependency>
</dependencies>
<executions>
<execution>
<goals>
......
package de.monticore.lang.monticar;
grammar CNNTrain extends de.monticore.lang.math.Math {
grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.lang.NumberUnit{
symbol scope CNNTrainCompilationUnit = "Configuration" name:Name& "{" TrainingConfiguration "}";
CNNTrainCompilationUnit = TrainingConfiguration;
symbol scope TrainingConfiguration = "configuration" name:Name& "{" entries:ConfigEntry* "}";
TrainingConfiguration = (assignments:ParameterAssignment)*;
interface Entry;
ast Entry = method String getName(){}
method ASTConfigValue getValue(){};
interface ConfigValue;
interface ConfigEntry extends Entry;
ParameterAssignment = lhs:TrainingParameter "=" rhs:ParameterRhs;
enum TrainingParameter = DATA:"data"
| LABELS:"labels"
| EPOCHS:"epochs"
| BATCHSIZE:"batch_size"
| OPTIMIZER:"optimizer"
| LEARNINGRATE:"learning_rate";
//eval_metric
//validation_split
ContextEntry implements ConfigEntry = name:"context" ":" value:ContextValue;
ModelPathEntry implements ConfigEntry = name:"model_path" ":" value:PathValue;
TrainDataEntry implements ConfigEntry = name:"train_data" ":" value:DataValue;
TrainLabelEntry implements ConfigEntry = name:"train_label" ":" value:DataValue;
TestDataEntry implements ConfigEntry = name:"test_data" ":" value:DataValue;
TestLabelEntry implements ConfigEntry = name:"test_label" ":" value:DataValue;
LoadingModeEntry implements ConfigEntry = name:"loading_mode" ":" value:LoadingModeValue;
OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue;
ValidationSplitEntry implements ConfigEntry = name:"validation_split" ":" value:NumberValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
CheckpointEntry implements ConfigEntry = name:"checkpoint" ":" value:IntegerValue;
ParameterRhs = stringVal:String
| number:Number
| refOrBool:Name;
PathValue implements ConfigValue = path:StringLiteral;
ContextValue implements ConfigValue = (GPU:"gpu" | CPU:"cpu");
EvalMetricValue implements ConfigValue =(accuracy:"accuracy"
| crossEntropy:"ce"
| f1:"f1"
| mae:"mae"
| mse:"mse"
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
| inv:"inv"
| poly:"poly"
| sigmoid:"sigmoid");
LoadingModeValue implements ConfigValue =(loadOnly:"load_only"
| loadAndTrain:"load_and_train"
| overwrite:"overwrite"
| noLoad:"no_load");
DataValue implements ConfigValue = (DataVariable | PathValue);
DataVariable = Name&;
IntegerValue implements ConfigValue = Number;
NumberValue implements ConfigValue = Number;
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
interface OptimizerValue extends ConfigValue;
//ast OptimizerValue = method java.util.List<? extends ASTEntry> getParams(){}
// method String getName(){};
interface SGDEntry extends Entry;
SGDOptimizer implements OptimizerValue = name:"sgd" ("{" params:SGDEntry* "}")?;
interface AdamEntry extends Entry;
AdamOptimizer implements OptimizerValue = name:"adam" ("{" params:AdamEntry* "}")?;
interface RmsPropEntry extends Entry;
RmsPropOptimizer implements OptimizerValue = name:"rmsprop" ("{" params:RmsPropEntry* "}")?;
interface AdaGradEntry extends Entry;
AdaGradOptimizer implements OptimizerValue = name:"adagrad" ("{" params:AdaGradEntry* "}")?;
//DCASGD, SGLD, Ftrl, Adamax
NesterovOptimizer implements OptimizerValue = name:"nag" ("{" params:SGDEntry* "}")?;
interface AdaDeltaEntry extends Entry;
AdaDeltaOptimizer implements OptimizerValue = name:"adadelta" ("{" params:AdaDeltaEntry* "}")?;
interface GeneralOptimizerEntry extends SGDEntry,AdamEntry,RmsPropEntry,AdaGradEntry,AdaDeltaEntry;
LearningRateEntry implements GeneralOptimizerEntry = name:"learning_rate" ":" value:NumberValue;
WeightDecayEntry implements GeneralOptimizerEntry = name:"weight_decay" ":" value:NumberValue;
LRDecayEntry implements GeneralOptimizerEntry = name:"learning_rate_decay" ":" value:NumberValue;
LRPolicyEntry implements GeneralOptimizerEntry = name:"learning_rate_policy" ":" value:LRPolicyValue;
RescaleGradEntry implements GeneralOptimizerEntry = name:"rescale_grad" ":" value:NumberValue;
ClipGradEntry implements GeneralOptimizerEntry = name:"clip_gradient" ":" value:NumberValue;
StepSizeEntry implements GeneralOptimizerEntry = name:"step_size" ":" value:IntegerValue;
MomentumEntry implements SGDEntry = name:"momentum" ":" value:NumberValue;
Beta1Entry implements AdamEntry = name:"beta1" ":" value:NumberValue;
Beta2Entry implements AdamEntry = name:"beta2" ":" value:NumberValue;
EpsilonEntry implements AdamEntry,AdaGradEntry,RmsPropEntry,AdaDeltaEntry = name:"epsilon" ":" value:NumberValue;
Gamma1Entry implements RmsPropEntry = name:"gamma1" ":" value:NumberValue;
Gamma2Entry implements RmsPropEntry = name:"gamma2" ":" value:NumberValue;
CenteredEntry implements RmsPropEntry = name:"centered" ":" value:BooleanValue;
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry = name:"rho" ":" value:NumberValue;
}
\ No newline at end of file
......@@ -22,24 +22,9 @@ package de.monticore.lang.monticar.cnntrain._ast;
import java.util.List;
public class ASTTrainingConfiguration extends ASTTrainingConfigurationTOP {
public interface ASTOptimizerValue extends ASTOptimizerValueTOP {
public ASTTrainingConfiguration() {
}
String getName();
public ASTTrainingConfiguration(List<ASTParameterAssignment> assignments) {
super(assignments);
}
public ASTParameterRhs get(String lhsName) {
ASTParameterRhs rhs = null;
lhsName = lhsName.replace("_", "");
for (ASTParameterAssignment assignment : getAssignments()) {
String assignmentLhs = assignment.getLhs().name();
if (assignmentLhs.equalsIgnoreCase(lhsName)) {
rhs = assignment.getRhs();
}
}
return rhs;
}
List<? extends ASTEntry> getParams();
}
......@@ -24,7 +24,9 @@ public class CNNTrainCocos {
public static CNNTrainCoCoChecker createChecker() {
return new CNNTrainCoCoChecker()
.addCoCo(new DuplicatedParameterCheck());
.addCoCo(new CheckEntryRepetition())
.addCoCo(new CheckInteger())
.addCoCo(new CheckValidPath());
}
}
......@@ -20,26 +20,25 @@
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTParameterAssignment;
import de.monticore.lang.monticar.cnntrain._ast.ASTTrainingConfiguration;
import de.monticore.lang.monticar.cnntrain._ast.ASTEntry;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.Set;
public class DuplicatedParameterCheck implements CNNTrainASTTrainingConfigurationCoCo {
public class CheckEntryRepetition implements CNNTrainASTEntryCoCo {
private Set<String> entryNameSet = new HashSet<>();
@Override
public void check(ASTTrainingConfiguration node) {
Set<Enum> set = new HashSet<>();
for (ASTParameterAssignment assignment : node.getAssignments()) {
if (set.contains(assignment.getLhs())) {
Log.error("0x03201 Multiple assignments of the same parameter are not allowed",
assignment.get_SourcePositionStart());
}
else {
set.add(assignment.getLhs());
}
public void check(ASTEntry node) {
if (entryNameSet.contains(node.getName())){
Log.error("0xC8853 The parameter '" + node.getName() + "' has multiple values. " +
"Multiple assignments of the same parameter are not allowed",
node.get_SourcePositionStart());
}
else {
entryNameSet.add(node.getName());
}
}
......
......@@ -18,51 +18,29 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._ast;
package de.monticore.lang.monticar.cnntrain._cocos;
import siunit.monticoresiunit.si._ast.ASTNumber;
import de.monticore.lang.monticar.cnntrain._ast.ASTIntegerValue;
import de.monticore.lang.numberunit._ast.ASTUnitNumber;
import de.se_rwth.commons.logging.Log;
import org.jscience.mathematics.number.Rational;
import java.util.Optional;
public class ASTParameterRhs extends ASTParameterRhsTOP {
private boolean containsBoolean;
public ASTParameterRhs() {
}
public ASTParameterRhs(String stringVal, ASTNumber number, String refOrBool) {
super(stringVal, number, refOrBool);
}
public class CheckInteger implements CNNTrainASTIntegerValueCoCo {
@Override
public void setRefOrBool(String refOrBool) {
if (refOrBool.equalsIgnoreCase("true")
|| refOrBool.equalsIgnoreCase("false")){
containsBoolean = true;
super.setRefOrBool(refOrBool.toLowerCase());
}
else {
containsBoolean = false;
super.setRefOrBool(refOrBool);
}
}
public Optional<String> getRef(){
if (containsBoolean) {
return Optional.empty();
}
else {
return getRefOrBool();
}
}
public Optional<String> getBooleanVal(){
if (containsBoolean) {
return getRefOrBool();
public void check(ASTIntegerValue node) {
Optional<ASTUnitNumber> unitNumber = node.getNumber().getUnitNumber();
if (unitNumber.isPresent()){
Rational number = unitNumber.get().getNumber().get();
if (number.getDivisor().intValue() != 1){
Log.error("0xC8851 Value has to be an integer."
, node.get_SourcePositionStart());
}
}
else {
return Optional.empty();
throw new IllegalStateException("integer check");
}
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTPathValue;
import de.se_rwth.commons.logging.Log;
import java.nio.file.Files;
import java.nio.file.InvalidPathException;
import java.nio.file.Path;
import java.nio.file.Paths;
public class CheckValidPath implements CNNTrainASTPathValueCoCo {
@Override
public void check(ASTPathValue node) {
try{
Path path = Paths.get(node.getPath().getValue().replaceAll("\"", ""));
/*if (!Files.exists(path)){
Log.error("0xC8855 File with path '" + node.getPath().getValue() + "' does not exist."
, node.get_SourcePositionStart());
}*/
}
catch (InvalidPathException e){
Log.error("0xC8556 Invalid path. " + e.getMessage());
}
}
}
......@@ -20,11 +20,8 @@
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.ResolvingConfiguration;
import de.monticore.symboltable.SymbolTableCreator;
import java.util.Optional;
import de.monticore.symboltable.Symbol;
import de.monticore.symboltable.resolving.CommonResolvingFilter;
public class CNNTrainLanguage extends CNNTrainLanguageTOP {
......@@ -42,6 +39,9 @@ public class CNNTrainLanguage extends CNNTrainLanguageTOP {
@Override
protected void initResolvingFilters() {
super.initResolvingFilters();
addResolvingFilter(new CommonResolvingFilter<Symbol>(EntrySymbol.KIND));
addResolvingFilter(new CommonResolvingFilter<Symbol>(NameValueSymbol.KIND));
addResolvingFilter(new CommonResolvingFilter<Symbol>(ValueSymbol.KIND));
setModelNameCalculator(new CNNTrainModelNameCalculator());
}
......
......@@ -20,8 +20,7 @@
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainCompilationUnit;
import de.monticore.lang.monticar.cnntrain._ast.ASTTrainingConfiguration;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.symboltable.ArtifactScope;
import de.monticore.symboltable.ImportStatement;
import de.monticore.symboltable.MutableScope;
......@@ -36,6 +35,7 @@ import java.util.Optional;
public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
private String compilationUnitPackage = "";
private TrainingConfigurationSymbol configuration;
public CNNTrainSymbolTableCreator(final ResolvingConfiguration resolvingConfig,
......@@ -51,7 +51,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void visit(final ASTCNNTrainCompilationUnit compilationUnit) {
Log.debug("Building Symboltable for Script: " + compilationUnit.getName(),
Log.debug("Building Symboltable for Script: " + compilationUnit.getTrainingConfiguration().getName(),
CNNTrainSymbolTableCreator.class.getSimpleName());
List<ImportStatement> imports = new ArrayList<>();
......@@ -62,16 +62,165 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
imports);
putOnStack(artifactScope);
}
CNNTrainCompilationUnitSymbol compilationUnitSymbol = new CNNTrainCompilationUnitSymbol(
compilationUnit.getName()
);
addToScopeAndLinkWithNode(compilationUnitSymbol, compilationUnit);
@Override
public void visit(final ASTTrainingConfiguration node){
configuration = new TrainingConfigurationSymbol(node.getName());
addToScopeAndLinkWithNode(configuration , node);
}
@Override
public void endVisit(final ASTTrainingConfiguration trainingConfiguration) {
removeCurrentScope();
}
@Override
public void endVisit(ASTEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue((ValueSymbol) node.getValue().getSymbol().get());
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTDataValue node){
ValueSymbol value;
if (node.getDataVariable().isPresent()){
value = (ValueSymbol) node.getDataVariable().get().getSymbol().get();
}
else {
value = (ValueSymbol) node.getPathValue().get().getSymbol().get();
}
node.setSymbol(value);
}
@Override
public void endVisit(ASTPathValue node) {
ValueSymbol value = new ValueSymbol();
value.setValue(node.getPath().getValue().replaceAll("\"", ""));
addToScopeAndLinkWithNode(value, node);
}
@Override
public void endVisit(ASTNumberValue node) {
ValueSymbol value = new ValueSymbol();
Double number = node.getNumber().getUnitNumber().get().getNumber().get().doubleValue();
value.setValue(number);
addToScopeAndLinkWithNode(value, node);
}
@Override
public void endVisit(ASTIntegerValue node) {
ValueSymbol value = new ValueSymbol();
Integer number = node.getNumber().getUnitNumber().get().getNumber().get().getDividend().intValue();
value.setValue(number);
addToScopeAndLinkWithNode(value, node);
}
@Override
public void endVisit(ASTBooleanValue node) {
ValueSymbol value = new ValueSymbol();
if (node.getTRUE().isPresent()){
value.setValue(true);
}
else if (node.getFALSE().isPresent()){
value.setValue(false);
}
addToScopeAndLinkWithNode(value, node);
}
@Override
public void endVisit(ASTContextValue node) {
ValueSymbol value = new ValueSymbol();
if (node.getCPU().isPresent()){
value.setValue(Context.CPU);
}
else if (node.getGPU().isPresent()){
value.setValue(Context.GPU);
}
addToScopeAndLinkWithNode(value, node);
}
@Override
public void endVisit(ASTEvalMetricValue node) {
ValueSymbol value = new ValueSymbol();
if (node.getAccuracy().isPresent()){
value.setValue(EvalMetric.ACCURACY);
}
else if (node.getCrossEntropy().isPresent()){
value.setValue(EvalMetric.CROSS_ENTROPY);
}
else if (node.getF1().isPresent()){
value.setValue(EvalMetric.F1);
}
else if (node.getMae().isPresent()){
value.setValue(EvalMetric.MAE);
}
else if (node.getMse().isPresent()){
value.setValue(EvalMetric.MSE);
}
else if (node.getRmse().isPresent()){
value.setValue(EvalMetric.RMSE);
}