Commit b14b8540 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Added Unroll-related work and support for new layers

See merge request !7
parents bbe7e4a4 2ed7b07c
Pipeline #214600 passed with stages
in 4 minutes and 53 seconds
......@@ -9,15 +9,16 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>0.0.4-SNAPSHOT</version>
<version>0.0.5-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.3-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.6-SNAPSHOT</CNNTrain.version>
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -79,12 +79,6 @@ public class ArchitectureElementData {
}
public int getConstValue() {
assert getElement() instanceof ConstantSymbol;
return ((ConstantSymbol) getElement()).getExpression().getIntValue().get();
}
public List<Integer> getKernel(){
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
}
......@@ -141,6 +135,22 @@ public class ArchitectureElementData {
return getLayerSymbol().getIntValue(AllPredefinedLayers.SIZE_NAME).get();
}
public int getRepeats(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPEATS_NAME).get();
}
public int getAxis(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.AXIS_NAME).get();
}
public List<Integer> getAxes(){
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.AXES_NAME).get();
}
public List<Integer> getShape(){
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.SHAPE_NAME).get();
}
public int getLayers(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
}
......@@ -161,10 +171,6 @@ public class ArchitectureElementData {
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.FLATTEN_PARAMETER_NAME).get();
}
public List<Integer> getShape() {
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.SHAPE_NAME).get();
}
@Nullable
public String getPoolType(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
......
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
......@@ -16,7 +17,7 @@ public abstract class ArchitectureSupportChecker {
// Overload functions returning always true to enable the features
protected boolean checkMultipleStreams(ArchitectureSymbol architecture) {
if (architecture.getStreams().size() != 1) {
if (architecture.getNetworkInstructions().size() != 1) {
Log.error("This cnn architecture has multiple instructions, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
......@@ -66,7 +67,7 @@ public abstract class ArchitectureSupportChecker {
}
private boolean hasConstant(ArchitectureElementSymbol element) {
ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
ArchitectureElementSymbol resolvedElement = (ArchitectureElementSymbol) element.getResolvedThis().get();
if (resolvedElement instanceof CompositeElementSymbol) {
List<ArchitectureElementSymbol> constructedElements = ((CompositeElementSymbol) resolvedElement).getElements();
......@@ -85,8 +86,8 @@ public abstract class ArchitectureSupportChecker {
}
protected boolean checkConstants(ArchitectureSymbol architecture) {
for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
for (ArchitectureElementSymbol element : stream.getElements()) {
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
for (ArchitectureElementSymbol element : networkInstruction.getBody().getElements()) {
if (hasConstant(element)) {
Log.error("This cnn architecture has a constant, which is currently not supported by the code generator."
, architecture.getSourcePosition());
......@@ -109,8 +110,8 @@ public abstract class ArchitectureSupportChecker {
}
protected boolean checkOutputAsInput(ArchitectureSymbol architecture) {
for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
for (ArchitectureElementSymbol element : networkInstruction.getBody().getFirstAtomicElements()) {
if (element.isOutput()) {
Log.error("This cnn architecture uses an output as an input, which is currently not supported by the code generator."
, architecture.getSourcePosition());
......@@ -122,6 +123,18 @@ public abstract class ArchitectureSupportChecker {
return true;
}
protected boolean checkUnroll(ArchitectureSymbol architecture) {
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
if (networkInstruction.isUnroll()) {
Log.error("This cnn architecture uses unrolls, which are currently not supported by the code generator."
, architecture.getSourcePosition());
return false;
}
}
return true;
}
public boolean check(ArchitectureSymbol architecture) {
return checkMultipleStreams(architecture)
&& checkMultipleInputs(architecture)
......@@ -129,6 +142,7 @@ public abstract class ArchitectureSupportChecker {
&& checkMultiDimensionalOutput(architecture)
&& checkConstants(architecture)
&& checkLayerVariables(architecture)
&& checkOutputAsInput(architecture);
&& checkOutputAsInput(architecture)
&& checkUnroll(architecture);
}
}
......@@ -138,17 +138,50 @@ public abstract class CNNArchTemplateController {
for (VariableSymbol element : getArchitecture().getInputs()){
list.add(nameManager.getName(element));
}
list.removeAll(Collections.singleton(null));
return list;
}
public List<String> getArchitectureOutputs(){
List<String> list = new ArrayList<>();
for (VariableSymbol element : getArchitecture().getOutputs()){
list.add(nameManager.getName(element));
if(nameManager.getName(element) != null && !list.contains(nameManager.getName(element))) {
list.add(nameManager.getName(element));
}
}
return list;
}
public List<VariableSymbol> getArchitectureInputSymbols(){
Set<String> names = new HashSet();
List<VariableSymbol> noDuplicates = new ArrayList();
for (VariableSymbol inputs : getArchitecture().getInputs()) {
if (getName(inputs) != null && !names.contains(getName(inputs))) {
names.add(getName(inputs));
noDuplicates.add(inputs);
}
}
return noDuplicates;
}
public List<VariableSymbol> getArchitectureOutputSymbols(){
Set<String> names = new HashSet();
List<VariableSymbol> noDuplicates = new ArrayList();
for (VariableSymbol output : getArchitecture().getOutputs()) {
if (getName(output) != null && !names.contains(getName(output))) {
names.add(getName(output));
noDuplicates.add(output);
}
}
return noDuplicates;
}
public String getComponentName(){
return getArchitecture().getComponentName();
}
......
......@@ -2,6 +2,7 @@
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
import java.util.ArrayList;
import java.util.HashMap;
......@@ -61,11 +62,8 @@ public class ConfigurationData {
return getConfiguration().getEntry("context").getValue().toString();
}
public String getEvalMetric() {
if (!getConfiguration().getEntryMap().containsKey("eval_metric")) {
return null;
}
return getConfiguration().getEntry("eval_metric").getValue().toString();
public Map<String, Object> getEvalMetric() {
return getMultiParamEntry(EVAL_METRIC, "name");
}
public String getLossName() {
......@@ -130,4 +128,47 @@ public class ConfigurationData {
}
return mapToStrings;
}
public Boolean getSaveAttentionImage() {
if (!getConfiguration().getEntryMap().containsKey("save_attention_image")) {
return null;
}
return (Boolean) getConfiguration().getEntry("save_attention_image").getValue().getValue();
}
public Boolean getUseTeacherForcing() {
if (!getConfiguration().getEntryMap().containsKey("use_teacher_forcing")) {
return null;
}
return (Boolean) getConfiguration().getEntry("use_teacher_forcing").getValue().getValue();
}
protected Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
if (!configurationContainsKey(key)) {
return null;
}
Map<String, Object> resultView = new HashMap<>();
ValueSymbol value = this.getConfiguration().getEntryMap().get(key).getValue();
if (value instanceof MultiParamValueSymbol) {
MultiParamValueSymbol multiParamValue = (MultiParamValueSymbol) value;
resultView.put(valueName, multiParamValue.getValue());
resultView.putAll(multiParamValue.getParameters());
}
else {
resultView.put(valueName, value.getValue());
}
return resultView;
}
protected Boolean configurationContainsKey(final String key) {
return this.getConfiguration().getEntryMap().containsKey(key);
}
protected Object retrieveConfigurationEntryValueByKey(final String key) {
return this.getConfiguration().getEntry(key).getValue().getValue();
}
}
......@@ -11,17 +11,21 @@ import java.util.*;
public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>();
private Set<String> names = new HashSet<>();
public LayerNameCreator(ArchitectureSymbol architecture) {
int stage = 1;
for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
stage = name(stream, stage, new ArrayList<>());
}
}
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
stage = name(networkInstruction.getBody(), stage, new ArrayList<>());
if (networkInstruction.isUnroll()) {
UnrollInstructionSymbol unroll = (UnrollInstructionSymbol) networkInstruction;
public ArchitectureElementSymbol getArchitectureElement(String name){
return nameToElement.get(name);
for (SerialCompositeElementSymbol body : unroll.getResolvedBodies()) {
stage = name(body, stage, new ArrayList<>());
}
}
}
}
public String getName(ArchitectureElementSymbol architectureElement){
......@@ -31,17 +35,17 @@ public class LayerNameCreator {
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof SerialCompositeElementSymbol) {
return nameSerialComposite((SerialCompositeElementSymbol) architectureElement, stage, streamIndices);
} else if (architectureElement instanceof ParallelCompositeElementSymbol){
} else if (architectureElement instanceof ParallelCompositeElementSymbol) {
return nameParallelComposite((ParallelCompositeElementSymbol) architectureElement, stage, streamIndices);
} else{
if (architectureElement.isAtomic()){
} else {
if (architectureElement.isAtomic()) {
if (architectureElement.getMaxSerialLength().get() > 0){
return add(architectureElement, stage, streamIndices);
} else {
return stage;
}
} else {
ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get();
ArchitectureElementSymbol resolvedElement = (ArchitectureElementSymbol) architectureElement.getResolvedThis().get();
return name(resolvedElement, stage, streamIndices);
}
}
......@@ -75,24 +79,15 @@ public class LayerNameCreator {
if (!elementToName.containsKey(architectureElement)) {
String name = createName(architectureElement, endStage, streamIndices);
while (nameToElement.containsKey(name)) {
endStage++;
name = createName(architectureElement, endStage, streamIndices);
if (!(architectureElement instanceof VariableSymbol)) {
while (names.contains(name)) {
endStage++;
name = createName(architectureElement, endStage, streamIndices);
}
}
elementToName.put(architectureElement, name);
boolean isLayerVariable = false;
if (architectureElement instanceof VariableSymbol) {
isLayerVariable = ((VariableSymbol) architectureElement).getType() == VariableSymbol.Type.LAYER;
}
// Do not map names of layer variables to their respective element since the names are not unique
// for now the name to element mapping is not used anywhere so it doesn't matter
if (!isLayerVariable) {
nameToElement.put(name, architectureElement);
}
names.add(name);
}
return endStage;
}
......@@ -101,23 +96,21 @@ public class LayerNameCreator {
if (architectureElement instanceof VariableSymbol) {
VariableSymbol element = (VariableSymbol) architectureElement;
String name = createBaseName(architectureElement);
String name = createBaseName(architectureElement) + "_";
if (element.getType() == VariableSymbol.Type.IO) {
if (element.getArrayAccess().isPresent()){
int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
} else {
name = name + "_";
}
} else if (element.getType() == VariableSymbol.Type.LAYER) {
if (element.getType() == VariableSymbol.Type.LAYER) {
if (element.getMember() == VariableSymbol.Member.STATE) {
name = name + "_state_";
name = name + "state_";
} else {
name = name + "_output_";
name = name + "output_";
}
}
if (element.getArrayAccess().isPresent()){
int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + arrayAccess + "_";
}
return name;
} else {
return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_";
......@@ -153,4 +146,3 @@ public class LayerNameCreator {
return stringBuilder.toString();
}
}
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.se_rwth.commons.logging.Log;
......@@ -19,11 +18,10 @@ public abstract class LayerSupportChecker {
protected List<String> supportedLayerList = new ArrayList<>();
private boolean isSupportedLayer(ArchitectureElementSymbol element){
ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
List<ArchitectureElementSymbol> constructLayerElemList;
ArchitectureElementSymbol resolvedElement = (ArchitectureElementSymbol) element.getResolvedThis().get();
if (resolvedElement instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol) resolvedElement).getElements();
List<ArchitectureElementSymbol> constructLayerElemList = ((CompositeElementSymbol) resolvedElement).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement)) {
return false;
......@@ -63,8 +61,8 @@ public abstract class LayerSupportChecker {
}
public boolean check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
for (ArchitectureElementSymbol element : stream.getElements()) {
for (NetworkInstructionSymbol networkInstructions : architecture.getNetworkInstructions()) {
for (ArchitectureElementSymbol element : networkInstructions.getBody().getElements()) {
if (!isSupportedLayer(element)) {
return false;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment