Commit 3953ada5 authored by Julian Dierkes's avatar Julian Dierkes
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/CNNArch2Gluon into develop
parents 97795107 d6913a43
Pipeline #205139 failed with stages
in 18 seconds
......@@ -9,15 +9,16 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.8-SNAPSHOT</version>
<version>0.2.9-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.3-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.6-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.3.8-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.4-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import java.util.HashSet;
import java.util.Set;
public class AllAttentionModels {
public static Set<String> getAttentionModels() {
//List of all models that use attention and should save images of the attention over time
Set models = new HashSet();
models.add("showAttendTell.Show_attend_tell");
return models;
}
}
\ No newline at end of file
......@@ -9,6 +9,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
import de.se_rwth.commons.logging.Log;
import java.util.*;
......
......@@ -43,4 +43,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport
return true;
}
@Override
protected boolean checkUnroll(ArchitectureSymbol architecture) {
return true;
}
}
......@@ -30,6 +30,16 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.LSTM_NAME);
supportedLayerList.add(AllPredefinedLayers.GRU_NAME);
supportedLayerList.add(AllPredefinedLayers.EMBEDDING_NAME);
supportedLayerList.add(AllPredefinedLayers.ARG_MAX_NAME);
supportedLayerList.add(AllPredefinedLayers.REPEAT_NAME);
supportedLayerList.add(AllPredefinedLayers.DOT_NAME);
supportedLayerList.add(AllPredefinedLayers.EXPAND_DIMS_NAME);
supportedLayerList.add(AllPredefinedLayers.SQUEEZE_NAME);
supportedLayerList.add(AllPredefinedLayers.SWAPAXES_NAME);
supportedLayerList.add(AllPredefinedLayers.BROADCAST_MULTIPLY_NAME);
supportedLayerList.add(AllPredefinedLayers.REDUCE_SUM_NAME);
supportedLayerList.add(AllPredefinedLayers.BROADCAST_ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME);
}
}
......@@ -6,9 +6,13 @@ import de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
import java.io.Writer;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public static final String NET_DEFINITION_MODE_KEY = "mode";
......@@ -18,6 +22,8 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
super(architecture, templateConfiguration);
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer, NetDefinitionMode netDefinitionMode){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>();
......@@ -41,25 +47,15 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
else if (element.getType() == VariableSymbol.Type.LAYER) {
include(TEMPLATE_ELEMENTS_DIR_PATH, element.getLayerVariableDeclaration().getLayer().getName(), writer, netDefinitionMode);
if (element.getMember() == VariableSymbol.Member.STATE) {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer, netDefinitionMode);
} else if (element.getMember() == VariableSymbol.Member.NONE) {
include(TEMPLATE_ELEMENTS_DIR_PATH, element.getLayerVariableDeclaration().getLayer().getName(), writer, netDefinitionMode);
}
}
}
else {
include(element.getResolvedThis().get(), writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ConstantSymbol constant, Writer writer, NetDefinitionMode netDefinitionMode) {
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(constant);
if (constant.isAtomic()) {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Const", writer, netDefinitionMode);
}
else {
include(constant.getResolvedThis().get(), writer, netDefinitionMode);
include((ArchitectureElementSymbol) element.getResolvedThis().get(), writer, netDefinitionMode);
}
setCurrentElement(previousElement);
......@@ -74,7 +70,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
else {
include(layer.getResolvedThis().get(), writer, netDefinitionMode);
include((ArchitectureElementSymbol) layer.getResolvedThis().get(), writer, netDefinitionMode);
}
setCurrentElement(previousElement);
......@@ -99,7 +95,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include((LayerSymbol) architectureElement, writer, netDefinitionMode);
}
else if (architectureElement instanceof ConstantSymbol) {
include((ConstantSymbol) architectureElement, writer, netDefinitionMode);
}
else {
include((VariableSymbol) architectureElement, writer, netDefinitionMode);
......@@ -117,40 +113,152 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElement, getWriter(), netDefinitionMode);
}
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream) {
return getStreamInputs(stream).keySet();
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
return getStreamInputs(stream, outputAsArray).keySet();
}
public List<String> getUnrollInputNames(UnrollInstructionSymbol unroll, String variable) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(unroll.getBody(), true));
Map<String, String> pairs = getUnrollPairs(unroll.getBody(), unroll.getResolvedBodies().get(0), variable);
for (int i = 0; i != inputNames.size(); ++i) {
if (pairs.containsKey(inputNames.get(i))) {
inputNames.set(i, pairs.get(inputNames.get(i)));
}
}
return inputNames;
}
public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream) {
return getStreamInputs(stream).values();
return getStreamInputs(stream, false).values();
}
public String getOutputName() {
return getNameWithoutIndex(getName(getArchitectureOutputSymbols().get(0)));
}
public Set<String> getStreamOutputNames(SerialCompositeElementSymbol stream) {
public String getNameAsArray(String name) {
return name.replaceAll("([0-9]+)_$", "[$1]");
}
public String getNameWithoutIndex(String name) {
return name.replaceAll("([0-9]+)_$", "").replaceAll("\\[[^\\]]+\\]$", "");
}
public String getIndex(String name, boolean defaultToZero) {
Pattern pattern = Pattern.compile("\\[([^\\]]+)\\]$");
Matcher matcher = pattern.matcher(name);
if (matcher.find()) {
return matcher.group(1);
}
return defaultToZero ? "0" : "";
}
public Set<String> getStreamOutputNames(SerialCompositeElementSymbol stream, boolean asArray) {
Set<String> outputNames = new LinkedHashSet<>();
for (ArchitectureElementSymbol element : stream.getLastAtomicElements()) {
if (element.isOutput()) {
outputNames.add(getName(element));
String name = getName(element);
if (asArray && element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.IO) {
name = getNameAsArray(name);
}
}
outputNames.add(name);
}
}
outputNames.addAll(getStreamLayerVariableMembers(stream, "1", true).keySet());
outputNames.addAll(getStreamLayerVariableMembers(stream, true).keySet());
return outputNames;
}
public List<String> getUnrollOutputNames(UnrollInstructionSymbol unroll, String variable) {
List<String> outputNames = new LinkedList<>(getStreamOutputNames(unroll.getBody(), true));
Map<String, String> pairs = getUnrollPairs(unroll.getBody(), unroll.getResolvedBodies().get(0), variable);
for (int i = 0; i != outputNames.size(); ++i) {
if (pairs.containsKey(outputNames.get(i))) {
outputNames.set(i, pairs.get(outputNames.get(i)));
}
}
return outputNames;
}
public boolean endsWithArgmax(SerialCompositeElementSymbol stream) {
List<ArchitectureElementSymbol> elements = stream.getElements();
if (elements.size() > 1) {
// Check second last element because last element is output
ArchitectureElementSymbol secondLastElement = elements.get(elements.size() - 2);
return secondLastElement.getName().equals(AllPredefinedLayers.ARG_MAX_NAME);
}
return false;
}
// Used to initialize all layer variable members which are passed through the networks
public Map<String, List<String>> getLayerVariableMembers(String batchSize) {
public Map<String, List<String>> getLayerVariableMembers() {
Map<String, List<String>> members = new LinkedHashMap<>();
for (SerialCompositeElementSymbol stream : getArchitecture().getStreams()) {
members.putAll(getStreamLayerVariableMembers(stream, batchSize, true));
for (NetworkInstructionSymbol networkInstruction : getArchitecture().getNetworkInstructions()) {
members.putAll(getStreamLayerVariableMembers(networkInstruction.getBody(), true));
}
return members;
}
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream) {
// Calculate differently named VariableSymbol elements in two streams, currently used for the UnrollInstructionSymbol
// body which is resolved with t = CONST_OFFSET and the current body of the actual timestep t
public Map<String, String> getUnrollPairs(ArchitectureElementSymbol element, ArchitectureElementSymbol current, String variable) {
Map<String, String> pairs = new HashMap<>();
if (element instanceof CompositeElementSymbol && current instanceof CompositeElementSymbol) {
List<ArchitectureElementSymbol> elements = ((CompositeElementSymbol) element).getElements();
List<ArchitectureElementSymbol> currentElements = ((CompositeElementSymbol) current).getElements();
if (elements.size() == currentElements.size()) {
for (int i = 0; i != currentElements.size(); ++i) {
String name = getName(elements.get(i));
String currentName = getName(currentElements.get(i));
if (elements.get(i).isOutput()) {
name = getNameAsArray(name);
}
if (currentElements.get(i).isOutput()) {
currentName = getNameAsArray(currentName);
}
if (elements.get(i) instanceof VariableSymbol && currentElements.get(i) instanceof VariableSymbol) {
if (name != null && currentName != null && !name.equals(currentName)) {
String newIndex = variable + "-1+" + getIndex(currentName, true);
currentName = currentName.replaceAll("\\[([0-9]+)\\]$", "[" + newIndex + "]");
pairs.put(name, currentName);
}
}
pairs.putAll(getUnrollPairs(elements.get(i), currentElements.get(i), variable));
}
}
}
return pairs;
}
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
......@@ -162,36 +270,52 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
dimensions.add(intDimension.toString());
}
// Add batch size dimension
dimensions.add(0, "1");
String name = getName(element);
inputs.put(getName(element), dimensions);
if (outputAsArray && element.isOutput() && element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.IO) {
name = getNameAsArray(name);
}
}
inputs.put(name, dimensions);
}
else if (element instanceof ConstantSymbol) {
inputs.put(getName(element), Arrays.asList("1"));
}
}
inputs.putAll(getStreamLayerVariableMembers(stream, "1", false));
inputs.putAll(getStreamLayerVariableMembers(stream, false));
return inputs;
}
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, String batchSize, boolean includeOutput) {
Map<String, List<String>> members = new HashMap<>();
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, boolean includeOutput) {
Map<String, List<String>> members = new LinkedHashMap<>();
List<ArchitectureElementSymbol> elements = stream.getSpannedScope().resolveLocally(ArchitectureElementSymbol.KIND);
for (ArchitectureElementSymbol element : elements) {
if (element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.LAYER && variable.getMember() == VariableSymbol.Member.NONE) {
if (variable.getType() == VariableSymbol.Type.LAYER && (variable.getMember() == VariableSymbol.Member.NONE)) {
LayerVariableDeclarationSymbol layerVariableDeclaration = variable.getLayerVariableDeclaration();
if (layerVariableDeclaration.getLayer().getDeclaration().isPredefined()) {
PredefinedLayerDeclaration predefinedLayerDeclaration =
(PredefinedLayerDeclaration) layerVariableDeclaration.getLayer().getDeclaration();
if (predefinedLayerDeclaration.isValidMember(VariableSymbol.Member.STATE)) {
int arrayLength = predefinedLayerDeclaration.getArrayLength(VariableSymbol.Member.STATE);
for (int i = 0; i < arrayLength; ++i) {
String name = variable.getName() + "_state_";
if (arrayLength > 1) {
name += i + "_";
}
List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes(
layerVariableDeclaration.getLayer().getInputTypes(),
layerVariableDeclaration.getLayer(),
......@@ -204,17 +328,19 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
dimensions.add(intDimension.toString());
}
// Add batch size dimension at index 1, since RNN states in Gluon have the format
// (layers, batch_size, units)
dimensions.add(1, batchSize);
members.put(name, dimensions);
}
if (includeOutput) {
if (predefinedLayerDeclaration.isValidMember(VariableSymbol.Member.OUTPUT)) {
arrayLength = predefinedLayerDeclaration.getArrayLength(VariableSymbol.Member.OUTPUT);
for (int i = 0; i < arrayLength; ++i) {
String name = variable.getName() + "_output_";
if (arrayLength > 1) {
name += i + "_";
}
List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes(
layerVariableDeclaration.getLayer().getInputTypes(),
layerVariableDeclaration.getLayer(),
......@@ -227,9 +353,6 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
dimensions.add(intDimension.toString());
}
// Add batch size dimension at index 0, since we use NTC format for RNN output in Gluon
dimensions.add(0, batchSize);
members.put(name, dimensions);
}
}
......@@ -237,8 +360,39 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
}
return members;
}
// cuts
public List<String> cutDimensions(List<String> dimensions) {
while (dimensions.size() > 1 && dimensions.get(dimensions.size() - 1).equals("1")) {
dimensions.remove(dimensions.size() - 1);
}
return dimensions;
}
public boolean hasUnrollInstructions() {
for (NetworkInstructionSymbol networkInstruction : getArchitecture().getNetworkInstructions()) {
if (networkInstruction.isUnroll()) {
return true;
}
}
return false;
}
public boolean isAttentionNetwork(){
return AllAttentionModels.getAttentionModels().contains(getComponentName());
}
public int getBeamSearchMaxLength(UnrollInstructionSymbol unroll){
return unroll.getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get();
}
public int getBeamSearchWidth(UnrollInstructionSymbol unroll){
// Beam search with width 1 is greedy search
return unroll.getIntValue(AllPredefinedLayers.WIDTH_NAME).orElse(1);
}
}
......@@ -192,7 +192,14 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
// Generate Reward function if necessary
if (configuration.getRlRewardFunction().isPresent()) {
generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
if (configuration.getTrainedArchitecture().isPresent()) {
generateRewardFunction(configuration.getTrainedArchitecture().get(),
configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
} else {
Log.error("No architecture model for the trained neural network but is required for " +
"reinforcement learning configuration.");
}
}
ftlContext.put("trainerName", trainerName);
......@@ -208,7 +215,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
return fileContentMap;
}
private void generateRewardFunction(RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
private void generateRewardFunction(NNArchitectureSymbol trainedArchitecture,
RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
GeneratorPythonWrapperStandaloneApi pythonWrapperApi = new GeneratorPythonWrapperStandaloneApi();
List<String> fullNameOfComponent = rewardFunctionSymbol.getRewardFunctionComponentName();
......@@ -241,7 +249,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
componentPortInformation = pythonWrapperApi.generate(emaSymbol, pythonWrapperOutputPath);
}
RewardFunctionParameterAdapter functionParameter = new RewardFunctionParameterAdapter(componentPortInformation);
new FunctionParameterChecker().check(functionParameter);
new FunctionParameterChecker().check(functionParameter, trainedArchitecture);
rewardFunctionSymbol.setRewardFunctionParameter(functionParameter);
}
......
......@@ -315,29 +315,6 @@ public class GluonConfigurationData extends ConfigurationData {
return environmentParameters.containsKey(ENVIRONMENT_REWARD_TOPIC);
}
private Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
if (!configurationContainsKey(key)) {
return null;
}
Map<String, Object> resultView = new HashMap<>();
MultiParamValueSymbol multiParamValue = (MultiParamValueSymbol)this.getConfiguration().getEntryMap()
.get(key).getValue();
resultView.put(valueName, multiParamValue.getValue());
resultView.putAll(multiParamValue.getParameters());
return resultView;
}
private Boolean configurationContainsKey(final String key) {
return this.getConfiguration().getEntryMap().containsKey(key);
}
private Object retrieveConfigurationEntryValueByKey(final String key) {
return this.getConfiguration().getEntry(key).getValue().getValue();
}
private Map<String, Object> getInputParameterWithName(final String parameterName) {
if (!getRlRewardFunctionParameter().isPresent()
|| !getRlRewardFunctionParameter().get().getTypeOfInputPort(parameterName).isPresent()
......
......@@ -6,9 +6,7 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator;
*/
public enum NetDefinitionMode {
ARCHITECTURE_DEFINITION,
FORWARD_FUNCTION,
PYTHON_INLINE,
CPP_INLINE;
FORWARD_FUNCTION;
public static NetDefinitionMode fromString(final String netDefinitionMode) {
switch(netDefinitionMode) {
......@@ -16,10 +14,6 @@ public enum NetDefinitionMode {
return ARCHITECTURE_DEFINITION;
case "FORWARD_FUNCTION":
return FORWARD_FUNCTION;
case "PYTHON_INLINE":
return PYTHON_INLINE;
case "CPP_INLINE":
return CPP_INLINE;
default:
throw new IllegalArgumentException("Unknown Net Definition Mode");
}
......
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.List;
/**
*
*/
......@@ -11,21 +14,42 @@ public class FunctionParameterChecker {
private String inputTerminalParameterName;
private String outputParameterName;
private RewardFunctionParameterAdapter rewardFunctionParameter;
private NNArchitectureSymbol trainedArchitecture;