Aufgrund einer Wartung wird GitLab am 21.09. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 21.09. between 8:00 and 9:00 am.

Commit 35300ca2 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Add tests for inter cocos and minor changes

parent dc20415c
......@@ -132,6 +132,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
......
......@@ -208,7 +208,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
GymEnvironmentNameEntry implements GymEnvironmentEntry = name:"name" ":" value:StringValue;
interface RosEnvironmentEntry extends Entry;
RosEnvironmentValue implements EnvironmentValue = | name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?;
RosEnvironmentValue implements EnvironmentValue = name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?;
RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue;
RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue;
RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue;
......
......@@ -39,7 +39,8 @@ public class CNNTrainCocos {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput())
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput())
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension());
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension())
.addCoCo(new CheckTrainedArchitectureHasVectorAction());
checker.checkAll(configurationSymbol);
}
......
......@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Collection;
......@@ -57,7 +58,8 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC
String ouParameterName) {
final int ouParameterDimension = ((Collection<?>) strategyParameters.getParameter(ouParameterName)).size();
if (ouParameterDimension != actionVectorDimension) {
Log.error("Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
Log.error("0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ " Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
+ " the same dimensions as the action dimension of output "
+ outputNameOfTrainedArchitecture + " which is " + actionVectorDimension,
configurationSymbol.getSourcePosition());
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.List;
public class CheckTrainedArchitectureHasVectorAction implements CNNTrainConfigurationSymbolCoCo {
@Override
public void check(ConfigurationSymbol configurationSymbol) {
if (configurationSymbol.getTrainedArchitecture().isPresent()
&& configurationSymbol.isReinforcementLearningMethod()) {
final NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
if (trainedArchitecture.getOutputs().size() == 1) {
final String actionName = trainedArchitecture.getOutputs().get(0);
final List<Integer> actionDimensions = trainedArchitecture.getDimensions().get(actionName);
if (actionDimensions.size() != 1) {
Log.error("0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ " Output of actor network must be a vector", configurationSymbol.getSourcePosition());
}
}
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain.cocos;
import de.monticore.lang.monticar.cnntrain._cocos.*;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.junit.Test;
public class InterCocoTest extends AbstractCoCoTest {
NNArchitecturerBuilder NNBuilder = new NNArchitecturerBuilder();
@Test
public void testValidTD3ActorCritic() {
// given
final NNArchitectureSymbol validActor = NNBuilder.getValidTrainedArchitecture();
final NNArchitectureSymbol validCritic = NNBuilder.getValidCriticArchitecture();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
validActor, validCritic);
// when
checkValidTrainedArchitecture(configurationSymbol);
checkValidCriticArchitecture(configurationSymbol);
}
@Test
public void testValidDDPGActorCritic() {
// given
final NNArchitectureSymbol validActor = NNBuilder.getValidTrainedArchitecture();
final NNArchitectureSymbol validCritic = NNBuilder.getValidCriticArchitecture();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "DdpgConfig",
validActor, validCritic);
// when
checkValidTrainedArchitecture(configurationSymbol);
checkValidCriticArchitecture(configurationSymbol);
}
@Test
public void testInvalidTrainingArchitectureWithTwoInputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckTrainedRlNetworkHasExactlyOneInput();
NNArchitectureSymbol nnWithTwoInputs = NNBuilder.getTrainedArchitectureWithTwoInputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests",
"TD3Config", nnWithTwoInputs, NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidTrainingArchitectureWithTwoOutputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckTrainedRlNetworkHasExactlyOneOutput();
NNArchitectureSymbol nnWithTwoOutputs = NNBuilder.getTrainedArchitectureWithTwoOutputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests",
"TD3Config", nnWithTwoOutputs, NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(configurationSymbol, cocoUUT, new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension1() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim1", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(3, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension2() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim2", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(2, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension3() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim3", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidCriticHasNotOneDimensionalOutput() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkHasExactlyAOneDimensionalOutput();
NNArchitectureSymbol criticWithThreeDimensionalOutput = NNBuilder.getCriticWithThreeDimensionalOutput();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithThreeDimensionalOutput);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasTwoOutputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkHasExactlyAOneDimensionalOutput();
NNArchitectureSymbol criticWithThreeDimensionalOutput = NNBuilder.getCriticWithTwoOutputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithThreeDimensionalOutput);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentStateDimensions() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentDimensions = NNBuilder.getCriticWithDifferentStateDimensions();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentDimensions);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentActionDimensions() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentDimensions = NNBuilder.getCriticWithDifferentActionDimensions();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentDimensions);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentStateTypes() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentStateTypes = NNBuilder.getCriticWithDifferentStateTypes();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentStateTypes);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentActionTypes() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentActionTypes = NNBuilder.getCriticWithDifferentActionTypes();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentActionTypes);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentStateRanges() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentStateRanges = NNBuilder.getCriticWithDifferentStateRanges();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentStateRanges);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentActionRanges() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentActionRanges = NNBuilder.getCriticWithDifferentActionRanges();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentActionRanges);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidTrainedArchitectureWithMultidimensionalAction() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckTrainedArchitectureHasVectorAction();
NNArchitectureSymbol actorWithMultidimensionalAction = NNBuilder.getTrainedArchitectureWithMultidimensionalAction();
NNArchitectureSymbol criticWithMultidimensionalAction = NNBuilder.getCriticWithMultidimensionalAction();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
actorWithMultidimensionalAction, criticWithMultidimensionalAction);
// when
checkInvalidTrainedArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
private ConfigurationSymbol getConfigurationSymbolFrom(final String modelPath, final String model,
final NNArchitectureSymbol actorArchitecture, final NNArchitectureSymbol criticArchitecture) {
final ConfigurationSymbol configurationSymbol = getConfigurationSymbolByPath( modelPath, model);
configurationSymbol.setTrainedArchitecture(actorArchitecture);
configurationSymbol.setCriticNetwork(criticArchitecture);
return configurationSymbol;
}
private ConfigurationSymbol getConfigurationSymbolByPath(final String modelPath, final String model) {
return getCompilationUnitSymbol(modelPath, model).getConfiguration();
}
private enum CheckOption {
TRAINED_ARCHITECTURE_COCOS,
CRITIC_ARCHITECTURE_COCOS,
}
private void checkInvalidArchitecture(
final ConfigurationSymbol configurationSymbol,
final CNNTrainConfigurationSymbolCoCo cocoUUT,
final ExpectedErrorInfo expectedErrors,
final CheckOption checkOption) {
Log.getFindings().clear();
if (checkOption.equals(CheckOption.TRAINED_ARCHITECTURE_COCOS)) {
CNNTrainCocos.checkTrainedArchitectureCoCos(configurationSymbol);
} else {
CNNTrainCocos.checkCriticCocos(configurationSymbol);
}
expectedErrors.checkExpectedPresent(Log.getFindings(), "Got no findings when checking all "
+ "cocos. Did you forget to add the new coco to MontiArcCocos?");
Log.getFindings().clear();
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker().addCoCo(cocoUUT);
checker.checkAll(configurationSymbol);
expectedErrors.checkOnlyExpectedPresent(Log.getFindings(), "Got no findings when checking only "
+ "the given coco. Did you pass an empty coco checker?");
}
private void checkInvalidTrainedArchitecture(
final ConfigurationSymbol configurationSymbol,
final CNNTrainConfigurationSymbolCoCo cocoUUT,
ExpectedErrorInfo expectedErrors) {
checkInvalidArchitecture(configurationSymbol, cocoUUT, expectedErrors, CheckOption.TRAINED_ARCHITECTURE_COCOS);
}
private void checkInvalidCriticArchitecture(
final ConfigurationSymbol configurationSymbol,
final CNNTrainConfigurationSymbolCoCo cocoUUT,
ExpectedErrorInfo expectedErrors) {
checkInvalidArchitecture(configurationSymbol, cocoUUT, expectedErrors, CheckOption.CRITIC_ARCHITECTURE_COCOS);
}
private void checkValidTrainedArchitecture(final ConfigurationSymbol configurationSymbol) {
Log.getFindings().clear();
CNNTrainCocos.checkTrainedArchitectureCoCos(configurationSymbol);
new ExpectedErrorInfo().checkOnlyExpectedPresent(Log.getFindings());
}
private void checkValidCriticArchitecture(final ConfigurationSymbol configurationSymbol) {
Log.getFindings().clear();
CNNTrainCocos.checkCriticCocos(configurationSymbol);
new ExpectedErrorInfo().checkOnlyExpectedPresent(Log.getFindings());
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain.cocos;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import java.awt.*;
import java.util.List;
import java.util.Map;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class NNArchitecturerBuilder {
private static final String ACTOR_NN_NAME ="ActorNetwork";
private static final String ACTOR_STATE_NAME = "actorState";
private static final List<Integer> ACTOR_STATE_DIM = Lists.newArrayList(25);
private static final String ACTOR_STATE_TYPE = "Q";
private static final Range ACTOR_STATE_RANGE = Range.withInfinityLimits();
private static final String ACTOR_ACTION_NAME = "actorAction";
private static final List<Integer> ACTOR_ACTION_DIM = Lists.newArrayList(3);
private static final String ACTOR_ACTION_TYPE = "Q";
private static final Range ACTOR_ACTION_RANGE = Range.withLimits(0, 1);
private static final String CRITIC_NN_NAME = "CriticNetwork";
private static final String CRITIC_STATE_NAME = "criticState";
private static final String CRITIC_ACTION_NAME = "criticAction";
private static final String CRITIC_QVALUE_NAME = "criticQValue";
public NNArchitectureSymbol getCriticWithDifferentStateRanges() {
Map<String, Range> ranges = getValidCriticRanges();
ranges.put(CRITIC_STATE_NAME, Range.withLowerInfinityLimit(5.0));
return getNNArchitectureSymbolFrom(CRITIC_NN_NAME, getValidCriticInputs(), getValidCriticOutputs(),
getValidCriticDimensions(), getValidCriticTypes(), ranges);
}
public NNArchitectureSymbol getCriticWithDifferentActionRanges() {
Map<String, Range> ranges = getValidCriticRanges();
ranges.put(CRITIC_ACTION_NAME, Range.withLimits(-3.5, 3.5));
return getNNArchitectureSymbolFrom(CRITIC_NN_NAME, getValidCriticInputs(), getValidCriticOutputs(),
getValidCriticDimensions(), getValidCriticTypes(), ranges);
}
public NNArchitectureSymbol getCriticWithDifferentActionTypes() {
Map<String, String> types = getValidCriticTypes();
types.put(CRITIC_ACTION_NAME, "Z");
return getNNArchitectureSymbolFrom(CRITIC_NN_NAME, getValidCriticInputs(), getValidCriticOutputs(),
getValidCriticDimensions(), types, getValidCriticRanges());
}
public NNArchitectureSymbol getCriticWithDifferentStateTypes() {
Map<String, String> types = getValidCriticTypes();
types.put(CRITIC_STATE_NAME, "Z");
return getNNArchitectureSymbolFrom(CRITIC_NN_NAME, getValidCriticInputs(), getValidCriticOutputs(),
getValidCriticDimensions(), types, getValidCriticRanges());
}
public NNArchitectureSymbol getCriticWithDifferentActionDimensions() {
Map<String, List<Integer>> dimensions = getValidCriticDimensions();
dimensions.put(CRITIC_ACTION_NAME, Lists.newArrayList(28));
return getNNArchitectureSymbolFrom(CRITIC_NN_NAME, getValidCriticInputs(), getValidCriticOutputs(), dimensions,
getValidCriticTypes(), getValidCriticRanges());
}
public NNArchitectureSymbol getCriticWithDifferentStateDimensions() {
Map<String, List<Integer>> dimensions = getValidCriticDimensions();
dimensions.put(CRITIC_STATE_NAME, Lists.newArrayList(12));
return getNNArchitectureSymbolFrom(CRITIC_NN_NAME, getValidCriticInputs(), getValidCriticOutputs(), dimensions,
getValidCriticTypes(), getValidCriticRanges());
}
public NNArchitectureSymbol getCriticWithTwoOutputs() {
final String anySecondOutputName = "qvalue2";
List<String> outputNames = getValidCriticOutputs();
outputNames.add(anySecondOutputName);
Map<String, List<Integer>> dimensions = getValidCriticDimensions();
dimensions.put(anySecondOutputName, Lists.newArrayList(2));
Map<String, String> types = getValidCriticTypes();
types.put(anySecondOutputName, "Q");
Map<String, Range> ranges = getValidCriticRanges();
ranges.put(anySecondOutputName, Range.withInfinityLimits());
return getNNArchitectureSymbolFrom(ACTOR_NN_NAME, getValidCriticInputs(), outputNames, dimensions, types, ranges);
}
public NNArchitectureSymbol getCriticWithThreeDimensionalOutput() {
Map<String, List<Integer>> dimensions = getValidCriticDimensions();
dimensions.put(CRITIC_QVALUE_NAME, Lists.newArrayList(4));
return getNNArchitectureSymbolFrom(CRITIC_NN_NAME, getValidCriticInputs(), getValidCriticOutputs(),
dimensions, getValidCriticTypes(), getValidCriticRanges());
}
public NNArchitectureSymbol getTrainedArchitectureWithTwoOutputs() {
final String anySecondOutputName = "action2";
List<String> outputNames = getValidActorOutputs();
outputNames.add(anySecondOutputName);
Map<String, List<Integer>> dimensions = getValidActorDimensions();
dimensions.put(anySecondOutputName, Lists.newArrayList(2));
Map<String, String> types = getValidActorTypes();
types.put(anySecondOutputName, "Q");
Map<String, Range> ranges = getValidActorRanges();
ranges.put(anySecondOutputName, Range.withInfinityLimits());
return getNNArchitectureSymbolFrom(ACTOR_NN_NAME, getValidActorInputs(), outputNames, dimensions, types, ranges);
}
public NNArchitectureSymbol getTrainedArchitectureWithTwoInputs() {
final String anySecondInputName = "state2";
List<String> inputNames = getValidActorInputs();
inputNames.add(anySecondInputName);
Map<String, List<Integer>> dimensions = getValidActorDimensions();
dimensions.put(anySecondInputName, Lists.newArrayList(2));
Map<String, String> types = getValidActorTypes();
types.put(anySecondInputName, "Q");
Map<String, Range> ranges = getValidActorRanges();
ranges.put(anySecondInputName, Range.withInfinityLimits());
return getNNArchitectureSymbolFrom(ACTOR_NN_NAME, inputNames, getValidActorOutputs(), dimensions, types, ranges);
}
public NNArchitectureSymbol getNNArchitectureSymbolFrom(String name, List<String> inputs, List<String> outputs,
Map<String, List<Integer>> dimensions, Map<String, String> types, Map<String, Range> ranges)
{
NNArchitectureSymbol architectureSymbolMock = mock(NNArchitectureSymbol.class);
when(architectureSymbolMock.getName()).thenReturn(name);
when(architectureSymbolMock.getInputs()).thenReturn(inputs);
when(architectureSymbolMock.getOutputs()).thenReturn(outputs);
when(architectureSymbolMock.getDimensions()).thenReturn(dimensions);
when(architectureSymbolMock.getTypes()).thenReturn(types);
when(architectureSymbolMock.getRanges()).thenReturn(ranges);
return architectureSymbolMock;
}
public NNArchitectureSymbol getValidTrainedArchitecture() {
return getNNArchitectureSymbolFrom(ACTOR_NN_NAME,
getValidActorInputs(), getValidActorOutputs(), getValidActorDimensions(), getValidActorTypes(),
getValidActorRanges());
}
public Map<String, Range> getValidActorRanges() {
return Maps.newHashMap(ImmutableMap.<String, Range>builder()
.put(ACTOR_STATE_NAME, ACTOR_STATE_RANGE)
.put(ACTOR_ACTION_NAME, ACTOR_ACTION_RANGE)
.build());
}
public Map<String, List<Integer>> getValidActorDimensions() {
return Maps.newHashMap(ImmutableMap.<String, List<Integer>>builder()
.put(ACTOR_STATE_NAME, ACTOR_STATE_DIM)
.put(ACTOR_ACTION_NAME, ACTOR_ACTION_DIM)
.build());
}
public List<String> getValidActorInputs() {
return Lists.newArrayList(ACTOR_STATE_NAME);
}
public List<String> getValidActorOutputs() {
return Lists.newArrayList(ACTOR_ACTION_NAME);
}
public Map<String, String> getValidActorTypes() {
return Maps.newHashMap(ImmutableMap.<String, String>builder()
.put(ACTOR_STATE_NAME, ACTOR_STATE_TYPE)
.put(ACTOR_ACTION_NAME, ACTOR_ACTION_TYPE)
.build());
}