Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNTrainLang
Commits
53fe8e9c
Commit
53fe8e9c
authored
Jul 23, 2019
by
Nicola Gatto
Browse files
Fix cocos for critic inputs
parent
944c9ed8
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java
View file @
53fe8e9c
...
...
@@ -48,12 +48,6 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
Log
.
error
(
"Malformed trained architecture"
);
}
if
(
trainedArchitecture
.
getInputs
().
size
()
!=
2
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
"Number of critic network inputs is wrong. Critic network has two inputs,"
+
"first needs to be a state input and second needs to be the action input."
);
}
final
String
stateInput
=
trainedArchitecture
.
getInputs
().
get
(
0
);
final
String
actionOutput
=
trainedArchitecture
.
getOutputs
().
get
(
0
);
final
List
<
Integer
>
stateDimensions
=
trainedArchitecture
.
getDimensions
().
get
(
stateInput
);
...
...
@@ -66,23 +60,29 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
String
criticInput1
=
criticNetwork
.
getInputs
().
get
(
0
);
String
criticInput2
=
criticNetwork
.
getInputs
().
get
(
1
);
if
(
criticNetwork
.
getDimensions
().
get
(
criticInput1
).
equals
(
stateDimensions
))
{
if
(
criticNetwork
.
getInputs
().
size
()
!=
2
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
"Number of critic network inputs is wrong. Critic network has two inputs,"
+
"first needs to be a state input and second needs to be the action input."
);
}
if
(!
criticNetwork
.
getDimensions
().
get
(
criticInput1
).
equals
(
stateDimensions
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Dimensions of first input of critic architecture must be"
+
" equal to state's dimensions "
+
stateDimensions
.
stream
().
map
(
Object:
:
toString
).
collect
(
Collectors
.
joining
(
"
{
"
,
"
,
"
,
"}"
))
+
stateDimensions
.
stream
().
map
(
Object:
:
toString
).
collect
(
Collectors
.
joining
(
"
,
"
,
"
{
"
,
"}"
))
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getDimensions
().
get
(
criticInput2
).
equals
(
actionDimensions
))
{
if
(
!
criticNetwork
.
getDimensions
().
get
(
criticInput2
).
equals
(
actionDimensions
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Dimensions of second input of critic architecture must be"
+
" equal to action's dimensions "
+
actionDimensions
.
stream
().
map
(
Object:
:
toString
).
collect
(
Collectors
.
joining
(
"
{
"
,
"
,
"
,
"}"
))
+
actionDimensions
.
stream
().
map
(
Object:
:
toString
).
collect
(
Collectors
.
joining
(
"
,
"
,
"
{
"
,
"}"
))
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getRanges
().
get
(
criticInput1
).
equals
(
stateRange
))
{
if
(
!
criticNetwork
.
getRanges
().
get
(
criticInput1
).
equals
(
stateRange
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Ranges of first input of critic architecture must be"
+
" equal to state's ranges "
...
...
@@ -90,7 +90,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getRanges
().
get
(
criticInput2
).
equals
(
actionRange
))
{
if
(
!
criticNetwork
.
getRanges
().
get
(
criticInput2
).
equals
(
actionRange
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Ranges of second input of critic architecture must be"
+
" equal to action's ranges "
...
...
@@ -98,7 +98,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getTypes
().
get
(
criticInput1
).
equals
(
stateType
))
{
if
(
!
criticNetwork
.
getTypes
().
get
(
criticInput1
).
equals
(
stateType
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Type of first input of critic architecture must be"
+
" equal to state's types "
...
...
@@ -106,7 +106,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getTypes
().
get
(
criticInput2
).
equals
(
actionType
))
{
if
(
!
criticNetwork
.
getTypes
().
get
(
criticInput2
).
equals
(
actionType
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Type of second input of critic architecture must be"
+
" equal to action's types "
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment