Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNTrainLang
Commits
cdc5657a
Commit
cdc5657a
authored
Jul 09, 2019
by
Nicola Gatto
Browse files
Add TD3 reinforcement learning parameter
parent
b5b9cdaa
Pipeline
#158410
passed with stages
in 8 minutes and 46 seconds
Changes
9
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
View file @
cdc5657a
...
...
@@ -145,7 +145,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
LearningMethodValue
implements
ConfigValue
=
(
supervisedLearning
:
"supervised"
|
reinforcement
:
"reinforcement"
);
RLAlgorithmValue
implements
ConfigValue
=
(
dqn
:
"dqn-algorithm"
|
ddpg
:
"ddpg-algorithm"
);
RLAlgorithmValue
implements
ConfigValue
=
(
dqn
:
"dqn-algorithm"
|
ddpg
:
"ddpg-algorithm"
|
tdThree
:
"td3-algorithm"
);
interface
MultiParamConfigEntry
extends
ConfigEntry
;
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java
View file @
cdc5657a
...
...
@@ -41,8 +41,16 @@ class ASTConfigurationUtils {
e
->
(
e
instanceof
ASTRLAlgorithmEntry
)
&&
((
ASTRLAlgorithmEntry
)
e
).
getValue
().
isPresentDdpg
());
}
static
boolean
isTd3Algorithm
(
final
ASTConfiguration
configuration
)
{
return
isReinforcementLearning
(
configuration
)
&&
configuration
.
getEntriesList
().
stream
().
anyMatch
(
e
->
(
e
instanceof
ASTRLAlgorithmEntry
)
&&
((
ASTRLAlgorithmEntry
)
e
).
getValue
().
isPresentTdThree
());
}
static
boolean
isDqnAlgorithm
(
final
ASTConfiguration
configuration
)
{
return
isReinforcementLearning
(
configuration
)
&&
!
isDdpgAlgorithm
(
configuration
);
return
isReinforcementLearning
(
configuration
)
&&
!
isDdpgAlgorithm
(
configuration
)
&&
!
isTd3Algorithm
(
configuration
);
}
static
boolean
hasEntry
(
final
ASTConfiguration
configuration
,
final
Class
<?
extends
ASTConfigEntry
>
entryClazz
)
{
...
...
@@ -84,4 +92,18 @@ class ASTConfigurationUtils {
}
return
false
;
}
static
boolean
isActorCriticAlgorithm
(
final
ASTConfiguration
node
)
{
return
isDdpgAlgorithm
(
node
)
||
isTd3Algorithm
(
node
);
}
static
boolean
hasCriticEntry
(
final
ASTConfiguration
node
)
{
return
node
.
getEntriesList
().
stream
()
.
anyMatch
(
e
->
((
e
instanceof
ASTCriticNetworkEntry
)
&&
!((
ASTCriticNetworkEntry
)
e
).
getValue
().
getNameList
().
isEmpty
()));
}
public
static
boolean
isContinuousAlgorithm
(
final
ASTConfiguration
node
)
{
return
isDdpgAlgorithm
(
node
)
||
isTd3Algorithm
(
node
);
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
View file @
cdc5657a
...
...
@@ -34,7 +34,7 @@ public class CNNTrainCocos {
.
addCoCo
(
new
CheckReinforcementRequiresEnvironment
())
.
addCoCo
(
new
CheckLearningParameterCombination
())
.
addCoCo
(
new
CheckRosEnvironmentRequiresRewardFunction
())
.
addCoCo
(
new
Check
Ddpg
RequiresCriticNetwork
())
.
addCoCo
(
new
Check
ActorCritic
RequiresCriticNetwork
())
.
addCoCo
(
new
CheckRlAlgorithmParameter
())
.
addCoCo
(
new
CheckDiscreteRLAlgorithmUsesDiscreteStrategy
())
.
addCoCo
(
new
CheckContinuousRLAlgorithmUsesContinuousStrategy
())
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java
View file @
cdc5657a
...
...
@@ -35,7 +35,7 @@ public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrai
@Override
public
void
check
(
ASTConfiguration
node
)
{
if
(
ASTConfigurationUtils
.
is
Ddpg
Algorithm
(
node
)
if
(
ASTConfigurationUtils
.
is
Continuous
Algorithm
(
node
)
&&
ASTConfigurationUtils
.
hasStrategy
(
node
)
&&
ASTConfigurationUtils
.
getStrategyMethod
(
node
).
isPresent
())
{
final
String
usedStrategy
=
ASTConfigurationUtils
.
getStrategyMethod
(
node
).
get
();
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java
View file @
cdc5657a
...
...
@@ -20,6 +20,7 @@
*/
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTEntry
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry
;
import
de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm
;
...
...
@@ -29,61 +30,58 @@ import de.se_rwth.commons.logging.Log;
public
class
CheckRlAlgorithmParameter
implements
CNNTrainASTEntryCoCo
{
private
final
ParameterAlgorithmMapping
parameterAlgorithmMapping
;
boolean
algorithmKnown
;
private
boolean
isDqn
=
true
;
private
boolean
isDdpg
=
true
;
private
boolean
isTd3
=
true
;
RLAlgorithm
algorithm
;
public
CheckRlAlgorithmParameter
()
{
parameterAlgorithmMapping
=
new
ParameterAlgorithmMapping
();
algorithmKnown
=
false
;
algorithm
=
null
;
}
@Override
public
void
check
(
ASTEntry
node
)
{
final
boolean
isDdpgParameter
=
parameterAlgorithmMapping
.
is
Ddp
gParameter
(
node
.
getClass
())
;
final
boolean
isDqnParameter
=
parameterAlgorithmMapping
.
isDqnParameter
(
node
.
getClass
())
;
if
(!
parameterAlgorithmMapping
.
is
ReinforcementLearnin
gParameter
(
node
.
getClass
())
)
{
return
;
}
if
(
node
instanceof
ASTRLAlgorithmEntry
)
{
ASTRLAlgorithmEntry
algorithmEntry
=
(
ASTRLAlgorithmEntry
)
node
;
if
(
algorithmEntry
.
getValue
().
isPresentDdpg
())
{
setAlgorithmToDdpg
(
node
);
logWrongParameterIfCheckFails
(
isDdpg
,
node
);
isTd3
=
false
;
isDqn
=
false
;
}
else
if
(
algorithmEntry
.
getValue
().
isPresentTdThree
())
{
logWrongParameterIfCheckFails
(
isTd3
,
node
);
isDdpg
=
false
;
isDqn
=
false
;
}
else
{
setAlgorithmToDqn
(
node
);
logWrongParameterIfCheckFails
(
isDqn
,
node
);
isDdpg
=
false
;
isTd3
=
false
;
}
}
else
{
if
(
isDdpgParameter
&&
!
isDqnParameter
)
{
setAlgorithmToDdpg
(
node
);
}
else
if
(!
isDdpgParameter
&&
isDqnParameter
)
{
setAlgorithmToDqn
(
node
);
final
boolean
isDdpgParameter
=
parameterAlgorithmMapping
.
isDdpgParameter
(
node
.
getClass
());
final
boolean
isDqnParameter
=
parameterAlgorithmMapping
.
isDqnParameter
(
node
.
getClass
());
final
boolean
isTd3Parameter
=
parameterAlgorithmMapping
.
isTd3Parameter
(
node
.
getClass
());
if
(!
isDdpgParameter
)
{
isDdpg
=
false
;
}
if
(!
isTd3Parameter
)
{
isTd3
=
false
;
}
if
(!
isDqnParameter
)
{
isDqn
=
false
;
}
}
logWrongParameterIfCheckFails
(
isDqn
||
isTd3
||
isDdpg
,
node
);
}
private
void
logErrorIfAlgorithmIsDqn
(
final
ASTEntry
node
)
{
if
(
algorithmKnown
&&
algorithm
.
equals
(
RLAlgorithm
.
DQN
))
{
Log
.
error
(
"0"
+
ErrorCodes
.
UNSUPPORTED_PARAMETER
+
" DDPG Parameter "
+
node
.
getName
()
+
" used but algorithm is "
+
algorithm
+
"."
,
node
.
get_SourcePositionStart
());
}
}
private
void
setAlgorithmToDdpg
(
final
ASTEntry
node
)
{
logErrorIfAlgorithmIsDqn
(
node
);
algorithmKnown
=
true
;
algorithm
=
RLAlgorithm
.
DDPG
;
}
private
void
setAlgorithmToDqn
(
final
ASTEntry
node
)
{
logErrorIfAlgorithmIsDdpg
(
node
);
algorithmKnown
=
true
;
algorithm
=
RLAlgorithm
.
DQN
;
}
private
void
logErrorIfAlgorithmIsDdpg
(
final
ASTEntry
node
)
{
if
(
algorithmKnown
&&
algorithm
.
equals
(
RLAlgorithm
.
DDPG
))
{
private
void
logWrongParameterIfCheckFails
(
final
boolean
condition
,
final
ASTEntry
node
)
{
if
(!
condition
)
{
Log
.
error
(
"0"
+
ErrorCodes
.
UNSUPPORTED_PARAMETER
+
" DQN
Parameter "
+
node
.
getName
()
+
" used but
algorithm is "
+
algorithm
+
"
."
,
+
"
Parameter "
+
node
.
getName
()
+
" used but
parameter is not for chosen
algorithm."
,
node
.
get_SourcePositionStart
());
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java
View file @
cdc5657a
...
...
@@ -108,6 +108,8 @@ class ParameterAlgorithmMapping {
ASTStrategyOUSigma
.
class
);
private
static
final
List
<
Class
>
EXCLUSIVE_TD3_PARAMETERS
=
Lists
.
newArrayList
(
EXCLUSIVE_DDPG_PARAMETERS
);
ParameterAlgorithmMapping
()
{
}
...
...
@@ -136,12 +138,19 @@ class ParameterAlgorithmMapping {
||
EXCLUSIVE_DDPG_PARAMETERS
.
contains
(
entryClazz
);
}
boolean
isTd3Parameter
(
Class
<?
extends
ASTEntry
>
entryClazz
)
{
return
GENERAL_PARAMETERS
.
contains
(
entryClazz
)
||
GENERAL_REINFORCEMENT_PARAMETERS
.
contains
(
entryClazz
)
||
EXCLUSIVE_TD3_PARAMETERS
.
contains
(
entryClazz
);
}
List
<
Class
>
getAllReinforcementParameters
()
{
return
ImmutableList
.<
Class
>
builder
()
.
addAll
(
GENERAL_PARAMETERS
)
.
addAll
(
GENERAL_REINFORCEMENT_PARAMETERS
)
.
addAll
(
EXCLUSIVE_DQN_PARAMETERS
)
.
addAll
(
EXCLUSIVE_DDPG_PARAMETERS
)
.
addAll
(
EXCLUSIVE_TD3_PARAMETERS
)
.
build
();
}
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
View file @
cdc5657a
...
...
@@ -351,6 +351,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
if
(
node
.
getValue
().
isPresentDdpg
())
{
value
.
setValue
(
RLAlgorithm
.
DDPG
);
}
else
if
(
node
.
getValue
().
isPresentTdThree
())
{
value
.
setValue
(
RLAlgorithm
.
TD3
);
}
else
{
value
.
setValue
(
RLAlgorithm
.
DQN
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java
View file @
cdc5657a
...
...
@@ -32,5 +32,11 @@ public enum RLAlgorithm {
public
String
toString
()
{
return
"ddpg"
;
}
},
TD3
{
@Override
public
String
toString
()
{
return
"td3"
;
}
}
}
}
\ No newline at end of file
src/test/resources/valid_tests/TD3Config.cnnt
View file @
cdc5657a
configuration TD3Config {
learning_method : reinforcement
rl_algorithm :
ddpg
-algorithm
rl_algorithm :
td3
-algorithm
critic : path.to.component
environment : gym { name:"CartPole-v1" }
soft_target_update_rate: 0.001
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment