Commit df0b9008 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Adapt to NNArchitecture Symbol

parent 5492fb33
......@@ -11,11 +11,7 @@ import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
import de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
......@@ -95,7 +91,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
}
public void generate(Path modelsDirPath, String rootModelName, TrainedArchitecture trainedArchitecture) {
public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) {
ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName);
configurationSymbol.setTrainedArchitecture(trainedArchitecture);
this.setRootProjectModelsDir(modelsDirPath.toString());
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import java.util.*;
......@@ -148,7 +146,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
if (!this.getConfiguration().getTrainedArchitecture().isPresent()) {
throw new IllegalStateException("No trained architecture set");
}
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
// We allow only one input, the first one is the only input
return trainedArchitecture.getInputs().get(0);
}
......@@ -157,7 +155,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
if (!this.getConfiguration().getTrainedArchitecture().isPresent()) {
throw new IllegalStateException("No trained architecture set");
}
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
// We allow only one output, the first one is the only output
return trainedArchitecture.getOutputs().get(0);
}
......@@ -167,7 +165,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
return null;
}
final String inputName = getInputNameOfTrainedArchitecture();
TrainedArchitecture trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
return trainedArchitecture.getDimensions().get(inputName);
}
......@@ -176,7 +174,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
return null;
}
final String outputName = getOutputNameOfTrainedArchitecture();
TrainedArchitecture trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
return trainedArchitecture.getDimensions().get(outputName);
}
......@@ -195,7 +193,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
Map<String, Object> strategyParams = getMultiParamEntry(AST_ENTRY_STRATEGY, "method");
assert getConfiguration().getTrainedArchitecture().isPresent(): "Architecture not present," +
" but reinforcement training";
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
final String actionPortName = getOutputNameOfTrainedArchitecture();
Range actionRange = trainedArchitecture.getRanges().get(actionPortName);
......
......@@ -2,8 +2,8 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator.annotations;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.symboltable.CommonSymbol;
......@@ -14,11 +14,13 @@ import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkNotNull;
public class ArchitectureAdapter implements TrainedArchitecture {
public class ArchitectureAdapter extends NNArchitectureSymbol {
private ArchitectureSymbol architectureSymbol;
public ArchitectureAdapter(final ArchitectureSymbol architectureSymbol) {
public ArchitectureAdapter(final String name,
final ArchitectureSymbol architectureSymbol) {
super(name);
checkNotNull(architectureSymbol);
this.architectureSymbol = architectureSymbol;
}
......@@ -55,6 +57,10 @@ public class ArchitectureAdapter implements TrainedArchitecture {
s -> s.getDefinition().getType().getDomain().getName()));
}
public ArchitectureSymbol getArchitectureSymbol() {
return this.architectureSymbol;
}
private Range astRangeToTrainRange(final ASTRange range) {
if (range == null || (range.hasNoLowerLimit() && range.hasNoUpperLimit())) {
return Range.withInfinityLimits();
......
......@@ -8,8 +8,8 @@ import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonLayerSuppo
import de.monticore.lang.monticar.cnnarch.generator.CNNArchSymbolCompiler;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
......@@ -52,7 +52,7 @@ public class CriticNetworkGenerator {
failIfArchitectureNotAvailable(configurationSymbol);
assert configurationSymbol.getTrainedArchitecture().isPresent();
TrainedArchitecture trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
failIfActorHasMultipleIO(trainedArchitecture);
List<String> criticNetwork = retrieveFullNameOfCriticsNetworkFromConfiguration(configurationSymbol);
......@@ -107,7 +107,7 @@ public class CriticNetworkGenerator {
return tmpDirectoryToArchFile;
}
private Map<String, Object> makeTemplateContextMap(TrainedArchitecture trainedArchitecture, String criticNetworkName, String networkImplementation) {
private Map<String, Object> makeTemplateContextMap(NNArchitectureSymbol trainedArchitecture, String criticNetworkName, String networkImplementation) {
final String stateName = trainedArchitecture.getInputs().get(0);
final String actionName = trainedArchitecture.getOutputs().get(0);
......@@ -194,7 +194,7 @@ public class CriticNetworkGenerator {
}
}
private void failIfActorHasMultipleIO(TrainedArchitecture trainedArchitecture) {
private void failIfActorHasMultipleIO(NNArchitectureSymbol trainedArchitecture) {
if (trainedArchitecture.getInputs().size() > 1 || trainedArchitecture.getOutputs().size() > 1) {
failWithMessage("Actor component with multiple inputs or outputs is not supported by this generator");
}
......
......@@ -20,12 +20,9 @@
*/
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.util.TrainedArchitectureMockFactory;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.cnnarch.gluongenerator.util.NNArchitectureMockFactory;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
......@@ -37,15 +34,9 @@ import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.List;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import org.junit.contrib.java.lang.system.Assertion;
import org.junit.contrib.java.lang.system.ExpectedSystemExit;
import static junit.framework.TestCase.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class GenerationTest extends AbstractSymtabTest {
private RewardFunctionSourceGenerator rewardFunctionSourceGenerator;
......@@ -200,7 +191,7 @@ public class GenerationTest extends AbstractSymtabTest {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
TrainedArchitecture trainedArchitecture = TrainedArchitectureMockFactory.createTrainedArchitectureMock();
NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock();
trainGenerator.generate(modelPath, "ReinforcementConfig2", trainedArchitecture);
......@@ -228,7 +219,7 @@ public class GenerationTest extends AbstractSymtabTest {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
TrainedArchitecture trainedArchitecture = TrainedArchitectureMockFactory.createTrainedArchitectureMock();
NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock();
trainGenerator.generate(modelPath, "ReinforcementConfig3", trainedArchitecture);
......@@ -276,7 +267,7 @@ public class GenerationTest extends AbstractSymtabTest {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests/ddpg");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
TrainedArchitecture trainedArchitecture = TrainedArchitectureMockFactory.createTrainedArchitectureMock();
NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock();
trainGenerator.generate(modelPath, "ActorNetwork", trainedArchitecture);
......@@ -305,7 +296,7 @@ public class GenerationTest extends AbstractSymtabTest {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests/td3");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
TrainedArchitecture trainedArchitecture = TrainedArchitectureMockFactory.createTrainedArchitectureMock();
NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock();
trainGenerator.generate(modelPath, "TD3Config", trainedArchitecture);
......@@ -334,7 +325,7 @@ public class GenerationTest extends AbstractSymtabTest {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests/ddpg-ros");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
TrainedArchitecture trainedArchitecture = TrainedArchitectureMockFactory.createTrainedArchitectureMock();
NNArchitectureSymbol trainedArchitecture = NNArchitectureMockFactory.createNNArchitectureMock();
trainGenerator.generate(modelPath, "RosActorNetwork", trainedArchitecture);
......
......@@ -2,8 +2,8 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator.util;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import java.util.List;
import java.util.Map;
......@@ -11,10 +11,10 @@ import java.util.Map;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TrainedArchitectureMockFactory {
public class NNArchitectureMockFactory {
public static TrainedArchitecture createTrainedArchitectureMock() {
TrainedArchitecture trainedArchitecture = mock(TrainedArchitecture.class);
public static NNArchitectureSymbol createNNArchitectureMock() {
NNArchitectureSymbol trainedArchitecture = mock(NNArchitectureSymbol.class);
final String inputName = "state";
final String outputName = "action";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment