Commit 113c2b39 authored by Svetlana Pavlitskaya's avatar Svetlana Pavlitskaya Committed by Thomas Michael Timmermanns
Browse files

Changes in grammar. Added Generator and template

parent 1e8e6590
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.1.1-SNAPSHOT</version>
<version>0.2.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -2,10 +2,11 @@ package de.monticore.lang.monticar;
grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.lang.NumberUnit{
CNNTrainCompilationUnit = Configuration;
symbol scope Configuration = "configuration"
symbol scope CNNTrainCompilationUnit = "configuration"
name:Name&
"{"entries:ConfigEntry* "}";
Configuration;
Configuration = "{"entries:ConfigEntry* "}";
interface Entry;
ast Entry = method String getName(){}
......@@ -15,18 +16,12 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.lang.N
interface VariableReference;
ast VariableReference = method String getName(){};
OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue;
EvalMetricValue implements ConfigValue =(accuracy:"accuracy"
| crossEntropy:"ce"
| f1:"f1"
| mae:"mae"
| mse:"mse"
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
......
......@@ -20,6 +20,10 @@
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.se_rwth.commons.logging.Log;
public class CNNTrainCocos {
public static CNNTrainCoCoChecker createChecker() {
......@@ -28,4 +32,10 @@ public class CNNTrainCocos {
.addCoCo(new CheckInteger());
}
public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){
ASTCNNTrainNode node = (ASTCNNTrainNode) compilationUnit.getAstNode().get();
int findings = Log.getFindings().size();
createChecker().checkAll(node);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
import java.util.Optional;
public class BatchSizeSymbol extends CommonSymbol {
public static final BatchSizeSymbolKind KIND = new BatchSizeSymbolKind();
private int value;
public BatchSizeSymbol() {
super("", KIND);
}
public BatchSizeSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public int getValue() {
return value;
}
public void setValue(int value) {
this.value = value;
}
}
......@@ -22,9 +22,9 @@ package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class ConfigParameterKind implements SymbolKind {
public class BatchSizeSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.ConfigParameterKind";
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.BatchSizeSymbolKind";
@Override
public String getName() {
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public class CNNTrainCompilationUnitSymbol extends CNNTrainCompilationUnitSymbolTOP{
private ConfigurationSymbol configuration;
public CNNTrainCompilationUnitSymbol(String name) {
super(name);
}
public ConfigurationSymbol getConfiguration() {
return configuration;
}
public void setConfiguration(ConfigurationSymbol configuration) {
this.configuration = configuration;
}
public void resolve(){
;
// getConfiguration().resolve();
}
}
......@@ -39,9 +39,10 @@ public class CNNTrainLanguage extends CNNTrainLanguageTOP {
@Override
protected void initResolvingFilters() {
super.initResolvingFilters();
addResolvingFilter(new CommonResolvingFilter<Symbol>(EntrySymbol.KIND));
addResolvingFilter(new CommonResolvingFilter<Symbol>(NameValueSymbol.KIND));
addResolvingFilter(new CommonResolvingFilter<Symbol>(ValueSymbol.KIND));
addResolvingFilter(new CNNTrainCompilationUnitResolvingFilter());
addResolvingFilter(new CommonResolvingFilter<Symbol>(ConfigurationSymbol.KIND));
addResolvingFilter(new CommonResolvingFilter<Symbol>(OptimizerSymbol.KIND));
addResolvingFilter(new CommonResolvingFilter<Symbol>(OptimizerParamSymbol.KIND));
setModelNameCalculator(new CNNTrainModelNameCalculator());
}
......
......@@ -27,10 +27,7 @@ import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.ResolvingConfiguration;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Optional;
import java.util.*;
public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
......@@ -39,19 +36,19 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
public CNNTrainSymbolTableCreator(final ResolvingConfiguration resolvingConfig,
final MutableScope enclosingScope) {
final MutableScope enclosingScope) {
super(resolvingConfig, enclosingScope);
}
public CNNTrainSymbolTableCreator(final ResolvingConfiguration resolvingConfig,
final Deque<MutableScope> scopeStack) {
final Deque<MutableScope> scopeStack) {
super(resolvingConfig, scopeStack);
}
@Override
public void visit(final ASTCNNTrainCompilationUnit compilationUnit) {
Log.debug("Building Symboltable for Script: " + compilationUnit.getConfiguration().getName(),
Log.debug("Building Symboltable for Script: " + compilationUnit.getName(),
CNNTrainSymbolTableCreator.class.getSimpleName());
List<ImportStatement> imports = new ArrayList<>();
......@@ -62,11 +59,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
imports);
putOnStack(artifactScope);
CNNTrainCompilationUnitSymbol compilationUnitSymbol = new CNNTrainCompilationUnitSymbol(compilationUnit.getName());
addToScopeAndLinkWithNode(compilationUnitSymbol, compilationUnit);
}
@Override
public void endVisit(ASTCNNTrainCompilationUnit ast) {
CNNTrainCompilationUnitSymbol compilationUnitSymbol = (CNNTrainCompilationUnitSymbol) ast.getSymbol().get();
compilationUnitSymbol.setConfiguration((ConfigurationSymbol) ast.getConfiguration().getSymbol().get());
setEnclosingScopeOfNodes(ast);
}
@Override
public void visit(final ASTConfiguration node){
configuration = new ConfigurationSymbol(node.getName());
configuration = new ConfigurationSymbol();
addToScopeAndLinkWithNode(configuration , node);
}
......@@ -76,72 +84,57 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
}
@Override
public void endVisit(ASTEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue((ValueSymbol) node.getValue().getSymbol().get());
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
public void visit(ASTOptimizerEntry node) {
OptimizerSymbol optimizer = new OptimizerSymbol(node.getValue().getName());
configuration.setOptimizer(optimizer);
addToScopeAndLinkWithNode(optimizer, node);
}
@Override
public void endVisit(ASTNumberValue node) {
ValueSymbol value = new ValueSymbol();
Double number = node.getNumber().getUnitNumber().get().getNumber().get().doubleValue();
value.setValue(number);
addToScopeAndLinkWithNode(value, node);
public void endVisit(ASTOptimizerEntry node) {
for (ASTEntry nodeParam : node.getValue().getParams()) {
OptimizerParamSymbol param = new OptimizerParamSymbol();
OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol) nodeParam.getValue().getSymbol().get();
param.setValue(valueSymbol);
configuration.getOptimizer().getOptimizerParamMap().put(nodeParam.getName(), param);;
}
}
@Override
public void endVisit(ASTIntegerValue node) {
ValueSymbol value = new ValueSymbol();
Integer number = node.getNumber().getUnitNumber().get().getNumber().get().getDividend().intValue();
value.setValue(number);
addToScopeAndLinkWithNode(value, node);
public void endVisit(ASTNumEpochEntry node) {
NumEpochSymbol symbol = new NumEpochSymbol();
Integer value_as_int = node.getValue().getNumber().getUnitNumber().get().getNumber().get().getDividend().intValue();
symbol.setValue(value_as_int);
addToScopeAndLinkWithNode(symbol, node);
configuration.setNumEpoch(symbol);
}
@Override
public void endVisit(ASTBooleanValue node) {
ValueSymbol value = new ValueSymbol();
if (node.getTRUE().isPresent()){
value.setValue(true);
}
else if (node.getFALSE().isPresent()){
value.setValue(false);
}
addToScopeAndLinkWithNode(value, node);
public void endVisit(ASTBatchSizeEntry node) {
BatchSizeSymbol symbol = new BatchSizeSymbol();
Integer value_as_int = node.getValue().getNumber().getUnitNumber().get().getNumber().get().getDividend().intValue();
symbol.setValue(value_as_int);
addToScopeAndLinkWithNode(symbol, node);
configuration.setBatchSize(symbol);
}
@Override
public void endVisit(ASTEvalMetricValue node) {
ValueSymbol value = new ValueSymbol();
if (node.getAccuracy().isPresent()){
value.setValue(EvalMetric.ACCURACY);
}
else if (node.getCrossEntropy().isPresent()){
value.setValue(EvalMetric.CROSS_ENTROPY);
}
else if (node.getF1().isPresent()){
value.setValue(EvalMetric.F1);
public void endVisit(ASTLoadCheckpointEntry node) {
LoadCheckpointSymbol symbol = new LoadCheckpointSymbol();
if (node.getValue().getTRUE().isPresent()){
symbol.setValue(true);
}
else if (node.getMae().isPresent()){
value.setValue(EvalMetric.MAE);
else if (node.getValue().getFALSE().isPresent()){
symbol.setValue(false);
}
else if (node.getMse().isPresent()){
value.setValue(EvalMetric.MSE);
}
else if (node.getRmse().isPresent()){
value.setValue(EvalMetric.RMSE);
}
else if (node.getTopKAccuracy().isPresent()){
value.setValue(EvalMetric.TOP_K_ACCURACY);
}
addToScopeAndLinkWithNode(value, node);
configuration.setLoadCheckpoint(symbol);
}
@Override
public void endVisit(ASTLRPolicyValue node) {
ValueSymbol value = new ValueSymbol();
OptimizerParamValueSymbol value = new OptimizerParamValueSymbol();
if (node.getFixed().isPresent()){
value.setValue(LRPolicy.FIXED);
}
......@@ -163,15 +156,32 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
addToScopeAndLinkWithNode(value, node);
}
public void endVisit(ASTOptimizerValue node){
ValueSymbol value = new ValueSymbol();
value.setValue(node.getName());
@Override
public void endVisit(ASTNumberValue node) {
OptimizerParamValueSymbol value = new OptimizerParamValueSymbol();
Double number = node.getNumber().getUnitNumber().get().getNumber().get().doubleValue();
value.setValue(number);
addToScopeAndLinkWithNode(value, node);
}
@Override
public void endVisit(ASTDataVariable node){
NameValueSymbol variableSymbol = new NameValueSymbol(node.getName());
addToScopeAndLinkWithNode(variableSymbol, node);
public void endVisit(ASTIntegerValue node) {
OptimizerParamValueSymbol value = new OptimizerParamValueSymbol();
Integer number = node.getNumber().getUnitNumber().get().getNumber().get().getDividend().intValue();
value.setValue(number);
addToScopeAndLinkWithNode(value, node);
}
@Override
public void endVisit(ASTBooleanValue node) {
OptimizerParamValueSymbol value = new OptimizerParamValueSymbol();
if (node.getTRUE().isPresent()){
value.setValue(true);
}
else if (node.getFALSE().isPresent()){
value.setValue(false);
}
addToScopeAndLinkWithNode(value, node);
}
}
......@@ -20,31 +20,53 @@
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.jscience.mathematics.number.Rational;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import java.util.*;
import java.util.Optional;
public class ConfigurationSymbol extends ConfigurationSymbolTOP {
public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map<String, EntrySymbol> entryMap = new HashMap<>();
private List<ConfigParameterSymbol> parameters = new ArrayList<>();
private NumEpochSymbol numEpoch;
private BatchSizeSymbol batchSize;
private LoadCheckpointSymbol loadCheckpoint;
private OptimizerSymbol optimizer;
public ConfigurationSymbol(String name) {
super(name);
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
public ConfigurationSymbol() {
super("", KIND);
}
public NumEpochSymbol getNumEpoch() {
return numEpoch;
}
public void setNumEpoch(NumEpochSymbol numEpoch) {
this.numEpoch = numEpoch;
}
public OptimizerSymbol getOptimizer() {
return optimizer;
}
public void setOptimizer(OptimizerSymbol optimizer) {
this.optimizer = optimizer;
}
public BatchSizeSymbol getBatchSize() {
return batchSize;
}
public Map<String, EntrySymbol> getEntryMap() {
return entryMap;
public void setBatchSize(BatchSizeSymbol batchSize) {
this.batchSize = batchSize;
}
public List<ConfigParameterSymbol> getParameters() {
return parameters;
public LoadCheckpointSymbol getLoadCheckpoint() {
return loadCheckpoint;
}
public void setParameters(List<ConfigParameterSymbol> parameters) {
this.parameters = parameters;
public void setLoadCheckpoint(LoadCheckpointSymbol loadCheckpoint) {
this.loadCheckpoint = loadCheckpoint;
}
}
......@@ -22,9 +22,9 @@ package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class NameValueKind extends ValueKind {
public class ConfigurationSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.NameValueKind";
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbolKind";
@Override
public String getName() {
......@@ -33,6 +33,7 @@ public class NameValueKind extends ValueKind {
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || super.isKindOf(kind);
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum EvalMetric {
ACCURACY{
@Override
public String toString() {
return "accuracy";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "crossEntropy";
}
},
F1{
@Override
public String toString() {
return "f1";
}
},
MAE{
@Override
public String toString() {
return "mae";
}
},
MSE{
@Override
public String toString() {
return "mse";
}
},
RMSE{
@Override
public String toString() {
return "rmse";
}
},
TOP_K_ACCURACY{
@Override
public String toString() {
return "topKAccuracy";
}
}
}