Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNTrainLang
Commits
3d15a4de
Commit
3d15a4de
authored
Jul 20, 2019
by
Nicola Gatto
Browse files
Add critic network input cocos
parent
7ff494bf
Changes
4
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
View file @
3d15a4de
...
...
@@ -50,7 +50,8 @@ public class CNNTrainCocos {
public
static
void
checkCriticCocos
(
final
ConfigurationSymbol
configurationSymbol
)
{
CNNTrainConfigurationSymbolChecker
checker
=
new
CNNTrainConfigurationSymbolChecker
()
.
addCoCo
(
new
CheckCriticNetworkHasExactlyAOneDimensionalOutput
());
.
addCoCo
(
new
CheckCriticNetworkHasExactlyAOneDimensionalOutput
())
.
addCoCo
(
new
CheckCriticNetworkInputs
());
int
findings
=
Log
.
getFindings
().
size
();
checker
.
checkAll
(
configurationSymbol
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java
0 → 100644
View file @
3d15a4de
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.cnntrain.annotations.Range
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.List
;
import
java.util.stream.Collectors
;
/**
*
*/
public
class
CheckCriticNetworkInputs
implements
CNNTrainConfigurationSymbolCoCo
{
@Override
public
void
check
(
ConfigurationSymbol
configurationSymbol
)
{
if
(
configurationSymbol
.
getCriticNetwork
().
isPresent
())
{
if
(!
configurationSymbol
.
getTrainedArchitecture
().
isPresent
())
{
Log
.
error
(
"0"
+
ErrorCodes
.
MISSING_TRAINED_ARCHITECTURE
+
"No architecture found that is trained by this configuration."
,
configurationSymbol
.
getSourcePosition
());
}
NNArchitectureSymbol
trainedArchitecture
=
configurationSymbol
.
getTrainedArchitecture
().
get
();
NNArchitectureSymbol
criticNetwork
=
configurationSymbol
.
getCriticNetwork
().
get
();
if
(
trainedArchitecture
.
getInputs
().
size
()
!=
1
||
trainedArchitecture
.
getOutputs
().
size
()
!=
1
)
{
Log
.
error
(
"Malformed trained architecture"
);
}
if
(
trainedArchitecture
.
getInputs
().
size
()
!=
2
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
"Number of critic network inputs is wrong. Critic network has two inputs,"
+
"first needs to be a state input and second needs to be the action input."
);
}
final
String
stateInput
=
trainedArchitecture
.
getInputs
().
get
(
0
);
final
String
actionOutput
=
trainedArchitecture
.
getOutputs
().
get
(
0
);
final
List
<
Integer
>
stateDimensions
=
trainedArchitecture
.
getDimensions
().
get
(
stateInput
);
final
List
<
Integer
>
actionDimensions
=
trainedArchitecture
.
getDimensions
().
get
(
actionOutput
);
final
Range
stateRange
=
trainedArchitecture
.
getRanges
().
get
(
stateInput
);
final
Range
actionRange
=
trainedArchitecture
.
getRanges
().
get
(
actionOutput
);
final
String
stateType
=
trainedArchitecture
.
getTypes
().
get
(
stateInput
);
final
String
actionType
=
trainedArchitecture
.
getTypes
().
get
(
actionOutput
);
String
criticInput1
=
criticNetwork
.
getInputs
().
get
(
0
);
String
criticInput2
=
criticNetwork
.
getInputs
().
get
(
1
);
if
(
criticNetwork
.
getDimensions
().
get
(
criticInput1
).
equals
(
stateDimensions
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Dimensions of first input of critic architecture must be"
+
" equal to state's dimensions "
+
stateDimensions
.
stream
().
map
(
Object:
:
toString
).
collect
(
Collectors
.
joining
(
"{"
,
","
,
"}"
))
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getDimensions
().
get
(
criticInput2
).
equals
(
actionDimensions
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Dimensions of second input of critic architecture must be"
+
" equal to action's dimensions "
+
actionDimensions
.
stream
().
map
(
Object:
:
toString
).
collect
(
Collectors
.
joining
(
"{"
,
","
,
"}"
))
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getRanges
().
get
(
criticInput1
).
equals
(
stateRange
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Ranges of first input of critic architecture must be"
+
" equal to state's ranges "
+
stateRange
.
toString
()
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getRanges
().
get
(
criticInput2
).
equals
(
actionRange
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Ranges of second input of critic architecture must be"
+
" equal to action's ranges "
+
actionRange
.
toString
()
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getTypes
().
get
(
criticInput1
).
equals
(
stateType
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Type of first input of critic architecture must be"
+
" equal to state's types "
+
stateType
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
if
(
criticNetwork
.
getTypes
().
get
(
criticInput2
).
equals
(
actionType
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
CRITIC_NETWORK_ERROR
+
" Declared critic network is not a critic: Type of second input of critic architecture must be"
+
" equal to action's types "
+
stateType
+
"."
,
configurationSymbol
.
getSourcePosition
());
}
}
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java
View file @
3d15a4de
...
...
@@ -20,6 +20,7 @@
*/
package
de.monticore.lang.monticar.cnntrain.annotations
;
import
java.util.Objects
;
import
java.util.Optional
;
public
class
Range
{
...
...
@@ -66,4 +67,28 @@ public class Range {
public
static
Range
withLowerInfinityLimit
(
double
upperLimit
)
{
return
new
Range
(
true
,
false
,
null
,
upperLimit
);
}
@Override
public
String
toString
()
{
final
String
lowerLimit
=
isLowerLimitInfinity
()
||
!
getLowerLimit
().
isPresent
()
?
"-oo"
:
getLowerLimit
().
get
().
toString
();
final
String
upperLimit
=
isUpperLimitInfinity
()
||
!
getUpperLimit
().
isPresent
()
?
"oo"
:
getUpperLimit
().
get
().
toString
();
return
"["
+
lowerLimit
+
", "
+
upperLimit
+
"]"
;
}
@Override
public
boolean
equals
(
Object
o
)
{
if
(
this
==
o
)
return
true
;
if
(!(
o
instanceof
Range
))
return
false
;
Range
range
=
(
Range
)
o
;
return
lowerLimitIsInfinity
==
range
.
lowerLimitIsInfinity
&&
upperLimitIsInfinity
==
range
.
upperLimitIsInfinity
&&
Objects
.
equals
(
lowerLimit
,
range
.
lowerLimit
)
&&
Objects
.
equals
(
upperLimit
,
range
.
upperLimit
);
}
@Override
public
int
hashCode
()
{
return
Objects
.
hash
(
lowerLimitIsInfinity
,
upperLimitIsInfinity
,
lowerLimit
,
upperLimit
);
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java
View file @
3d15a4de
...
...
@@ -32,4 +32,5 @@ public class ErrorCodes {
public
static
final
String
STRATEGY_NOT_APPLICABLE
=
"xC8857"
;
public
static
final
String
CONTRADICTING_PARAMETERS
=
"xC8858"
;
public
static
final
String
CRITIC_NETWORK_ERROR
=
"xC7100"
;
public
static
final
String
MISSING_TRAINED_ARCHITECTURE
=
"xC7101"
;
}
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment