Commit 226d1816 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'symboltable_refactoring' into 'master'

Symboltable refactoring

See merge request !6
parents 27d93b08 aedd1e5b
Pipeline #66104 passed with stages
in 2 minutes and 11 seconds
......@@ -37,7 +37,7 @@
<properties>
<!-- .. SE-Libraries .................................................. -->
<monticore.version>4.5.4.08.11.2017</monticore.version>
<monticore.version>4.5.5</monticore.version>
<se-commons.version>1.7.7</se-commons.version>
<mc.grammars.assembly.version>0.0.6</mc.grammars.assembly.version>
<SIUnit.version>0.0.10-SNAPSHOT</SIUnit.version>
......
......@@ -22,6 +22,15 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.lang.N
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
EvalMetricValue implements ConfigValue =(accuracy:"accuracy"
| crossEntropy:"cross_entropy"
| f1:"f1"
| mae:"mae"
| mse:"mse"
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
......
......@@ -103,56 +103,99 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void endVisit(ASTNumEpochEntry node) {
NumEpochSymbol symbol = new NumEpochSymbol();
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
Integer value_as_int = node.getValue().getNumber().getUnitNumber().get().getNumber().get().getDividend().intValue();
symbol.setValue(value_as_int);
addToScopeAndLinkWithNode(symbol, node);
configuration.setNumEpoch(symbol);
value.setValue(value_as_int);
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTBatchSizeEntry node) {
BatchSizeSymbol symbol = new BatchSizeSymbol();
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
Integer value_as_int = node.getValue().getNumber().getUnitNumber().get().getNumber().get().getDividend().intValue();
symbol.setValue(value_as_int);
addToScopeAndLinkWithNode(symbol, node);
configuration.setBatchSize(symbol);
value.setValue(value_as_int);
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLoadCheckpointEntry node) {
LoadCheckpointSymbol symbol = new LoadCheckpointSymbol();
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().getTRUE().isPresent()){
symbol.setValue(true);
value.setValue(true);
}
else if (node.getValue().getFALSE().isPresent()){
symbol.setValue(false);
value.setValue(false);
}
configuration.setLoadCheckpoint(symbol);
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTNormalizeEntry node) {
NormalizeSymbol symbol = new NormalizeSymbol();
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().getTRUE().isPresent()){
symbol.setValue(true);
value.setValue(true);
}
else if (node.getValue().getFALSE().isPresent()){
symbol.setValue(false);
value.setValue(false);
}
configuration.setNormalize(symbol);
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTTrainContextEntry node) {
TrainContextSymbol symbol = new TrainContextSymbol();
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().cpuIsPresent()){
symbol.setValue(node.getValue().getCpu().get());
value.setValue(Context.CPU);
}
else {
symbol.setValue(node.getValue().getGpu().get());
value.setValue(Context.CPU);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTEvalMetricEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().getAccuracy().isPresent()){
value.setValue(EvalMetric.ACCURACY);
}
else if (node.getValue().getCrossEntropy().isPresent()){
value.setValue(EvalMetric.CROSS_ENTROPY);
}
else if (node.getValue().getF1().isPresent()){
value.setValue(EvalMetric.F1);
}
else if (node.getValue().getMae().isPresent()){
value.setValue(EvalMetric.MAE);
}
else if (node.getValue().getMse().isPresent()){
value.setValue(EvalMetric.MSE);
}
else if (node.getValue().getRmse().isPresent()){
value.setValue(EvalMetric.RMSE);
}
else if (node.getValue().getTopKAccuracy().isPresent()){
value.setValue(EvalMetric.TOP_K_ACCURACY);
}
configuration.setTrainContext(symbol);
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
......
......@@ -22,16 +22,15 @@ package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private NumEpochSymbol numEpoch;
private BatchSizeSymbol batchSize;
private LoadCheckpointSymbol loadCheckpoint;
private NormalizeSymbol normalize;
private Map<String, EntrySymbol> entryMap = new HashMap<>();
private OptimizerSymbol optimizer;
private TrainContextSymbol trainContext;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
......@@ -39,14 +38,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
super("", KIND);
}
public NumEpochSymbol getNumEpoch() {
return numEpoch;
}
public void setNumEpoch(NumEpochSymbol numEpoch) {
this.numEpoch = numEpoch;
}
public OptimizerSymbol getOptimizer() {
return optimizer;
}
......@@ -55,36 +46,12 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
this.optimizer = optimizer;
}
public BatchSizeSymbol getBatchSize() {
return batchSize;
}
public void setBatchSize(BatchSizeSymbol batchSize) {
this.batchSize = batchSize;
}
public LoadCheckpointSymbol getLoadCheckpoint() {
return loadCheckpoint;
}
public void setLoadCheckpoint(LoadCheckpointSymbol loadCheckpoint) {
this.loadCheckpoint = loadCheckpoint;
}
public NormalizeSymbol getNormalize() {
return normalize;
}
public void setNormalize(NormalizeSymbol normalize) {
this.normalize = normalize;
}
public TrainContextSymbol getTrainContext() {
return trainContext;
public Map<String, EntrySymbol> getEntryMap() {
return entryMap;
}
public void setTrainContext(TrainContextSymbol trainContext) {
this.trainContext = trainContext;
public EntrySymbol getEntry(String name){
return getEntryMap().get(name);
}
}
......@@ -20,19 +20,17 @@
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class BatchSizeSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.BatchSizeSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
public enum Context {
CPU{
@Override
public String toString() {
return "cpu";
}
},
GPU{
@Override
public String toString() {
return "gpu";
}
}
}
\ No newline at end of file
}
......@@ -22,9 +22,9 @@ package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class TrainContextKind implements SymbolKind {
public class EntryKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.TrainContextKind";
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.EntryKind";
@Override
public String getName() {
......@@ -35,4 +35,4 @@ public class TrainContextKind implements SymbolKind {
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
}
......@@ -21,30 +21,21 @@
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
import java.util.Optional;
public class EntrySymbol extends CommonSymbol {
public class NumEpochSymbol extends CommonSymbol {
public static final EntryKind KIND = new EntryKind();
private ValueSymbol value;
public static final NumEpochSymbolKind KIND = new NumEpochSymbolKind();
private int value;
public NumEpochSymbol() {
super("", KIND);
public EntrySymbol(String name) {
super(name, KIND);
}
public NumEpochSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public int getValue() {
public ValueSymbol getValue() {
return value;
}
public void setValue(int value) {
public void setValue(ValueSymbol value) {
this.value = value;
}
}
......@@ -20,31 +20,47 @@
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
import java.util.Optional;
public class BatchSizeSymbol extends CommonSymbol {
public static final BatchSizeSymbolKind KIND = new BatchSizeSymbolKind();
private int value;
public BatchSizeSymbol() {
super("", KIND);
}
public BatchSizeSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public int getValue() {
return value;
public enum EvalMetric {
ACCURACY{
@Override
public String toString() {
return "accuracy";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "crossEntropy";
}
},
F1{
@Override
public String toString() {
return "f1";
}
},
MAE{
@Override
public String toString() {
return "mae";
}
},
MSE{
@Override
public String toString() {
return "mse";
}
},
RMSE{
@Override
public String toString() {
return "rmse";
}
},
TOP_K_ACCURACY{
@Override
public String toString() {
return "topKAccuracy";
}
}
public void setValue(int value) {
this.value = value;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
import java.util.Optional;
public class LoadCheckpointSymbol extends CommonSymbol {
public static final LoadCheckpointSymbolKind KIND = new LoadCheckpointSymbolKind();
private boolean value = false;
public LoadCheckpointSymbol() {
super("", KIND);
}
public LoadCheckpointSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public boolean getValue() {
return value;
}
public void setValue(boolean value) {
this.value = value;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class LoadCheckpointSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.LoadCheckpointSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
import java.util.Optional;
public class NormalizeSymbol extends CommonSymbol {
public static final NormalizeSymbolKind KIND = new NormalizeSymbolKind();
private boolean value = false;
public NormalizeSymbol() {
super("", KIND);
}
public NormalizeSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public boolean getValue() {
return value;
}
public void setValue(boolean value) {
this.value = value;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class NormalizeSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.NormalizeSymbolKind";
@Override
public String getName() {
return NAME;
}