Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
8ad4df06
Commit
8ad4df06
authored
Oct 03, 2019
by
Nicola Gatto
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add state check and add tests for reward check
parent
3f811ad0
Pipeline
#191050
passed with stages
in 4 minutes and 5 seconds
Changes
3
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
252 additions
and
5 deletions
+252
-5
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
.../lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
+11
-3
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
...luongenerator/reinforcement/FunctionParameterChecker.java
+26
-2
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterCheckerTest.java
...generator/reinforcement/FunctionParameterCheckerTest.java
+215
-0
No files found.
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
View file @
8ad4df06
...
...
@@ -151,7 +151,14 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
// Generate Reward function if necessary
if
(
configuration
.
getRlRewardFunction
().
isPresent
())
{
generateRewardFunction
(
configuration
.
getRlRewardFunction
().
get
(),
Paths
.
get
(
rootProjectModelsDir
));
if
(
configuration
.
getTrainedArchitecture
().
isPresent
())
{
generateRewardFunction
(
configuration
.
getTrainedArchitecture
().
get
(),
configuration
.
getRlRewardFunction
().
get
(),
Paths
.
get
(
rootProjectModelsDir
));
}
else
{
Log
.
error
(
"No architecture model for the trained neural network but is required for "
+
"reinforcement learning configuration."
);
}
}
ftlContext
.
put
(
"trainerName"
,
trainerName
);
...
...
@@ -167,7 +174,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
return
fileContentMap
;
}
private
void
generateRewardFunction
(
RewardFunctionSymbol
rewardFunctionSymbol
,
Path
modelsDirPath
)
{
private
void
generateRewardFunction
(
NNArchitectureSymbol
trainedArchitecture
,
RewardFunctionSymbol
rewardFunctionSymbol
,
Path
modelsDirPath
)
{
GeneratorPythonWrapperStandaloneApi
pythonWrapperApi
=
new
GeneratorPythonWrapperStandaloneApi
();
List
<
String
>
fullNameOfComponent
=
rewardFunctionSymbol
.
getRewardFunctionComponentName
();
...
...
@@ -200,7 +208,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
componentPortInformation
=
pythonWrapperApi
.
generate
(
emaSymbol
,
pythonWrapperOutputPath
);
}
RewardFunctionParameterAdapter
functionParameter
=
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
new
FunctionParameterChecker
().
check
(
functionParameter
);
new
FunctionParameterChecker
().
check
(
functionParameter
,
trainedArchitecture
);
rewardFunctionSymbol
.
setRewardFunctionParameter
(
functionParameter
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
View file @
8ad4df06
/* (c) https://github.com/MontiCore/monticore */
package
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.List
;
/**
*
*/
...
...
@@ -11,21 +14,42 @@ public class FunctionParameterChecker {
private
String
inputTerminalParameterName
;
private
String
outputParameterName
;
private
RewardFunctionParameterAdapter
rewardFunctionParameter
;
private
NNArchitectureSymbol
trainedArchitecture
;
public
FunctionParameterChecker
()
{
}
public
void
check
(
final
RewardFunctionParameterAdapter
rewardFunctionParameter
)
{
public
void
check
(
final
RewardFunctionParameterAdapter
rewardFunctionParameter
,
final
NNArchitectureSymbol
trainedArchitecture
)
{
assert
rewardFunctionParameter
!=
null
;
assert
trainedArchitecture
!=
null
;
this
.
rewardFunctionParameter
=
rewardFunctionParameter
;
this
.
trainedArchitecture
=
trainedArchitecture
;
retrieveParameterNames
();
checkHasExactlyTwoInputs
();
checkHasExactlyOneOutput
();
checkHasStateAndTerminalInput
();
checkInputStateDimension
();
checkInputTerminalTypeAndDimension
();
checkStateDimensionEqualsTrainedArchitectureState
();
checkInputStateDimension
();
checkOutputDimension
();
}
private
void
checkStateDimensionEqualsTrainedArchitectureState
()
{
failIfConditionFails
(
stateInputOfNNArchitectureIsEqualToRewardState
(),
"State dimension of trained architecture is not equal to reward state dimensions."
);
}
private
boolean
stateInputOfNNArchitectureIsEqualToRewardState
()
{
assert
trainedArchitecture
.
getInputs
().
size
()
==
1
:
"Trained architecture is not a policy network."
;
final
String
nnStateInputName
=
trainedArchitecture
.
getInputs
().
get
(
0
);
final
List
<
Integer
>
dimensions
=
trainedArchitecture
.
getDimensions
().
get
(
nnStateInputName
);
return
rewardFunctionParameter
.
getInputPortDimensionOfPort
(
inputStateParameterName
).
isPresent
()
&&
rewardFunctionParameter
.
getInputPortDimensionOfPort
(
inputStateParameterName
).
get
().
equals
(
dimensions
);
}
private
void
checkHasExactlyTwoInputs
()
{
failIfConditionFails
(
functionHasTwoInputs
(),
"Reward function must have exactly two input parameters: "
+
"One input needs to represents the environment's state and another input needs to be a "
...
...
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterCheckerTest.java
0 → 100644
View file @
8ad4df06
package
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement
;
import
com.google.common.collect.ImmutableMap
;
import
com.google.common.collect.Lists
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortDirection
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable
;
import
de.se_rwth.commons.logging.Finding
;
import
de.se_rwth.commons.logging.Log
;
import
org.junit.Before
;
import
org.junit.Test
;
import
java.util.List
;
import
java.util.Map
;
import
static
org
.
junit
.
Assert
.*;
import
static
org
.
mockito
.
Mockito
.
mock
;
import
static
org
.
mockito
.
Mockito
.
when
;
public
class
FunctionParameterCheckerTest
{
private
static
final
List
<
Integer
>
STATE_DIMENSIONS
=
Lists
.
newArrayList
(
3
,
2
,
4
);
private
static
final
PortVariable
STATE_PORT
=
PortVariable
.
multidimensionalVariableFrom
(
"input1"
,
EmadlType
.
Q
,
PortDirection
.
INPUT
,
STATE_DIMENSIONS
);
private
static
final
PortVariable
TERMINAL_PORT
=
PortVariable
.
primitiveVariableFrom
(
"input2"
,
EmadlType
.
B
,
PortDirection
.
INPUT
);
private
static
final
PortVariable
OUTPUT_PORT
=
PortVariable
.
primitiveVariableFrom
(
"output1"
,
EmadlType
.
Q
,
PortDirection
.
OUTPUT
);
private
static
final
String
COMPONENT_NAME
=
"TestRewardComponent"
;
FunctionParameterChecker
uut
=
new
FunctionParameterChecker
();
@Before
public
void
setup
()
{
Log
.
getFindings
().
clear
();
Log
.
enableFailQuick
(
false
);
}
@Test
public
void
validReward
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getValidRewardAdapter
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertEquals
(
0
,
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
());
}
@Test
public
void
invalidRewardWithOneInput
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getComponentWithOneInput
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
anyMatch
(
Finding:
:
isError
));
}
@Test
public
void
invalidRewardWithTwoOutputs
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getComponentWithTwoOutputs
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
anyMatch
(
Finding:
:
isError
));
}
@Test
public
void
invalidRewardWithTerminalHasQType
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getComponentWithTwoQInputs
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
anyMatch
(
Finding:
:
isError
));
}
@Test
public
void
invalidRewardWithNonScalarOutput
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getComponentWithNonScalarOutput
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
()
==
1
);
}
@Test
public
void
invalidRewardStateUnequalToTrainedArchitectureState1
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getValidRewardAdapter
();
NNArchitectureSymbol
trainedArchitectureWithDifferenDimension
=
getTrainedArchitectureWithStateDimensions
(
Lists
.
newArrayList
(
6
));
// when
uut
.
check
(
adapter
,
trainedArchitectureWithDifferenDimension
);
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
()
==
1
);
}
@Test
public
void
invalidRewardStateUnequalToTrainedArchitectureState2
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getValidRewardAdapter
();
NNArchitectureSymbol
trainedArchitectureWithDifferenDimension
=
getTrainedArchitectureWithStateDimensions
(
Lists
.
newArrayList
(
3
,
8
));
// when
uut
.
check
(
adapter
,
trainedArchitectureWithDifferenDimension
);
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
()
==
1
);
}
@Test
public
void
invalidRewardStateUnequalToTrainedArchitectureState3
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getValidRewardAdapter
();
NNArchitectureSymbol
trainedArchitectureWithDifferenDimension
=
getTrainedArchitectureWithStateDimensions
(
Lists
.
newArrayList
(
2
,
4
,
3
));
// when
uut
.
check
(
adapter
,
trainedArchitectureWithDifferenDimension
);
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
()
==
1
);
}
private
RewardFunctionParameterAdapter
getComponentWithNonScalarOutput
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
componentPortInformation
.
addAllInputs
(
getValidInputPortVariables
());
List
<
PortVariable
>
outputs
=
Lists
.
newArrayList
(
PortVariable
.
multidimensionalVariableFrom
(
"output"
,
EmadlType
.
Q
,
PortDirection
.
OUTPUT
,
Lists
.
newArrayList
(
2
,
2
)));
componentPortInformation
.
addAllOutputs
(
outputs
);
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
RewardFunctionParameterAdapter
getComponentWithTwoQInputs
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
List
<
PortVariable
>
inputs
=
Lists
.
newArrayList
(
STATE_PORT
,
PortVariable
.
multidimensionalVariableFrom
(
"input2"
,
EmadlType
.
Q
,
PortDirection
.
INPUT
,
Lists
.
newArrayList
(
2
,
3
,
2
)));
componentPortInformation
.
addAllInputs
(
inputs
);
componentPortInformation
.
addAllOutputs
(
getValidOutputPorts
());
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
RewardFunctionParameterAdapter
getComponentWithTwoOutputs
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
componentPortInformation
.
addAllInputs
(
getValidInputPortVariables
());
List
<
PortVariable
>
outputs
=
getValidOutputPorts
();
outputs
.
add
(
PortVariable
.
primitiveVariableFrom
(
"output2"
,
EmadlType
.
B
,
PortDirection
.
OUTPUT
));
componentPortInformation
.
addAllOutputs
(
outputs
);
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
RewardFunctionParameterAdapter
getComponentWithOneInput
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
componentPortInformation
.
addAllInputs
(
Lists
.
newArrayList
(
STATE_PORT
));
componentPortInformation
.
addAllOutputs
(
getValidOutputPorts
());
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
RewardFunctionParameterAdapter
getValidRewardAdapter
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
componentPortInformation
.
addAllInputs
(
getValidInputPortVariables
());
componentPortInformation
.
addAllOutputs
(
getValidOutputPorts
());
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
List
<
PortVariable
>
getValidOutputPorts
()
{
return
Lists
.
newArrayList
(
OUTPUT_PORT
);
}
private
List
<
PortVariable
>
getValidInputPortVariables
()
{
return
Lists
.
newArrayList
(
STATE_PORT
,
TERMINAL_PORT
);
}
private
NNArchitectureSymbol
getValidTrainedArchitecture
()
{
NNArchitectureSymbol
nnArchitectureSymbol
=
mock
(
NNArchitectureSymbol
.
class
);
final
String
stateInputName
=
"stateInput"
;
when
(
nnArchitectureSymbol
.
getInputs
()).
thenReturn
(
Lists
.
newArrayList
(
stateInputName
));
when
(
nnArchitectureSymbol
.
getDimensions
()).
thenReturn
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
stateInputName
,
STATE_DIMENSIONS
)
.
build
());
return
nnArchitectureSymbol
;
}
private
NNArchitectureSymbol
getTrainedArchitectureWithStateDimensions
(
final
List
<
Integer
>
dimensions
)
{
NNArchitectureSymbol
nnArchitectureSymbol
=
mock
(
NNArchitectureSymbol
.
class
);
final
String
stateInputName
=
"stateInput"
;
when
(
nnArchitectureSymbol
.
getInputs
()).
thenReturn
(
Lists
.
newArrayList
(
stateInputName
));
when
(
nnArchitectureSymbol
.
getDimensions
()).
thenReturn
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
stateInputName
,
dimensions
)
.
build
());
return
nnArchitectureSymbol
;
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment