diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java index d16e68cd28e7b04f77fff79a06f67da9ca929531..b16a14766cb32edc42a4a7dbc9c69747d98094dd 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java @@ -29,4 +29,4 @@ public class CNNArch2GluonCli { GenericCNNArchCli cli = new GenericCNNArchCli(generator); cli.run(args); } -} +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java index bec19e8ded17b8d0764a1906ecfcabbb9a60323a..86845b9f52a31d343c9f3e8fc2c5c86ea8744b67 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java @@ -2,8 +2,8 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator; import com.google.common.collect.Maps; import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol; -import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerationPair; -import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerator; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; +import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter; import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker; import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter; import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator; @@ -91,13 +91,21 @@ public class CNNTrain2Gluon extends CNNTrainGenerator { } } - public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) { + public void generate(Path modelsDirPath, + String rootModelName, + NNArchitectureSymbol trainedArchitecture, + NNArchitectureSymbol criticNetwork) { ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName); configurationSymbol.setTrainedArchitecture(trainedArchitecture); + configurationSymbol.setCriticNetwork(criticNetwork); this.setRootProjectModelsDir(modelsDirPath.toString()); generateFilesFromConfigurationSymbol(configurationSymbol); } + public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) { + generate(modelsDirPath, rootModelName, trainedArchitecture, null); + } + @Override public Map<String, String> generateStrings(ConfigurationSymbol configuration) { TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration(); @@ -119,23 +127,28 @@ public class CNNTrain2Gluon extends CNNTrainGenerator { if (rlAlgorithm.equals(RLAlgorithm.DDPG) || rlAlgorithm.equals(RLAlgorithm.TD3)) { - CriticNetworkGenerator criticNetworkGenerator = new CriticNetworkGenerator(); - criticNetworkGenerator.setGenerationTargetPath( - Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString()); - if (getRootProjectModelsDir().isPresent()) { - criticNetworkGenerator.setRootModelsDir(getRootProjectModelsDir().get()); - } else { - Log.error("No root model dir set"); + + if (!configuration.getCriticNetwork().isPresent()) { + Log.error("No architecture model for critic available but is required for chosen " + + "actor-critic algorithm"); } + NNArchitectureSymbol genericArchitectureSymbol = configuration.getCriticNetwork().get(); + final String criticComponentName = genericArchitectureSymbol.getName().replace('.', '_'); + ArchitectureSymbol architectureSymbol + = ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol(); + + CNNArch2Gluon gluonGenerator = new CNNArch2Gluon(); + gluonGenerator.setGenerationTargetPath( + Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString()); + Map<String, String> architectureFileContentMap + = gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true); - CriticNetworkGenerationPair criticNetworkResult - = criticNetworkGenerator.generateCriticNetworkContent(templateConfiguration, configuration); - fileContentMap.putAll(criticNetworkResult.getFileContent().entrySet().stream().collect(Collectors.toMap( + fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap( k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(), Map.Entry::getValue)) ); - ftlContext.put("criticInstanceName", criticNetworkResult.getCriticNetworkName()); + ftlContext.put("criticInstanceName", criticComponentName); } ftlContext.put("trainerName", trainerName); diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerationException.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerationException.java deleted file mode 100644 index 70389b6fe81945552c177fa614315446a870a438..0000000000000000000000000000000000000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerationException.java +++ /dev/null @@ -1,7 +0,0 @@ -package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic; - -public class CriticNetworkGenerationException extends RuntimeException { - public CriticNetworkGenerationException(String s) { - super("Generation of critic network failed: " + s); - } -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerationPair.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerationPair.java deleted file mode 100644 index 8ea2d62fb68d1f356e7b357d9ed71e1f2fed9ead..0000000000000000000000000000000000000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerationPair.java +++ /dev/null @@ -1,21 +0,0 @@ -package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic; - -import java.util.Map; - -public class CriticNetworkGenerationPair { - private String criticNetworkName; - private Map<String, String> fileContent; - - public CriticNetworkGenerationPair(String criticNetworkName, Map<String, String> fileContent) { - this.criticNetworkName = criticNetworkName; - this.fileContent = fileContent; - } - - public String getCriticNetworkName() { - return criticNetworkName; - } - - public Map<String, String> getFileContent() { - return fileContent; - } -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerator.java deleted file mode 100644 index e524c3afe9c9d060baa2a5d92a3ff2c6cc843460..0000000000000000000000000000000000000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/critic/CriticNetworkGenerator.java +++ /dev/null @@ -1,213 +0,0 @@ -package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic; - -import com.google.common.collect.Lists; -import de.monticore.lang.monticar.cnnarch._symboltable.*; -import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2Gluon; -import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonArchitectureSupportChecker; -import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonLayerSupportChecker; -import de.monticore.lang.monticar.cnnarch.generator.CNNArchSymbolCompiler; -import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration; -import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; -import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; -import de.monticore.lang.monticar.cnntrain.annotations.Range; -import de.se_rwth.commons.logging.Log; -import org.apache.commons.io.FileUtils; -import org.apache.commons.lang3.StringUtils; - -import java.io.IOException; -import java.nio.charset.Charset; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.*; - -import static com.google.common.base.Preconditions.checkState; - -public class CriticNetworkGenerator { - private static final String START_SEQUENCE = "implementationCritic(state,action)"; - - private String generationTargetPath; - private String rootModelsDir; - - protected String getGenerationTargetPath() { - return generationTargetPath; - } - - public void setGenerationTargetPath(String getGenerationTargetPath) { - this.generationTargetPath = getGenerationTargetPath; - } - - protected String getRootModelsDir() { - return rootModelsDir; - } - - public void setRootModelsDir(String rootModelsDir) { - this.rootModelsDir = rootModelsDir; - } - - public CriticNetworkGenerationPair generateCriticNetworkContent(TemplateConfiguration templateConfiguration, - ConfigurationSymbol configurationSymbol) { - checkState(getRootModelsDir() != null, "Root project directory is not set"); - checkState(getGenerationTargetPath() != null, "Target path is not set"); - - failIfArchitectureNotAvailable(configurationSymbol); - assert configurationSymbol.getTrainedArchitecture().isPresent(); - NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get(); - failIfActorHasMultipleIO(trainedArchitecture); - - List<String> criticNetwork = retrieveFullNameOfCriticsNetworkFromConfiguration(configurationSymbol); - final String criticNetworkName = criticNetwork.get(criticNetwork.size()-1); - final Path pathTocriticNetworkFile = retrievePathToCriticNetworkFileFromFullName(criticNetwork, criticNetworkName); - final String networkImplementation = parseNetworkImplementationFromFile(pathTocriticNetworkFile); - - Map<String, Object> context = makeTemplateContextMap(trainedArchitecture, criticNetworkName, networkImplementation); - Path directoryOfCnnArchFile = makeCnnArchFileFromContext(templateConfiguration, context); - - Map<String, String> fileContentMap = generatePythonNetworkFiles(criticNetworkName, directoryOfCnnArchFile); - deleteOutputDirectory(directoryOfCnnArchFile); - - return new CriticNetworkGenerationPair(criticNetworkName, fileContentMap); - } - - private void deleteOutputDirectory(Path directoryOfCnnArchFile) { - try { - FileUtils.deleteDirectory(directoryOfCnnArchFile.toFile()); - } catch (IOException e) { - Log.warn("Cannot delete temporary CNN arch directory: " + directoryOfCnnArchFile.toString()); - } - } - - private Map<String, String> generatePythonNetworkFiles(String criticNetworkName, Path directoryOfCnnArchFile) { - CNNArch2Gluon gluonGenerator = new CNNArch2Gluon(); - gluonGenerator.setGenerationTargetPath(this.getGenerationTargetPath()); - - Map<String, String> fileContentMap = new HashMap<>(); - CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new CNNArch2GluonArchitectureSupportChecker(), - new CNNArch2GluonLayerSupportChecker()); - ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbolFromModelsDir(directoryOfCnnArchFile, criticNetworkName); - architectureSymbol.setComponentName(criticNetworkName); - fileContentMap.putAll(gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true)); - return fileContentMap; - } - - private Path makeCnnArchFileFromContext(TemplateConfiguration templateConfiguration, Map<String, Object> context) { - final String architectureContent = templateConfiguration.processTemplate( - context, "reinforcement/architecture/CriticArchitecture.ftl"); - - Path tmpDirectoryToArchFile = Paths.get(this.getGenerationTargetPath(), "tmp"); - try { - if (!tmpDirectoryToArchFile.toFile().exists()) { - Files.createDirectories(tmpDirectoryToArchFile); - } - Files.write( - Paths.get(tmpDirectoryToArchFile.toString(), context.get("architectureName") + ".cnna"), architectureContent.getBytes()); - } catch (IOException e) { - failWithMessage(e.getMessage()); - } - return tmpDirectoryToArchFile; - } - - private Map<String, Object> makeTemplateContextMap(NNArchitectureSymbol trainedArchitecture, String criticNetworkName, String networkImplementation) { - final String stateName = trainedArchitecture.getInputs().get(0); - final String actionName = trainedArchitecture.getOutputs().get(0); - - Map<String, List<Integer>> dimensions = trainedArchitecture.getDimensions(); - List<Integer> stateDimensions = dimensions.get(stateName); - List<Integer> actionDimensions = dimensions.get(actionName); - - Map<String, Range> ranges = trainedArchitecture.getRanges(); - Range stateRange = ranges.get(stateName); - Range actionRange = ranges.get(actionName); - - Map<String, String> types = trainedArchitecture.getTypes(); - final String stateType = types.get(stateName); - final String actionType = types.get(actionName); - - Map<String, Object> context = new HashMap<>(); - context.put("stateDimension", stateDimensions); - context.put("actionDimension", actionDimensions); - context.put("stateRange", stateRange); - context.put("actionRange", actionRange); - context.put("stateType", stateType); - context.put("actionType", actionType); - context.put("implementation", networkImplementation); - context.put("architectureName", criticNetworkName); - return context; - } - - private String parseNetworkImplementationFromFile(Path criticNetworkFile) { - String criticNetworkFileContent = null; - try { - criticNetworkFileContent = new String(Files.readAllBytes(criticNetworkFile), Charset.forName("UTF-8")); - } catch (IOException e) { - failWithMessage("Cannot create critic network file:" + e.getMessage()); - } - - String contentWhiteSpaceRemoved = criticNetworkFileContent.replaceAll("\\s+",""); - - if (!contentWhiteSpaceRemoved.contains(START_SEQUENCE) - || StringUtils.countMatches(contentWhiteSpaceRemoved, "{") != 1 - || StringUtils.countMatches(contentWhiteSpaceRemoved, "}") != 1 - || contentWhiteSpaceRemoved.charAt(contentWhiteSpaceRemoved.length() - 1) != '}') { - failWithMessage("Cannot parse critic file"); - } - - final int startOfNNImplementation = contentWhiteSpaceRemoved.indexOf("{") + 1; - final int endOfNNImplementation = contentWhiteSpaceRemoved.indexOf("}"); - return contentWhiteSpaceRemoved.substring(startOfNNImplementation, endOfNNImplementation); - } - - private Path retrievePathToCriticNetworkFileFromFullName(List<String> criticNetwork, String criticNetworkName) { - // Add file ending cnna to file name - criticNetwork.set(criticNetwork.size()-1, criticNetworkName + ".cnna"); - - Path root = Paths.get(this.getRootModelsDir()); - Path criticNetworkFile = criticNetwork.stream().map(Paths::get).reduce(root, Path::resolve); - - if (!criticNetworkFile.toFile().exists()) { - failWithMessage("Critic network file does not exist in " + criticNetworkFile.toString()); - } - return criticNetworkFile; - } - - private List<String> retrieveFullNameOfCriticsNetworkFromConfiguration(ConfigurationSymbol configurationSymbol) { - // Load critics network - failIfConfigurationHasNoCritic(configurationSymbol); - - assert configurationSymbol.getEntry("critic").getValue().getValue() instanceof String; - List<String> criticNetwork = Lists.newArrayList(( - (String)configurationSymbol.getEntry("critic").getValue().getValue()).split("\\.")); - - // Check if file name is upper case otherwise make it upper case - int size = criticNetwork.size(); - if (Character.isLowerCase(criticNetwork.get(size-1).charAt(0))) { - String lowerCaseFileName = criticNetwork.get(size-1); - String upperCaseFileName = lowerCaseFileName.substring(0,1).toUpperCase() + lowerCaseFileName.substring(1); - criticNetwork.set(size-1, upperCaseFileName); - } - return criticNetwork; - } - - private void failIfConfigurationHasNoCritic(ConfigurationSymbol configurationSymbol) { - if (!configurationSymbol.getEntryMap().containsKey("critic")) { - failWithMessage("No critic network file given, but is required for selected algorithm"); - } - } - - private void failIfActorHasMultipleIO(NNArchitectureSymbol trainedArchitecture) { - if (trainedArchitecture.getInputs().size() > 1 || trainedArchitecture.getOutputs().size() > 1) { - failWithMessage("Actor component with multiple inputs or outputs is not supported by this generator"); - } - } - - private void failIfArchitectureNotAvailable(ConfigurationSymbol configurationSymbol) { - if (!configurationSymbol.getTrainedArchitecture().isPresent()) { - failWithMessage("No architecture symbol found but is required for selected algorithm"); - } - } - - private void failWithMessage(final String message) { - Log.error("Critic network generation failed: " + message); - throw new CriticNetworkGenerationException(message); - } -} \ No newline at end of file diff --git a/src/main/resources/templates/gluon/reinforcement/architecture/CriticArchitecture.ftl b/src/main/resources/templates/gluon/reinforcement/architecture/CriticArchitecture.ftl deleted file mode 100644 index c923c08b41ca0b4b90c25ba58712d649eb0819f9..0000000000000000000000000000000000000000 --- a/src/main/resources/templates/gluon/reinforcement/architecture/CriticArchitecture.ftl +++ /dev/null @@ -1,7 +0,0 @@ -architecture ${architectureName}() { - def input ${stateType}<#if stateRange??>(<#if stateRange.isLowerLimitInfinity()>-oo<#else>${stateRange.lowerLimit.get()}</#if>:<#if stateRange.isUpperLimitInfinity()>oo<#else>${stateRange.upperLimit.get()}</#if>)</#if>^{<#list stateDimension as d>${d}<#if d?has_next>,</#if></#list>} state - def input ${actionType}<#if actionRange??>(<#if actionRange.isLowerLimitInfinity()>-oo<#else>${actionRange.lowerLimit.get()}</#if>:<#if actionRange.isUpperLimitInfinity()>oo<#else>${actionRange.upperLimit.get()}</#if>)</#if>^{<#list actionDimension as d>${d}<#if d?has_next>,</#if></#list>} action - def output Q(-oo:oo)^{1} qvalue - - ${implementation}->FullyConnected(units=1)->qvalue; -} \ No newline at end of file diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java index 2c02d468fd5e6babdb238aa8bd2dfa2acd8a673b..d3290488acd0cb11f7c6d8215998e9547915ca5f 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java @@ -268,8 +268,10 @@ public class GenerationTest extends AbstractSymtabTest { Path modelPath = Paths.get("src/test/resources/valid_tests/ddpg"); CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator); NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock(); + NNArchitectureSymbol criticArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel( + Paths.get("./src/test/resources/valid_tests/ddpg/comp"), "CriticNetwork"); - trainGenerator.generate(modelPath, "ActorNetwork", trainedArchitecture); + trainGenerator.generate(modelPath, "ActorNetwork", trainedArchitecture, criticArchitecture); assertTrue(Log.getFindings().stream().noneMatch(Finding::isError)); checkFilesAreEqual( @@ -297,8 +299,10 @@ public class GenerationTest extends AbstractSymtabTest { Path modelPath = Paths.get("src/test/resources/valid_tests/td3"); CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator); NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock(); + NNArchitectureSymbol criticArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel( + Paths.get("./src/test/resources/valid_tests/td3/comp"), "CriticNetwork"); - trainGenerator.generate(modelPath, "TD3Config", trainedArchitecture); + trainGenerator.generate(modelPath, "TD3Config", trainedArchitecture, criticArchitecture); assertTrue(Log.getFindings().stream().noneMatch(Finding::isError)); checkFilesAreEqual( @@ -326,8 +330,10 @@ public class GenerationTest extends AbstractSymtabTest { Path modelPath = Paths.get("src/test/resources/valid_tests/ddpg-ros"); CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator); NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock(); + NNArchitectureSymbol criticArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel( + Paths.get("./src/test/resources/valid_tests/ddpg-ros/comp"), "RosCriticNetwork"); - trainGenerator.generate(modelPath, "RosActorNetwork", trainedArchitecture); + trainGenerator.generate(modelPath, "RosActorNetwork", trainedArchitecture, criticArchitecture); assertTrue(Log.getFindings().stream().noneMatch(Finding::isError)); checkFilesAreEqual( diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/util/NNArchitectureMockFactory.java b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/util/NNArchitectureMockFactory.java index 4fe2fe594ae779326f618f378dc020ec04789b2b..0a690724710466026af29da9920016f8a50ca9fd 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/util/NNArchitectureMockFactory.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/util/NNArchitectureMockFactory.java @@ -2,9 +2,15 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator.util; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; +import de.monticore.lang.monticar.cnnarch.generator.CNNArchSymbolCompiler; +import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonArchitectureSupportChecker; +import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonLayerSupportChecker; +import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter; import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; import de.monticore.lang.monticar.cnntrain.annotations.Range; +import java.nio.file.Path; import java.util.List; import java.util.Map; @@ -45,4 +51,12 @@ public class NNArchitectureMockFactory { return trainedArchitecture; } + public static NNArchitectureSymbol createArchitectureSymbolByCNNArchModel(final Path modelsDirPath, + final String rootModel) { + CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new CNNArch2GluonArchitectureSupportChecker(), + new CNNArch2GluonLayerSupportChecker()); + ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbolFromModelsDir(modelsDirPath, rootModel); + architectureSymbol.setComponentName(rootModel); + return new ArchitectureAdapter(rootModel, architectureSymbol); + } } diff --git a/src/test/resources/valid_tests/ddpg-ros/comp/RosCriticNetwork.cnna b/src/test/resources/valid_tests/ddpg-ros/comp/RosCriticNetwork.cnna index 8c9363e72ccfb9d36457de656b32e2c31beb120f..a17088e94003c69909d4ed8953c113979eb2f88f 100644 --- a/src/test/resources/valid_tests/ddpg-ros/comp/RosCriticNetwork.cnna +++ b/src/test/resources/valid_tests/ddpg-ros/comp/RosCriticNetwork.cnna @@ -1,4 +1,8 @@ -implementation Critic(state, action) { +architecture RosCriticNetwork { + def input Q(-oo:oo)^{8} state + def input Q(-1:1)^{3} action + def output Q(-oo:oo)^{1} qvalues + (state -> FullyConnected(units=300) -> Relu() -> @@ -10,5 +14,7 @@ implementation Critic(state, action) { ) -> Add() -> FullyConnected(units=600) -> - Relu() + Relu() -> + FullyConnected(units=1) -> + qvalues; } \ No newline at end of file diff --git a/src/test/resources/valid_tests/ddpg/comp/CriticNetwork.cnna b/src/test/resources/valid_tests/ddpg/comp/CriticNetwork.cnna index 8c9363e72ccfb9d36457de656b32e2c31beb120f..17a50cf1f5724511b804a647aa9b59937b3e3aea 100644 --- a/src/test/resources/valid_tests/ddpg/comp/CriticNetwork.cnna +++ b/src/test/resources/valid_tests/ddpg/comp/CriticNetwork.cnna @@ -1,4 +1,8 @@ -implementation Critic(state, action) { +architecture CriticNetwork { + def input Q(-oo:oo)^{8} state + def input Q(-1:1)^{3} action + def output Q(-oo:oo)^{1} qvalues + (state -> FullyConnected(units=300) -> Relu() -> @@ -10,5 +14,7 @@ implementation Critic(state, action) { ) -> Add() -> FullyConnected(units=600) -> - Relu() + Relu() -> + FullyConnected(units=1) -> + qvalues; } \ No newline at end of file diff --git a/src/test/resources/valid_tests/td3/comp/CriticNetwork.cnna b/src/test/resources/valid_tests/td3/comp/CriticNetwork.cnna index 8c9363e72ccfb9d36457de656b32e2c31beb120f..17a50cf1f5724511b804a647aa9b59937b3e3aea 100644 --- a/src/test/resources/valid_tests/td3/comp/CriticNetwork.cnna +++ b/src/test/resources/valid_tests/td3/comp/CriticNetwork.cnna @@ -1,4 +1,8 @@ -implementation Critic(state, action) { +architecture CriticNetwork { + def input Q(-oo:oo)^{8} state + def input Q(-1:1)^{3} action + def output Q(-oo:oo)^{1} qvalues + (state -> FullyConnected(units=300) -> Relu() -> @@ -10,5 +14,7 @@ implementation Critic(state, action) { ) -> Add() -> FullyConnected(units=600) -> - Relu() + Relu() -> + FullyConnected(units=1) -> + qvalues; } \ No newline at end of file