Commit 8ad4df06 authored by Nicola Gatto's avatar Nicola Gatto

Add state check and add tests for reward check

parent 3f811ad0
Pipeline #191050 passed with stages
in 4 minutes and 5 seconds
......@@ -151,7 +151,14 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
// Generate Reward function if necessary
if (configuration.getRlRewardFunction().isPresent()) {
generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
if (configuration.getTrainedArchitecture().isPresent()) {
generateRewardFunction(configuration.getTrainedArchitecture().get(),
configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
} else {
Log.error("No architecture model for the trained neural network but is required for " +
"reinforcement learning configuration.");
}
}
ftlContext.put("trainerName", trainerName);
......@@ -167,7 +174,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
return fileContentMap;
}
private void generateRewardFunction(RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
private void generateRewardFunction(NNArchitectureSymbol trainedArchitecture,
RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
GeneratorPythonWrapperStandaloneApi pythonWrapperApi = new GeneratorPythonWrapperStandaloneApi();
List<String> fullNameOfComponent = rewardFunctionSymbol.getRewardFunctionComponentName();
......@@ -200,7 +208,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
componentPortInformation = pythonWrapperApi.generate(emaSymbol, pythonWrapperOutputPath);
}
RewardFunctionParameterAdapter functionParameter = new RewardFunctionParameterAdapter(componentPortInformation);
new FunctionParameterChecker().check(functionParameter);
new FunctionParameterChecker().check(functionParameter, trainedArchitecture);
rewardFunctionSymbol.setRewardFunctionParameter(functionParameter);
}
......
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.List;
/**
*
*/
......@@ -11,21 +14,42 @@ public class FunctionParameterChecker {
private String inputTerminalParameterName;
private String outputParameterName;
private RewardFunctionParameterAdapter rewardFunctionParameter;
private NNArchitectureSymbol trainedArchitecture;
public FunctionParameterChecker() {
}
public void check(final RewardFunctionParameterAdapter rewardFunctionParameter) {
public void check(final RewardFunctionParameterAdapter rewardFunctionParameter,
final NNArchitectureSymbol trainedArchitecture) {
assert rewardFunctionParameter != null;
assert trainedArchitecture != null;
this.rewardFunctionParameter = rewardFunctionParameter;
this.trainedArchitecture = trainedArchitecture;
retrieveParameterNames();
checkHasExactlyTwoInputs();
checkHasExactlyOneOutput();
checkHasStateAndTerminalInput();
checkInputStateDimension();
checkInputTerminalTypeAndDimension();
checkStateDimensionEqualsTrainedArchitectureState();
checkInputStateDimension();
checkOutputDimension();
}
private void checkStateDimensionEqualsTrainedArchitectureState() {
failIfConditionFails(stateInputOfNNArchitectureIsEqualToRewardState(),
"State dimension of trained architecture is not equal to reward state dimensions.");
}
private boolean stateInputOfNNArchitectureIsEqualToRewardState() {
assert trainedArchitecture.getInputs().size() == 1: "Trained architecture is not a policy network.";
final String nnStateInputName = trainedArchitecture.getInputs().get(0);
final List<Integer> dimensions = trainedArchitecture.getDimensions().get(nnStateInputName);
return rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).isPresent()
&& rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).get().equals(dimensions);
}
private void checkHasExactlyTwoInputs() {
failIfConditionFails(functionHasTwoInputs(), "Reward function must have exactly two input parameters: "
+ "One input needs to represents the environment's state and another input needs to be a "
......
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortDirection;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import org.junit.Before;
import org.junit.Test;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class FunctionParameterCheckerTest {
private static final List<Integer> STATE_DIMENSIONS = Lists.newArrayList(3,2,4);
private static final PortVariable STATE_PORT = PortVariable.multidimensionalVariableFrom("input1", EmadlType.Q,
PortDirection.INPUT, STATE_DIMENSIONS);
private static final PortVariable TERMINAL_PORT = PortVariable.primitiveVariableFrom("input2", EmadlType.B,
PortDirection.INPUT);
private static final PortVariable OUTPUT_PORT = PortVariable.primitiveVariableFrom("output1", EmadlType.Q,
PortDirection.OUTPUT);
private static final String COMPONENT_NAME = "TestRewardComponent";
FunctionParameterChecker uut = new FunctionParameterChecker();
@Before
public void setup() {
Log.getFindings().clear();
Log.enableFailQuick(false);
}
@Test
public void validReward() {
// given
RewardFunctionParameterAdapter adapter = getValidRewardAdapter();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertEquals(0, findings.stream().filter(Finding::isError).count());
}
@Test
public void invalidRewardWithOneInput() {
// given
RewardFunctionParameterAdapter adapter = getComponentWithOneInput();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().anyMatch(Finding::isError));
}
@Test
public void invalidRewardWithTwoOutputs() {
// given
RewardFunctionParameterAdapter adapter = getComponentWithTwoOutputs();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().anyMatch(Finding::isError));
}
@Test
public void invalidRewardWithTerminalHasQType() {
// given
RewardFunctionParameterAdapter adapter = getComponentWithTwoQInputs();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().anyMatch(Finding::isError));
}
@Test
public void invalidRewardWithNonScalarOutput() {
// given
RewardFunctionParameterAdapter adapter = getComponentWithNonScalarOutput();
// when
uut.check(adapter, getValidTrainedArchitecture());
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().filter(Finding::isError).count() == 1);
}
@Test
public void invalidRewardStateUnequalToTrainedArchitectureState1() {
// given
RewardFunctionParameterAdapter adapter = getValidRewardAdapter();
NNArchitectureSymbol trainedArchitectureWithDifferenDimension = getTrainedArchitectureWithStateDimensions(
Lists.newArrayList(6));
// when
uut.check(adapter, trainedArchitectureWithDifferenDimension);
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().filter(Finding::isError).count() == 1);
}
@Test
public void invalidRewardStateUnequalToTrainedArchitectureState2() {
// given
RewardFunctionParameterAdapter adapter = getValidRewardAdapter();
NNArchitectureSymbol trainedArchitectureWithDifferenDimension = getTrainedArchitectureWithStateDimensions(
Lists.newArrayList(3, 8));
// when
uut.check(adapter, trainedArchitectureWithDifferenDimension);
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().filter(Finding::isError).count() == 1);
}
@Test
public void invalidRewardStateUnequalToTrainedArchitectureState3() {
// given
RewardFunctionParameterAdapter adapter = getValidRewardAdapter();
NNArchitectureSymbol trainedArchitectureWithDifferenDimension = getTrainedArchitectureWithStateDimensions(
Lists.newArrayList(2,4,3));
// when
uut.check(adapter, trainedArchitectureWithDifferenDimension);
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().filter(Finding::isError).count() == 1);
}
private RewardFunctionParameterAdapter getComponentWithNonScalarOutput() {
ComponentPortInformation componentPortInformation = new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
List<PortVariable> outputs = Lists.newArrayList(PortVariable.multidimensionalVariableFrom(
"output", EmadlType.Q, PortDirection.OUTPUT, Lists.newArrayList(2,2)));
componentPortInformation.addAllOutputs(outputs);
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getComponentWithTwoQInputs() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
List<PortVariable> inputs = Lists.newArrayList(STATE_PORT,
PortVariable.multidimensionalVariableFrom("input2", EmadlType.Q, PortDirection.INPUT,
Lists.newArrayList(2,3,2)));
componentPortInformation.addAllInputs(inputs);
componentPortInformation.addAllOutputs(getValidOutputPorts());
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getComponentWithTwoOutputs() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
List<PortVariable> outputs = getValidOutputPorts();
outputs.add(PortVariable.primitiveVariableFrom("output2", EmadlType.B, PortDirection.OUTPUT));
componentPortInformation.addAllOutputs(outputs);
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getComponentWithOneInput() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(Lists.newArrayList(STATE_PORT));
componentPortInformation.addAllOutputs(getValidOutputPorts());
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getValidRewardAdapter() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
componentPortInformation.addAllOutputs(getValidOutputPorts());
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private List<PortVariable> getValidOutputPorts() {
return Lists.newArrayList(OUTPUT_PORT);
}
private List<PortVariable> getValidInputPortVariables() {
return Lists.newArrayList(STATE_PORT, TERMINAL_PORT);
}
private NNArchitectureSymbol getValidTrainedArchitecture() {
NNArchitectureSymbol nnArchitectureSymbol = mock(NNArchitectureSymbol.class);
final String stateInputName = "stateInput";
when(nnArchitectureSymbol.getInputs()).thenReturn(Lists.newArrayList(stateInputName));
when(nnArchitectureSymbol.getDimensions()).thenReturn(ImmutableMap.<String, List<Integer>>builder()
.put(stateInputName, STATE_DIMENSIONS)
.build());
return nnArchitectureSymbol;
}
private NNArchitectureSymbol getTrainedArchitectureWithStateDimensions(final List<Integer> dimensions) {
NNArchitectureSymbol nnArchitectureSymbol = mock(NNArchitectureSymbol.class);
final String stateInputName = "stateInput";
when(nnArchitectureSymbol.getInputs()).thenReturn(Lists.newArrayList(stateInputName));
when(nnArchitectureSymbol.getDimensions()).thenReturn(ImmutableMap.<String, List<Integer>>builder()
.put(stateInputName, dimensions)
.build());
return nnArchitectureSymbol;
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment