Skip to content
Snippets Groups Projects
Commit 8ad4df06 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Add state check and add tests for reward check

parent 3f811ad0
No related branches found
No related tags found
2 merge requests!23Added Unroll-related features and layers,!22Pipeline fix and new CNNTrain Integration
Pipeline #191050 passed
......@@ -151,7 +151,14 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
// Generate Reward function if necessary
if (configuration.getRlRewardFunction().isPresent()) {
generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
if (configuration.getTrainedArchitecture().isPresent()) {
generateRewardFunction(configuration.getTrainedArchitecture().get(),
configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
} else {
Log.error("No architecture model for the trained neural network but is required for " +
"reinforcement learning configuration.");
}
}
ftlContext.put("trainerName", trainerName);
......@@ -167,7 +174,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
return fileContentMap;
}
private void generateRewardFunction(RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
private void generateRewardFunction(NNArchitectureSymbol trainedArchitecture,
RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
GeneratorPythonWrapperStandaloneApi pythonWrapperApi = new GeneratorPythonWrapperStandaloneApi();
List<String> fullNameOfComponent = rewardFunctionSymbol.getRewardFunctionComponentName();
......@@ -200,7 +208,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
componentPortInformation = pythonWrapperApi.generate(emaSymbol, pythonWrapperOutputPath);
}
RewardFunctionParameterAdapter functionParameter = new RewardFunctionParameterAdapter(componentPortInformation);
new FunctionParameterChecker().check(functionParameter);
new FunctionParameterChecker().check(functionParameter, trainedArchitecture);
rewardFunctionSymbol.setRewardFunctionParameter(functionParameter);
}
......
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.List;
/**
*
*/
......@@ -11,21 +14,42 @@ public class FunctionParameterChecker {
private String inputTerminalParameterName;
private String outputParameterName;
private RewardFunctionParameterAdapter rewardFunctionParameter;
private NNArchitectureSymbol trainedArchitecture;
public FunctionParameterChecker() {
}
public void check(final RewardFunctionParameterAdapter rewardFunctionParameter) {
public void check(final RewardFunctionParameterAdapter rewardFunctionParameter,
final NNArchitectureSymbol trainedArchitecture) {
assert rewardFunctionParameter != null;
assert trainedArchitecture != null;
this.rewardFunctionParameter = rewardFunctionParameter;
this.trainedArchitecture = trainedArchitecture;
retrieveParameterNames();
checkHasExactlyTwoInputs();
checkHasExactlyOneOutput();
checkHasStateAndTerminalInput();
checkInputStateDimension();
checkInputTerminalTypeAndDimension();
checkStateDimensionEqualsTrainedArchitectureState();
checkInputStateDimension();
checkOutputDimension();
}
private void checkStateDimensionEqualsTrainedArchitectureState() {
failIfConditionFails(stateInputOfNNArchitectureIsEqualToRewardState(),
"State dimension of trained architecture is not equal to reward state dimensions.");
}
private boolean stateInputOfNNArchitectureIsEqualToRewardState() {
assert trainedArchitecture.getInputs().size() == 1: "Trained architecture is not a policy network.";
final String nnStateInputName = trainedArchitecture.getInputs().get(0);
final List<Integer> dimensions = trainedArchitecture.getDimensions().get(nnStateInputName);
return rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).isPresent()
&& rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).get().equals(dimensions);
}
private void checkHasExactlyTwoInputs() {
failIfConditionFails(functionHasTwoInputs(), "Reward function must have exactly two input parameters: "
+ "One input needs to represents the environment's state and another input needs to be a "
......
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortDirection;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import org.junit.Before;
import org.junit.Test;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class FunctionParameterCheckerTest {
private static final List<Integer> STATE_DIMENSIONS = Lists.newArrayList(3,2,4);
private static final PortVariable STATE_PORT = PortVariable.multidimensionalVariableFrom("input1", EmadlType.Q,
PortDirection.INPUT, STATE_DIMENSIONS);
private static final PortVariable TERMINAL_PORT = PortVariable.primitiveVariableFrom("input2", EmadlType.B,
PortDirection.INPUT);
private static final PortVariable OUTPUT_PORT = PortVariable.primitiveVariableFrom("output1", EmadlType.Q,
PortDirection.OUTPUT);
private static final String COMPONENT_NAME = "TestRewardComponent";
FunctionParameterChecker uut = new FunctionParameterChecker();
@Before
public void setup() {
Log.getFindings().clear();
Log.enableFailQuick(false);
}
@Test
public void validReward() {
// given
RewardFunctionParameterAdapter adapter = getValidRewardAdapter();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertEquals(0, findings.stream().filter(Finding::isError).count());
}
@Test
public void invalidRewardWithOneInput() {
// given
RewardFunctionParameterAdapter adapter = getComponentWithOneInput();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().anyMatch(Finding::isError));
}
@Test
public void invalidRewardWithTwoOutputs() {
// given
RewardFunctionParameterAdapter adapter = getComponentWithTwoOutputs();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().anyMatch(Finding::isError));
}
@Test
public void invalidRewardWithTerminalHasQType() {
// given
RewardFunctionParameterAdapter adapter = getComponentWithTwoQInputs();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().anyMatch(Finding::isError));
}
@Test
public void invalidRewardWithNonScalarOutput() {
// given
RewardFunctionParameterAdapter adapter = getComponentWithNonScalarOutput();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().filter(Finding::isError).count() == 1);
}
@Test
public void invalidRewardStateUnequalToTrainedArchitectureState1() {
// given
RewardFunctionParameterAdapter adapter = getValidRewardAdapter();
NNArchitectureSymbol trainedArchitectureWithDifferenDimension = getTrainedArchitectureWithStateDimensions(
Lists.newArrayList(6));
// when
uut.check(adapter, trainedArchitectureWithDifferenDimension);
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().filter(Finding::isError).count() == 1);
}
@Test
public void invalidRewardStateUnequalToTrainedArchitectureState2() {
// given
RewardFunctionParameterAdapter adapter = getValidRewardAdapter();
NNArchitectureSymbol trainedArchitectureWithDifferenDimension = getTrainedArchitectureWithStateDimensions(
Lists.newArrayList(3, 8));
// when
uut.check(adapter, trainedArchitectureWithDifferenDimension);
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().filter(Finding::isError).count() == 1);
}
@Test
public void invalidRewardStateUnequalToTrainedArchitectureState3() {
// given
RewardFunctionParameterAdapter adapter = getValidRewardAdapter();
NNArchitectureSymbol trainedArchitectureWithDifferenDimension = getTrainedArchitectureWithStateDimensions(
Lists.newArrayList(2,4,3));
// when
uut.check(adapter, trainedArchitectureWithDifferenDimension);
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().filter(Finding::isError).count() == 1);
}
private RewardFunctionParameterAdapter getComponentWithNonScalarOutput() {
ComponentPortInformation componentPortInformation = new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
List<PortVariable> outputs = Lists.newArrayList(PortVariable.multidimensionalVariableFrom(
"output", EmadlType.Q, PortDirection.OUTPUT, Lists.newArrayList(2,2)));
componentPortInformation.addAllOutputs(outputs);
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getComponentWithTwoQInputs() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
List<PortVariable> inputs = Lists.newArrayList(STATE_PORT,
PortVariable.multidimensionalVariableFrom("input2", EmadlType.Q, PortDirection.INPUT,
Lists.newArrayList(2,3,2)));
componentPortInformation.addAllInputs(inputs);
componentPortInformation.addAllOutputs(getValidOutputPorts());
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getComponentWithTwoOutputs() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
List<PortVariable> outputs = getValidOutputPorts();
outputs.add(PortVariable.primitiveVariableFrom("output2", EmadlType.B, PortDirection.OUTPUT));
componentPortInformation.addAllOutputs(outputs);
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getComponentWithOneInput() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(Lists.newArrayList(STATE_PORT));
componentPortInformation.addAllOutputs(getValidOutputPorts());
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getValidRewardAdapter() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
componentPortInformation.addAllOutputs(getValidOutputPorts());
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private List<PortVariable> getValidOutputPorts() {
return Lists.newArrayList(OUTPUT_PORT);
}
private List<PortVariable> getValidInputPortVariables() {
return Lists.newArrayList(STATE_PORT, TERMINAL_PORT);
}
private NNArchitectureSymbol getValidTrainedArchitecture() {
NNArchitectureSymbol nnArchitectureSymbol = mock(NNArchitectureSymbol.class);
final String stateInputName = "stateInput";
when(nnArchitectureSymbol.getInputs()).thenReturn(Lists.newArrayList(stateInputName));
when(nnArchitectureSymbol.getDimensions()).thenReturn(ImmutableMap.<String, List<Integer>>builder()
.put(stateInputName, STATE_DIMENSIONS)
.build());
return nnArchitectureSymbol;
}
private NNArchitectureSymbol getTrainedArchitectureWithStateDimensions(final List<Integer> dimensions) {
NNArchitectureSymbol nnArchitectureSymbol = mock(NNArchitectureSymbol.class);
final String stateInputName = "stateInput";
when(nnArchitectureSymbol.getInputs()).thenReturn(Lists.newArrayList(stateInputName));
when(nnArchitectureSymbol.getDimensions()).thenReturn(ImmutableMap.<String, List<Integer>>builder()
.put(stateInputName, dimensions)
.build());
return nnArchitectureSymbol;
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment