Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNTrainLang
Commits
35300ca2
Commit
35300ca2
authored
Sep 30, 2019
by
Nicola Gatto
Browse files
Add tests for inter cocos and minor changes
parent
dc20415c
Changes
10
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
35300ca2
...
...
@@ -132,6 +132,13 @@
<scope>
test
</scope>
</dependency>
<dependency>
<groupId>
org.mockito
</groupId>
<artifactId>
mockito-core
</artifactId>
<version>
1.10.19
</version>
<scope>
test
</scope>
</dependency>
<dependency>
<groupId>
ch.qos.logback
</groupId>
<artifactId>
logback-classic
</artifactId>
...
...
src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
View file @
35300ca2
...
...
@@ -208,7 +208,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
GymEnvironmentNameEntry
implements
GymEnvironmentEntry
=
name
:
"name"
":"
value
:
StringValue
;
interface
RosEnvironmentEntry
extends
Entry
;
RosEnvironmentValue
implements
EnvironmentValue
=
|
name
:
"ros_interface"
(
"{"
params
:
RosEnvironmentEntry
*
"}"
)?;
RosEnvironmentValue
implements
EnvironmentValue
=
name
:
"ros_interface"
(
"{"
params
:
RosEnvironmentEntry
*
"}"
)?;
RosEnvironmentStateTopicEntry
implements
RosEnvironmentEntry
=
name
:
"state_topic"
":"
value
:
StringValue
;
RosEnvironmentActionTopicEntry
implements
RosEnvironmentEntry
=
name
:
"action_topic"
":"
value
:
StringValue
;
RosEnvironmentResetTopicEntry
implements
RosEnvironmentEntry
=
name
:
"reset_topic"
":"
value
:
StringValue
;
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
View file @
35300ca2
...
...
@@ -39,7 +39,8 @@ public class CNNTrainCocos {
CNNTrainConfigurationSymbolChecker
checker
=
new
CNNTrainConfigurationSymbolChecker
()
.
addCoCo
(
new
CheckTrainedRlNetworkHasExactlyOneInput
())
.
addCoCo
(
new
CheckTrainedRlNetworkHasExactlyOneOutput
())
.
addCoCo
(
new
CheckOUParameterDimensionEqualsActionDimension
());
.
addCoCo
(
new
CheckOUParameterDimensionEqualsActionDimension
())
.
addCoCo
(
new
CheckTrainedArchitectureHasVectorAction
());
checker
.
checkAll
(
configurationSymbol
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckOUParameterDimensionEqualsActionDimension.java
View file @
35300ca2
...
...
@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.Collection
;
...
...
@@ -57,7 +58,8 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC
String
ouParameterName
)
{
final
int
ouParameterDimension
=
((
Collection
<?>)
strategyParameters
.
getParameter
(
ouParameterName
)).
size
();
if
(
ouParameterDimension
!=
actionVectorDimension
)
{
Log
.
error
(
"Vector parameter "
+
ouParameterName
+
" of parameter "
+
STRATEGY_OU
+
" must have"
Log
.
error
(
"0"
+
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
+
" Vector parameter "
+
ouParameterName
+
" of parameter "
+
STRATEGY_OU
+
" must have"
+
" the same dimensions as the action dimension of output "
+
outputNameOfTrainedArchitecture
+
" which is "
+
actionVectorDimension
,
configurationSymbol
.
getSourcePosition
());
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedArchitectureHasVectorAction.java
0 → 100644
View file @
35300ca2
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.List
;
public
class
CheckTrainedArchitectureHasVectorAction
implements
CNNTrainConfigurationSymbolCoCo
{
@Override
public
void
check
(
ConfigurationSymbol
configurationSymbol
)
{
if
(
configurationSymbol
.
getTrainedArchitecture
().
isPresent
()
&&
configurationSymbol
.
isReinforcementLearningMethod
())
{
final
NNArchitectureSymbol
trainedArchitecture
=
configurationSymbol
.
getTrainedArchitecture
().
get
();
if
(
trainedArchitecture
.
getOutputs
().
size
()
==
1
)
{
final
String
actionName
=
trainedArchitecture
.
getOutputs
().
get
(
0
);
final
List
<
Integer
>
actionDimensions
=
trainedArchitecture
.
getDimensions
().
get
(
actionName
);
if
(
actionDimensions
.
size
()
!=
1
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
+
" Output of actor network must be a vector"
,
configurationSymbol
.
getSourcePosition
());
}
}
}
}
}
src/test/java/de/monticore/lang/monticar/cnntrain/cocos/InterCocoTest.java
0 → 100644
View file @
35300ca2
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package
de.monticore.lang.monticar.cnntrain.cocos
;
import
de.monticore.lang.monticar.cnntrain._cocos.*
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
org.junit.Test
;
public
class
InterCocoTest
extends
AbstractCoCoTest
{
NNArchitecturerBuilder
NNBuilder
=
new
NNArchitecturerBuilder
();
@Test
public
void
testValidTD3ActorCritic
()
{
// given
final
NNArchitectureSymbol
validActor
=
NNBuilder
.
getValidTrainedArchitecture
();
final
NNArchitectureSymbol
validCritic
=
NNBuilder
.
getValidCriticArchitecture
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
validActor
,
validCritic
);
// when
checkValidTrainedArchitecture
(
configurationSymbol
);
checkValidCriticArchitecture
(
configurationSymbol
);
}
@Test
public
void
testValidDDPGActorCritic
()
{
// given
final
NNArchitectureSymbol
validActor
=
NNBuilder
.
getValidTrainedArchitecture
();
final
NNArchitectureSymbol
validCritic
=
NNBuilder
.
getValidCriticArchitecture
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"DdpgConfig"
,
validActor
,
validCritic
);
// when
checkValidTrainedArchitecture
(
configurationSymbol
);
checkValidCriticArchitecture
(
configurationSymbol
);
}
@Test
public
void
testInvalidTrainingArchitectureWithTwoInputs
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckTrainedRlNetworkHasExactlyOneInput
();
NNArchitectureSymbol
nnWithTwoInputs
=
NNBuilder
.
getTrainedArchitectureWithTwoInputs
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
nnWithTwoInputs
,
NNBuilder
.
getValidCriticArchitecture
());
// when
checkInvalidTrainedArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidTrainingArchitectureWithTwoOutputs
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckTrainedRlNetworkHasExactlyOneOutput
();
NNArchitectureSymbol
nnWithTwoOutputs
=
NNBuilder
.
getTrainedArchitectureWithTwoOutputs
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
nnWithTwoOutputs
,
NNBuilder
.
getValidCriticArchitecture
());
// when
checkInvalidTrainedArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidActionDimensionUnequalToOUParameterDimension1
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckOUParameterDimensionEqualsActionDimension
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"invalid_cocos_tests"
,
"UnequalOUDim1"
,
NNBuilder
.
getValidTrainedArchitecture
(),
NNBuilder
.
getValidCriticArchitecture
());
// when
checkInvalidTrainedArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
3
,
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidActionDimensionUnequalToOUParameterDimension2
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckOUParameterDimensionEqualsActionDimension
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"invalid_cocos_tests"
,
"UnequalOUDim2"
,
NNBuilder
.
getValidTrainedArchitecture
(),
NNBuilder
.
getValidCriticArchitecture
());
// when
checkInvalidTrainedArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
2
,
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidActionDimensionUnequalToOUParameterDimension3
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckOUParameterDimensionEqualsActionDimension
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"invalid_cocos_tests"
,
"UnequalOUDim3"
,
NNBuilder
.
getValidTrainedArchitecture
(),
NNBuilder
.
getValidCriticArchitecture
());
// when
checkInvalidTrainedArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
));
}
@Test
public
void
testInvalidCriticHasNotOneDimensionalOutput
()
{
//given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckCriticNetworkHasExactlyAOneDimensionalOutput
();
NNArchitectureSymbol
criticWithThreeDimensionalOutput
=
NNBuilder
.
getCriticWithThreeDimensionalOutput
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
NNBuilder
.
getValidTrainedArchitecture
(),
criticWithThreeDimensionalOutput
);
// when
checkInvalidCriticArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
CRITIC_NETWORK_ERROR
));
}
@Test
public
void
testInvalidCriticHasTwoOutputs
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckCriticNetworkHasExactlyAOneDimensionalOutput
();
NNArchitectureSymbol
criticWithThreeDimensionalOutput
=
NNBuilder
.
getCriticWithTwoOutputs
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
NNBuilder
.
getValidTrainedArchitecture
(),
criticWithThreeDimensionalOutput
);
// when
checkInvalidCriticArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
CRITIC_NETWORK_ERROR
));
}
@Test
public
void
testInvalidCriticHasDifferentStateDimensions
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckCriticNetworkInputs
();
NNArchitectureSymbol
criticWithDifferentDimensions
=
NNBuilder
.
getCriticWithDifferentStateDimensions
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
NNBuilder
.
getValidTrainedArchitecture
(),
criticWithDifferentDimensions
);
// when
checkInvalidCriticArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
CRITIC_NETWORK_ERROR
));
}
@Test
public
void
testInvalidCriticHasDifferentActionDimensions
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckCriticNetworkInputs
();
NNArchitectureSymbol
criticWithDifferentDimensions
=
NNBuilder
.
getCriticWithDifferentActionDimensions
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
NNBuilder
.
getValidTrainedArchitecture
(),
criticWithDifferentDimensions
);
// when
checkInvalidCriticArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
CRITIC_NETWORK_ERROR
));
}
@Test
public
void
testInvalidCriticHasDifferentStateTypes
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckCriticNetworkInputs
();
NNArchitectureSymbol
criticWithDifferentStateTypes
=
NNBuilder
.
getCriticWithDifferentStateTypes
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
NNBuilder
.
getValidTrainedArchitecture
(),
criticWithDifferentStateTypes
);
// when
checkInvalidCriticArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
CRITIC_NETWORK_ERROR
));
}
@Test
public
void
testInvalidCriticHasDifferentActionTypes
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckCriticNetworkInputs
();
NNArchitectureSymbol
criticWithDifferentActionTypes
=
NNBuilder
.
getCriticWithDifferentActionTypes
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
NNBuilder
.
getValidTrainedArchitecture
(),
criticWithDifferentActionTypes
);
// when
checkInvalidCriticArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
CRITIC_NETWORK_ERROR
));
}
@Test
public
void
testInvalidCriticHasDifferentStateRanges
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckCriticNetworkInputs
();
NNArchitectureSymbol
criticWithDifferentStateRanges
=
NNBuilder
.
getCriticWithDifferentStateRanges
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
NNBuilder
.
getValidTrainedArchitecture
(),
criticWithDifferentStateRanges
);
// when
checkInvalidCriticArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
CRITIC_NETWORK_ERROR
));
}
@Test
public
void
testInvalidCriticHasDifferentActionRanges
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckCriticNetworkInputs
();
NNArchitectureSymbol
criticWithDifferentActionRanges
=
NNBuilder
.
getCriticWithDifferentActionRanges
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
NNBuilder
.
getValidTrainedArchitecture
(),
criticWithDifferentActionRanges
);
// when
checkInvalidCriticArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
CRITIC_NETWORK_ERROR
));
}
@Test
public
void
testInvalidTrainedArchitectureWithMultidimensionalAction
()
{
// given
CNNTrainConfigurationSymbolCoCo
cocoUUT
=
new
CheckTrainedArchitectureHasVectorAction
();
NNArchitectureSymbol
actorWithMultidimensionalAction
=
NNBuilder
.
getTrainedArchitectureWithMultidimensionalAction
();
NNArchitectureSymbol
criticWithMultidimensionalAction
=
NNBuilder
.
getCriticWithMultidimensionalAction
();
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolFrom
(
"valid_tests"
,
"TD3Config"
,
actorWithMultidimensionalAction
,
criticWithMultidimensionalAction
);
// when
checkInvalidTrainedArchitecture
(
configurationSymbol
,
cocoUUT
,
new
ExpectedErrorInfo
(
1
,
ErrorCodes
.
TRAINED_ARCHITECTURE_ERROR
));
}
private
ConfigurationSymbol
getConfigurationSymbolFrom
(
final
String
modelPath
,
final
String
model
,
final
NNArchitectureSymbol
actorArchitecture
,
final
NNArchitectureSymbol
criticArchitecture
)
{
final
ConfigurationSymbol
configurationSymbol
=
getConfigurationSymbolByPath
(
modelPath
,
model
);
configurationSymbol
.
setTrainedArchitecture
(
actorArchitecture
);
configurationSymbol
.
setCriticNetwork
(
criticArchitecture
);
return
configurationSymbol
;
}
private
ConfigurationSymbol
getConfigurationSymbolByPath
(
final
String
modelPath
,
final
String
model
)
{
return
getCompilationUnitSymbol
(
modelPath
,
model
).
getConfiguration
();
}
private
enum
CheckOption
{
TRAINED_ARCHITECTURE_COCOS
,
CRITIC_ARCHITECTURE_COCOS
,
}
private
void
checkInvalidArchitecture
(
final
ConfigurationSymbol
configurationSymbol
,
final
CNNTrainConfigurationSymbolCoCo
cocoUUT
,
final
ExpectedErrorInfo
expectedErrors
,
final
CheckOption
checkOption
)
{
Log
.
getFindings
().
clear
();
if
(
checkOption
.
equals
(
CheckOption
.
TRAINED_ARCHITECTURE_COCOS
))
{
CNNTrainCocos
.
checkTrainedArchitectureCoCos
(
configurationSymbol
);
}
else
{
CNNTrainCocos
.
checkCriticCocos
(
configurationSymbol
);
}
expectedErrors
.
checkExpectedPresent
(
Log
.
getFindings
(),
"Got no findings when checking all "
+
"cocos. Did you forget to add the new coco to MontiArcCocos?"
);
Log
.
getFindings
().
clear
();
CNNTrainConfigurationSymbolChecker
checker
=
new
CNNTrainConfigurationSymbolChecker
().
addCoCo
(
cocoUUT
);
checker
.
checkAll
(
configurationSymbol
);
expectedErrors
.
checkOnlyExpectedPresent
(
Log
.
getFindings
(),
"Got no findings when checking only "
+
"the given coco. Did you pass an empty coco checker?"
);
}
private
void
checkInvalidTrainedArchitecture
(
final
ConfigurationSymbol
configurationSymbol
,
final
CNNTrainConfigurationSymbolCoCo
cocoUUT
,
ExpectedErrorInfo
expectedErrors
)
{
checkInvalidArchitecture
(
configurationSymbol
,
cocoUUT
,
expectedErrors
,
CheckOption
.
TRAINED_ARCHITECTURE_COCOS
);
}
private
void
checkInvalidCriticArchitecture
(
final
ConfigurationSymbol
configurationSymbol
,
final
CNNTrainConfigurationSymbolCoCo
cocoUUT
,
ExpectedErrorInfo
expectedErrors
)
{
checkInvalidArchitecture
(
configurationSymbol
,
cocoUUT
,
expectedErrors
,
CheckOption
.
CRITIC_ARCHITECTURE_COCOS
);
}
private
void
checkValidTrainedArchitecture
(
final
ConfigurationSymbol
configurationSymbol
)
{
Log
.
getFindings
().
clear
();
CNNTrainCocos
.
checkTrainedArchitectureCoCos
(
configurationSymbol
);
new
ExpectedErrorInfo
().
checkOnlyExpectedPresent
(
Log
.
getFindings
());
}
private
void
checkValidCriticArchitecture
(
final
ConfigurationSymbol
configurationSymbol
)
{
Log
.
getFindings
().
clear
();
CNNTrainCocos
.
checkCriticCocos
(
configurationSymbol
);
new
ExpectedErrorInfo
().
checkOnlyExpectedPresent
(
Log
.
getFindings
());
}
}
src/test/java/de/monticore/lang/monticar/cnntrain/cocos/NNArchitecturerBuilder.java
0 → 100644
View file @
35300ca2
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package
de.monticore.lang.monticar.cnntrain.cocos
;
import
com.google.common.collect.ImmutableMap
;
import
com.google.common.collect.Lists
;
import
com.google.common.collect.Maps
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.annotations.Range
;
import
java.awt.*
;
import
java.util.List
;
import
java.util.Map
;
import
static
org
.
mockito
.
Mockito
.
mock
;
import
static
org
.
mockito
.
Mockito
.
when
;
public
class
NNArchitecturerBuilder
{
private
static
final
String
ACTOR_NN_NAME
=
"ActorNetwork"
;
private
static
final
String
ACTOR_STATE_NAME
=
"actorState"
;
private
static
final
List
<
Integer
>
ACTOR_STATE_DIM
=
Lists
.
newArrayList
(
25
);
private
static
final
String
ACTOR_STATE_TYPE
=
"Q"
;
private
static
final
Range
ACTOR_STATE_RANGE
=
Range
.
withInfinityLimits
();
private
static
final
String
ACTOR_ACTION_NAME
=
"actorAction"
;
private
static
final
List
<
Integer
>
ACTOR_ACTION_DIM
=
Lists
.
newArrayList
(
3
);
private
static
final
String
ACTOR_ACTION_TYPE
=
"Q"
;
private
static
final
Range
ACTOR_ACTION_RANGE
=
Range
.
withLimits
(
0
,
1
);
private
static
final
String
CRITIC_NN_NAME
=
"CriticNetwork"
;
private
static
final
String
CRITIC_STATE_NAME
=
"criticState"
;
private
static
final
String
CRITIC_ACTION_NAME
=
"criticAction"
;
private
static
final
String
CRITIC_QVALUE_NAME
=
"criticQValue"
;
public
NNArchitectureSymbol
getCriticWithDifferentStateRanges
()
{
Map
<
String
,
Range
>
ranges
=
getValidCriticRanges
();
ranges
.
put
(
CRITIC_STATE_NAME
,
Range
.
withLowerInfinityLimit
(
5.0
));
return
getNNArchitectureSymbolFrom
(
CRITIC_NN_NAME
,
getValidCriticInputs
(),
getValidCriticOutputs
(),
getValidCriticDimensions
(),
getValidCriticTypes
(),
ranges
);
}
public
NNArchitectureSymbol
getCriticWithDifferentActionRanges
()
{
Map
<
String
,
Range
>
ranges
=
getValidCriticRanges
();
ranges
.
put
(
CRITIC_ACTION_NAME
,
Range
.
withLimits
(-
3.5
,
3.5
));
return
getNNArchitectureSymbolFrom
(
CRITIC_NN_NAME
,
getValidCriticInputs
(),
getValidCriticOutputs
(),
getValidCriticDimensions
(),
getValidCriticTypes
(),
ranges
);
}
public
NNArchitectureSymbol
getCriticWithDifferentActionTypes
()
{
Map
<
String
,
String
>
types
=
getValidCriticTypes
();
types
.
put
(
CRITIC_ACTION_NAME
,
"Z"
);
return
getNNArchitectureSymbolFrom
(
CRITIC_NN_NAME
,
getValidCriticInputs
(),
getValidCriticOutputs
(),
getValidCriticDimensions
(),
types
,
getValidCriticRanges
());
}
public
NNArchitectureSymbol
getCriticWithDifferentStateTypes
()
{
Map
<
String
,
String
>
types
=
getValidCriticTypes
();
types
.
put
(
CRITIC_STATE_NAME
,
"Z"
);
return
getNNArchitectureSymbolFrom
(
CRITIC_NN_NAME
,
getValidCriticInputs
(),
getValidCriticOutputs
(),
getValidCriticDimensions
(),
types
,
getValidCriticRanges
());
}
public
NNArchitectureSymbol
getCriticWithDifferentActionDimensions
()
{
Map
<
String
,
List
<
Integer
>>
dimensions
=
getValidCriticDimensions
();
dimensions
.
put
(
CRITIC_ACTION_NAME
,
Lists
.
newArrayList
(
28
));
return
getNNArchitectureSymbolFrom
(
CRITIC_NN_NAME
,
getValidCriticInputs
(),
getValidCriticOutputs
(),
dimensions
,
getValidCriticTypes
(),
getValidCriticRanges
());
}
public
NNArchitectureSymbol
getCriticWithDifferentStateDimensions
()
{
Map
<
String
,
List
<
Integer
>>
dimensions
=
getValidCriticDimensions
();
dimensions
.
put
(
CRITIC_STATE_NAME
,
Lists
.
newArrayList
(
12
));
return
getNNArchitectureSymbolFrom
(
CRITIC_NN_NAME
,
getValidCriticInputs
(),
getValidCriticOutputs
(),
dimensions
,
getValidCriticTypes
(),
getValidCriticRanges
());
}
public
NNArchitectureSymbol
getCriticWithTwoOutputs
()
{
final
String
anySecondOutputName
=
"qvalue2"
;
List
<
String
>
outputNames
=
getValidCriticOutputs
();
outputNames
.
add
(
anySecondOutputName
);
Map
<
String
,
List
<
Integer
>>
dimensions
=
getValidCriticDimensions
();
dimensions
.
put
(
anySecondOutputName
,
Lists
.
newArrayList
(
2
));
Map
<
String
,
String
>
types
=
getValidCriticTypes
();
types
.
put
(
anySecondOutputName
,
"Q"
);
Map
<
String
,
Range
>
ranges
=
getValidCriticRanges
();
ranges
.
put
(
anySecondOutputName
,
Range
.
withInfinityLimits
());
return
getNNArchitectureSymbolFrom
(
ACTOR_NN_NAME
,
getValidCriticInputs
(),
outputNames
,
dimensions
,
types
,
ranges
);
}
public
NNArchitectureSymbol
getCriticWithThreeDimensionalOutput
()
{
Map
<
String
,
List
<
Integer
>>
dimensions
=
getValidCriticDimensions
();
dimensions
.
put
(
CRITIC_QVALUE_NAME
,
Lists
.
newArrayList
(
4
));
return
getNNArchitectureSymbolFrom
(
CRITIC_NN_NAME
,
getValidCriticInputs
(),
getValidCriticOutputs
(),
dimensions
,
getValidCriticTypes
(),
getValidCriticRanges
());
}
public
NNArchitectureSymbol
getTrainedArchitectureWithTwoOutputs
()
{
final
String
anySecondOutputName
=
"action2"
;
List
<
String
>
outputNames
=
getValidActorOutputs
();
outputNames
.
add
(
anySecondOutputName
);
Map
<
String
,
List
<
Integer
>>
dimensions
=
getValidActorDimensions
();
dimensions
.
put
(
anySecondOutputName
,
Lists
.
newArrayList
(
2
));
Map
<
String
,
String
>
types
=
getValidActorTypes
();
types
.
put
(
anySecondOutputName
,
"Q"
);
Map
<
String
,
Range
>
ranges
=
getValidActorRanges
();
ranges
.
put
(
anySecondOutputName
,
Range
.
withInfinityLimits
());
return
getNNArchitectureSymbolFrom
(
ACTOR_NN_NAME
,
getValidActorInputs
(),
outputNames
,
dimensions
,
types
,
ranges
);
}
public
NNArchitectureSymbol
getTrainedArchitectureWithTwoInputs
()
{
final
String
anySecondInputName
=
"state2"
;
List
<
String
>
inputNames
=
getValidActorInputs
();
inputNames
.
add
(
anySecondInputName
);
Map
<
String
,
List
<
Integer
>>
dimensions
=
getValidActorDimensions
();
dimensions
.
put
(
anySecondInputName
,
Lists
.
newArrayList
(
2
));
Map
<
String
,
String
>
types
=
getValidActorTypes
();
types
.
put
(
anySecondInputName
,
"Q"
);
Map
<
String
,
Range
>
ranges
=
getValidActorRanges
();
ranges
.
put
(
anySecondInputName
,
Range
.
withInfinityLimits
());
return
getNNArchitectureSymbolFrom
(
ACTOR_NN_NAME
,
inputNames
,
getValidActorOutputs
(),
dimensions
,
types
,
ranges
);
}
public
NNArchitectureSymbol
getNNArchitectureSymbolFrom
(
String
name
,
List
<
String
>
inputs
,
List
<
String
>
outputs
,
Map
<
String
,
List
<
Integer
>>
dimensions
,
Map
<
String
,
String
>
types
,
Map
<
String
,
Range
>
ranges
)
{
NNArchitectureSymbol
architectureSymbolMock
=
mock
(
NNArchitectureSymbol
.
class
);
when
(
architectureSymbolMock
.
getName
()).
thenReturn
(
name
);
when
(
architectureSymbolMock
.
getInputs
()).
thenReturn
(
inputs
);
when
(
architectureSymbolMock
.
getOutputs
()).
thenReturn
(
outputs
);
when
(
architectureSymbolMock
.
getDimensions
()).
thenReturn
(
dimensions
);
when
(
architectureSymbolMock
.
getTypes
()).
thenReturn
(
types
);
when
(
architectureSymbolMock
.
getRanges
()).
thenReturn
(
ranges
);
return
architectureSymbolMock
;
}
public
NNArchitectureSymbol
getValidTrainedArchitecture
()
{
return
getNNArchitectureSymbolFrom
(
ACTOR_NN_NAME
,
getValidActorInputs
(),
getValidActorOutputs
(),
getValidActorDimensions
(),
getValidActorTypes
(),
getValidActorRanges
());
}
public
Map
<
String
,
Range
>
getValidActorRanges
()
{
return
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
Range
>
builder
()
.
put
(
ACTOR_STATE_NAME
,
ACTOR_STATE_RANGE
)
.
put
(
ACTOR_ACTION_NAME
,
ACTOR_ACTION_RANGE
)
.
build
());
}
public
Map
<
String
,
List
<
Integer
>>
getValidActorDimensions
()
{
return
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
ACTOR_STATE_NAME
,
ACTOR_STATE_DIM
)
.
put
(
ACTOR_ACTION_NAME
,
ACTOR_ACTION_DIM
)
.
build
());
}
public
List
<
String
>
getValidActorInputs
()
{
return
Lists
.
newArrayList
(
ACTOR_STATE_NAME
);
}
public
List
<
String
>
getValidActorOutputs
()
{
return
Lists
.
newArrayList
(
ACTOR_ACTION_NAME
);
}
public
Map
<
String
,
String
>
getValidActorTypes
()
{
return
Maps
.
newHashMap
(
ImmutableMap
.<
String
,
String
>
builder
()
.
put
(
ACTOR_STATE_NAME
,
ACTOR_STATE_TYPE
)
.
put
(
ACTOR_ACTION_NAME
,
ACTOR_ACTION_TYPE
)
.
build
());
}
public
List
<
String
>
getValidCriticInputs
()
{
return
Lists
.
newArrayList
(
CRITIC_STATE_NAME
,
CRITIC_ACTION_NAME
);