Commit 537212db authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Generate critic by passed architecture symbol

parent df0b9008
......@@ -29,4 +29,4 @@ public class CNNArch2GluonCli {
GenericCNNArchCli cli = new GenericCNNArchCli(generator);
cli.run(args);
}
}
}
\ No newline at end of file
......@@ -2,8 +2,8 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerationPair;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
......@@ -91,13 +91,21 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
}
public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) {
public void generate(Path modelsDirPath,
String rootModelName,
NNArchitectureSymbol trainedArchitecture,
NNArchitectureSymbol criticNetwork) {
ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName);
configurationSymbol.setTrainedArchitecture(trainedArchitecture);
configurationSymbol.setCriticNetwork(criticNetwork);
this.setRootProjectModelsDir(modelsDirPath.toString());
generateFilesFromConfigurationSymbol(configurationSymbol);
}
public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) {
generate(modelsDirPath, rootModelName, trainedArchitecture, null);
}
@Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
......@@ -119,23 +127,28 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (rlAlgorithm.equals(RLAlgorithm.DDPG)
|| rlAlgorithm.equals(RLAlgorithm.TD3)) {
CriticNetworkGenerator criticNetworkGenerator = new CriticNetworkGenerator();
criticNetworkGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
if (getRootProjectModelsDir().isPresent()) {
criticNetworkGenerator.setRootModelsDir(getRootProjectModelsDir().get());
} else {
Log.error("No root model dir set");
if (!configuration.getCriticNetwork().isPresent()) {
Log.error("No architecture model for critic available but is required for chosen " +
"actor-critic algorithm");
}
NNArchitectureSymbol genericArchitectureSymbol = configuration.getCriticNetwork().get();
final String criticComponentName = genericArchitectureSymbol.getName().replace('.', '_');
ArchitectureSymbol architectureSymbol
= ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol();
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
Map<String, String> architectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true);
CriticNetworkGenerationPair criticNetworkResult
= criticNetworkGenerator.generateCriticNetworkContent(templateConfiguration, configuration);
fileContentMap.putAll(criticNetworkResult.getFileContent().entrySet().stream().collect(Collectors.toMap(
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("criticInstanceName", criticNetworkResult.getCriticNetworkName());
ftlContext.put("criticInstanceName", criticComponentName);
}
ftlContext.put("trainerName", trainerName);
......
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic;
public class CriticNetworkGenerationException extends RuntimeException {
public CriticNetworkGenerationException(String s) {
super("Generation of critic network failed: " + s);
}
}
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic;
import java.util.Map;
public class CriticNetworkGenerationPair {
private String criticNetworkName;
private Map<String, String> fileContent;
public CriticNetworkGenerationPair(String criticNetworkName, Map<String, String> fileContent) {
this.criticNetworkName = criticNetworkName;
this.fileContent = fileContent;
}
public String getCriticNetworkName() {
return criticNetworkName;
}
public Map<String, String> getFileContent() {
return fileContent;
}
}
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonArchitectureSupportChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonLayerSupportChecker;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchSymbolCompiler;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import static com.google.common.base.Preconditions.checkState;
public class CriticNetworkGenerator {
private static final String START_SEQUENCE = "implementationCritic(state,action)";
private String generationTargetPath;
private String rootModelsDir;
protected String getGenerationTargetPath() {
return generationTargetPath;
}
public void setGenerationTargetPath(String getGenerationTargetPath) {
this.generationTargetPath = getGenerationTargetPath;
}
protected String getRootModelsDir() {
return rootModelsDir;
}
public void setRootModelsDir(String rootModelsDir) {
this.rootModelsDir = rootModelsDir;
}
public CriticNetworkGenerationPair generateCriticNetworkContent(TemplateConfiguration templateConfiguration,
ConfigurationSymbol configurationSymbol) {
checkState(getRootModelsDir() != null, "Root project directory is not set");
checkState(getGenerationTargetPath() != null, "Target path is not set");
failIfArchitectureNotAvailable(configurationSymbol);
assert configurationSymbol.getTrainedArchitecture().isPresent();
NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
failIfActorHasMultipleIO(trainedArchitecture);
List<String> criticNetwork = retrieveFullNameOfCriticsNetworkFromConfiguration(configurationSymbol);
final String criticNetworkName = criticNetwork.get(criticNetwork.size()-1);
final Path pathTocriticNetworkFile = retrievePathToCriticNetworkFileFromFullName(criticNetwork, criticNetworkName);
final String networkImplementation = parseNetworkImplementationFromFile(pathTocriticNetworkFile);
Map<String, Object> context = makeTemplateContextMap(trainedArchitecture, criticNetworkName, networkImplementation);
Path directoryOfCnnArchFile = makeCnnArchFileFromContext(templateConfiguration, context);
Map<String, String> fileContentMap = generatePythonNetworkFiles(criticNetworkName, directoryOfCnnArchFile);
deleteOutputDirectory(directoryOfCnnArchFile);
return new CriticNetworkGenerationPair(criticNetworkName, fileContentMap);
}
private void deleteOutputDirectory(Path directoryOfCnnArchFile) {
try {
FileUtils.deleteDirectory(directoryOfCnnArchFile.toFile());
} catch (IOException e) {
Log.warn("Cannot delete temporary CNN arch directory: " + directoryOfCnnArchFile.toString());
}
}
private Map<String, String> generatePythonNetworkFiles(String criticNetworkName, Path directoryOfCnnArchFile) {
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(this.getGenerationTargetPath());
Map<String, String> fileContentMap = new HashMap<>();
CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new CNNArch2GluonArchitectureSupportChecker(),
new CNNArch2GluonLayerSupportChecker());
ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbolFromModelsDir(directoryOfCnnArchFile, criticNetworkName);
architectureSymbol.setComponentName(criticNetworkName);
fileContentMap.putAll(gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true));
return fileContentMap;
}
private Path makeCnnArchFileFromContext(TemplateConfiguration templateConfiguration, Map<String, Object> context) {
final String architectureContent = templateConfiguration.processTemplate(
context, "reinforcement/architecture/CriticArchitecture.ftl");
Path tmpDirectoryToArchFile = Paths.get(this.getGenerationTargetPath(), "tmp");
try {
if (!tmpDirectoryToArchFile.toFile().exists()) {
Files.createDirectories(tmpDirectoryToArchFile);
}
Files.write(
Paths.get(tmpDirectoryToArchFile.toString(), context.get("architectureName") + ".cnna"), architectureContent.getBytes());
} catch (IOException e) {
failWithMessage(e.getMessage());
}
return tmpDirectoryToArchFile;
}
private Map<String, Object> makeTemplateContextMap(NNArchitectureSymbol trainedArchitecture, String criticNetworkName, String networkImplementation) {
final String stateName = trainedArchitecture.getInputs().get(0);
final String actionName = trainedArchitecture.getOutputs().get(0);
Map<String, List<Integer>> dimensions = trainedArchitecture.getDimensions();
List<Integer> stateDimensions = dimensions.get(stateName);
List<Integer> actionDimensions = dimensions.get(actionName);
Map<String, Range> ranges = trainedArchitecture.getRanges();
Range stateRange = ranges.get(stateName);
Range actionRange = ranges.get(actionName);
Map<String, String> types = trainedArchitecture.getTypes();
final String stateType = types.get(stateName);
final String actionType = types.get(actionName);
Map<String, Object> context = new HashMap<>();
context.put("stateDimension", stateDimensions);
context.put("actionDimension", actionDimensions);
context.put("stateRange", stateRange);
context.put("actionRange", actionRange);
context.put("stateType", stateType);
context.put("actionType", actionType);
context.put("implementation", networkImplementation);
context.put("architectureName", criticNetworkName);
return context;
}
private String parseNetworkImplementationFromFile(Path criticNetworkFile) {
String criticNetworkFileContent = null;
try {
criticNetworkFileContent = new String(Files.readAllBytes(criticNetworkFile), Charset.forName("UTF-8"));
} catch (IOException e) {
failWithMessage("Cannot create critic network file:" + e.getMessage());
}
String contentWhiteSpaceRemoved = criticNetworkFileContent.replaceAll("\\s+","");
if (!contentWhiteSpaceRemoved.contains(START_SEQUENCE)
|| StringUtils.countMatches(contentWhiteSpaceRemoved, "{") != 1
|| StringUtils.countMatches(contentWhiteSpaceRemoved, "}") != 1
|| contentWhiteSpaceRemoved.charAt(contentWhiteSpaceRemoved.length() - 1) != '}') {
failWithMessage("Cannot parse critic file");
}
final int startOfNNImplementation = contentWhiteSpaceRemoved.indexOf("{") + 1;
final int endOfNNImplementation = contentWhiteSpaceRemoved.indexOf("}");
return contentWhiteSpaceRemoved.substring(startOfNNImplementation, endOfNNImplementation);
}
private Path retrievePathToCriticNetworkFileFromFullName(List<String> criticNetwork, String criticNetworkName) {
// Add file ending cnna to file name
criticNetwork.set(criticNetwork.size()-1, criticNetworkName + ".cnna");
Path root = Paths.get(this.getRootModelsDir());
Path criticNetworkFile = criticNetwork.stream().map(Paths::get).reduce(root, Path::resolve);
if (!criticNetworkFile.toFile().exists()) {
failWithMessage("Critic network file does not exist in " + criticNetworkFile.toString());
}
return criticNetworkFile;
}
private List<String> retrieveFullNameOfCriticsNetworkFromConfiguration(ConfigurationSymbol configurationSymbol) {
// Load critics network
failIfConfigurationHasNoCritic(configurationSymbol);
assert configurationSymbol.getEntry("critic").getValue().getValue() instanceof String;
List<String> criticNetwork = Lists.newArrayList((
(String)configurationSymbol.getEntry("critic").getValue().getValue()).split("\\."));
// Check if file name is upper case otherwise make it upper case
int size = criticNetwork.size();
if (Character.isLowerCase(criticNetwork.get(size-1).charAt(0))) {
String lowerCaseFileName = criticNetwork.get(size-1);
String upperCaseFileName = lowerCaseFileName.substring(0,1).toUpperCase() + lowerCaseFileName.substring(1);
criticNetwork.set(size-1, upperCaseFileName);
}
return criticNetwork;
}
private void failIfConfigurationHasNoCritic(ConfigurationSymbol configurationSymbol) {
if (!configurationSymbol.getEntryMap().containsKey("critic")) {
failWithMessage("No critic network file given, but is required for selected algorithm");
}
}
private void failIfActorHasMultipleIO(NNArchitectureSymbol trainedArchitecture) {
if (trainedArchitecture.getInputs().size() > 1 || trainedArchitecture.getOutputs().size() > 1) {
failWithMessage("Actor component with multiple inputs or outputs is not supported by this generator");
}
}
private void failIfArchitectureNotAvailable(ConfigurationSymbol configurationSymbol) {
if (!configurationSymbol.getTrainedArchitecture().isPresent()) {
failWithMessage("No architecture symbol found but is required for selected algorithm");
}
}
private void failWithMessage(final String message) {
Log.error("Critic network generation failed: " + message);
throw new CriticNetworkGenerationException(message);
}
}
\ No newline at end of file
architecture ${architectureName}() {
def input ${stateType}<#if stateRange??>(<#if stateRange.isLowerLimitInfinity()>-oo<#else>${stateRange.lowerLimit.get()}</#if>:<#if stateRange.isUpperLimitInfinity()>oo<#else>${stateRange.upperLimit.get()}</#if>)</#if>^{<#list stateDimension as d>${d}<#if d?has_next>,</#if></#list>} state
def input ${actionType}<#if actionRange??>(<#if actionRange.isLowerLimitInfinity()>-oo<#else>${actionRange.lowerLimit.get()}</#if>:<#if actionRange.isUpperLimitInfinity()>oo<#else>${actionRange.upperLimit.get()}</#if>)</#if>^{<#list actionDimension as d>${d}<#if d?has_next>,</#if></#list>} action
def output Q(-oo:oo)^{1} qvalue
${implementation}->FullyConnected(units=1)->qvalue;
}
\ No newline at end of file
......@@ -268,8 +268,10 @@ public class GenerationTest extends AbstractSymtabTest {
Path modelPath = Paths.get("src/test/resources/valid_tests/ddpg");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock();
NNArchitectureSymbol criticArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel(
Paths.get("./src/test/resources/valid_tests/ddpg/comp"), "CriticNetwork");
trainGenerator.generate(modelPath, "ActorNetwork", trainedArchitecture);
trainGenerator.generate(modelPath, "ActorNetwork", trainedArchitecture, criticArchitecture);
assertTrue(Log.getFindings().stream().noneMatch(Finding::isError));
checkFilesAreEqual(
......@@ -297,8 +299,10 @@ public class GenerationTest extends AbstractSymtabTest {
Path modelPath = Paths.get("src/test/resources/valid_tests/td3");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock();
NNArchitectureSymbol criticArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel(
Paths.get("./src/test/resources/valid_tests/td3/comp"), "CriticNetwork");
trainGenerator.generate(modelPath, "TD3Config", trainedArchitecture);
trainGenerator.generate(modelPath, "TD3Config", trainedArchitecture, criticArchitecture);
assertTrue(Log.getFindings().stream().noneMatch(Finding::isError));
checkFilesAreEqual(
......@@ -326,8 +330,10 @@ public class GenerationTest extends AbstractSymtabTest {
Path modelPath = Paths.get("src/test/resources/valid_tests/ddpg-ros");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock();
NNArchitectureSymbol criticArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel(
Paths.get("./src/test/resources/valid_tests/ddpg-ros/comp"), "RosCriticNetwork");
trainGenerator.generate(modelPath, "RosActorNetwork", trainedArchitecture);
trainGenerator.generate(modelPath, "RosActorNetwork", trainedArchitecture, criticArchitecture);
assertTrue(Log.getFindings().stream().noneMatch(Finding::isError));
checkFilesAreEqual(
......
......@@ -2,9 +2,15 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator.util;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchSymbolCompiler;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonArchitectureSupportChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonLayerSupportChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
......@@ -45,4 +51,12 @@ public class NNArchitectureMockFactory {
return trainedArchitecture;
}
public static NNArchitectureSymbol createArchitectureSymbolByCNNArchModel(final Path modelsDirPath,
final String rootModel) {
CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new CNNArch2GluonArchitectureSupportChecker(),
new CNNArch2GluonLayerSupportChecker());
ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbolFromModelsDir(modelsDirPath, rootModel);
architectureSymbol.setComponentName(rootModel);
return new ArchitectureAdapter(rootModel, architectureSymbol);
}
}
implementation Critic(state, action) {
architecture RosCriticNetwork {
def input Q(-oo:oo)^{8} state
def input Q(-1:1)^{3} action
def output Q(-oo:oo)^{1} qvalues
(state ->
FullyConnected(units=300) ->
Relu() ->
......@@ -10,5 +14,7 @@ implementation Critic(state, action) {
) ->
Add() ->
FullyConnected(units=600) ->
Relu()
Relu() ->
FullyConnected(units=1) ->
qvalues;
}
\ No newline at end of file
implementation Critic(state, action) {
architecture CriticNetwork {
def input Q(-oo:oo)^{8} state
def input Q(-1:1)^{3} action
def output Q(-oo:oo)^{1} qvalues
(state ->
FullyConnected(units=300) ->
Relu() ->
......@@ -10,5 +14,7 @@ implementation Critic(state, action) {
) ->
Add() ->
FullyConnected(units=600) ->
Relu()
Relu() ->
FullyConnected(units=1) ->
qvalues;
}
\ No newline at end of file
implementation Critic(state, action) {
architecture CriticNetwork {
def input Q(-oo:oo)^{8} state
def input Q(-1:1)^{3} action
def output Q(-oo:oo)^{1} qvalues
(state ->
FullyConnected(units=300) ->
Relu() ->
......@@ -10,5 +14,7 @@ implementation Critic(state, action) {
) ->
Add() ->
FullyConnected(units=600) ->
Relu()
Relu() ->
FullyConnected(units=1) ->
qvalues;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment