Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
languages
CNNTrainLang
Commits
907f1be8
Commit
907f1be8
authored
May 27, 2019
by
Evgeny Kusmenko
Browse files
Merge branch 'prepare-ddpg-algo' into 'master'
Prepare ddpg algorithm See merge request
!16
parents
9de665e2
2df512f2
Pipeline
#144443
passed with stages
in 7 minutes and 16 seconds
Changes
13
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
907f1be8
...
...
@@ -30,7 +30,7 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnn-train
</artifactId>
<version>
0.3.
0
-SNAPSHOT
</version>
<version>
0.3.
1
-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
...
...
src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
View file @
907f1be8
...
...
@@ -98,6 +98,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface
MultiParamValue
extends
ConfigValue
;
LearningMethodEntry
implements
ConfigEntry
=
name
:
"learning_method"
":"
value
:
LearningMethodValue
;
RLAlgorithmEntry
implements
ConfigEntry
=
name
:
"rl_algorithm"
":"
value
:
RLAlgorithmValue
;
NumEpisodesEntry
implements
ConfigEntry
=
name
:
"num_episodes"
":"
value
:
IntegerValue
;
DiscountFactorEntry
implements
ConfigEntry
=
name
:
"discount_factor"
":"
value
:
NumberValue
;
NumMaxStepsEntry
implements
ConfigEntry
=
name
:
"num_max_steps"
":"
value
:
IntegerValue
;
...
...
@@ -109,11 +110,14 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
AgentNameEntry
implements
ConfigEntry
=
name
:
"agent_name"
":"
value
:
StringValue
;
UseDoubleDQNEntry
implements
ConfigEntry
=
name
:
"use_double_dqn"
":"
value
:
BooleanValue
;
RewardFunctionEntry
implements
ConfigEntry
=
name
:
"reward_function"
":"
value
:
ComponentNameValue
;
CriticNetworkEntry
implements
ConfigEntry
=
name
:
"critic"
":"
value
:
ComponentNameValue
;
ComponentNameValue
implements
ConfigValue
=
Name
(
"."
Name
)*;
LearningMethodValue
implements
ConfigValue
=
(
supervisedLearning
:
"supervised"
|
reinforcement
:
"reinforcement"
);
RLAlgorithmValue
implements
ConfigValue
=
(
dqn
:
"dqn-algorithm"
|
ddpg
:
"ddpg-algorithm"
);
interface
MultiParamConfigEntry
extends
ConfigEntry
;
//
Replay
Memory
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java
View file @
907f1be8
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration
;
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
View file @
907f1be8
...
...
@@ -33,7 +33,8 @@ public class CNNTrainCocos {
.
addCoCo
(
new
CheckFixTargetNetworkRequiresInterval
())
.
addCoCo
(
new
CheckReinforcementRequiresEnvironment
())
.
addCoCo
(
new
CheckLearningParameterCombination
())
.
addCoCo
(
new
CheckRosEnvironmentRequiresRewardFunction
());
.
addCoCo
(
new
CheckRosEnvironmentRequiresRewardFunction
())
.
addCoCo
(
new
CheckDdpgRequiresCriticNetwork
());
}
public
static
void
checkAll
(
CNNTrainCompilationUnitSymbol
compilationUnit
){
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java
0 → 100644
View file @
907f1be8
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnntrain._cocos
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTCriticNetworkEntry
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry
;
import
de.monticore.lang.monticar.cnntrain.helper.ErrorCodes
;
import
de.se_rwth.commons.logging.Log
;
public
class
CheckDdpgRequiresCriticNetwork
implements
CNNTrainASTConfigurationCoCo
{
@Override
public
void
check
(
ASTConfiguration
node
)
{
boolean
isDdpg
=
node
.
getEntriesList
().
stream
()
.
anyMatch
(
e
->
e
instanceof
ASTRLAlgorithmEntry
&&
((
ASTRLAlgorithmEntry
)
e
).
getValue
().
isPresentDdpg
());
boolean
hasCriticEntry
=
node
.
getEntriesList
().
stream
()
.
anyMatch
(
e
->
((
e
instanceof
ASTCriticNetworkEntry
)
&&
!((
ASTCriticNetworkEntry
)
e
).
getValue
().
getNameList
().
isEmpty
()));
if
(
isDdpg
&&
!
hasCriticEntry
)
{
ASTRLAlgorithmEntry
algorithmEntry
=
node
.
getEntriesList
().
stream
()
.
filter
(
e
->
e
instanceof
ASTRLAlgorithmEntry
)
.
map
(
e
->
(
ASTRLAlgorithmEntry
)
e
)
.
findFirst
()
.
orElseThrow
(()
->
new
IllegalStateException
(
"ASTRLAlgorithmEntry entry must be available"
));
Log
.
error
(
"0"
+
ErrorCodes
.
REQUIRED_PARAMETER_MISSING
+
" DDPG learning algorithm requires critc"
+
" network entry"
,
algorithmEntry
.
get_SourcePositionStart
());
}
}
}
\ No newline at end of file
src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java
View file @
907f1be8
...
...
@@ -61,6 +61,8 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
);
private
final
static
List
<
Class
>
ALLOWED_REINFORCEMENT_LEARNING
=
Lists
.
newArrayList
(
ASTTrainContextEntry
.
class
,
ASTRLAlgorithmEntry
.
class
,
ASTCriticNetworkEntry
.
class
,
ASTOptimizerEntry
.
class
,
ASTRewardFunctionEntry
.
class
,
ASTMinimumLearningRateEntry
.
class
,
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
View file @
907f1be8
...
...
@@ -277,6 +277,19 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
return
value
;
}
private
ValueSymbol
getValueSymbolForComponentName
(
ASTComponentNameValue
astComponentNameValue
)
{
ValueSymbol
value
=
new
ValueSymbol
();
List
<
String
>
valueAsList
=
astComponentNameValue
.
getNameList
();
value
.
setValue
(
valueAsList
);
return
value
;
}
private
ValueSymbol
getValueSymbolForComponentNameAsString
(
ASTComponentNameValue
astComponentNameValue
)
{
ValueSymbol
value
=
new
ValueSymbol
();
value
.
setValue
(
String
.
join
(
"."
,
astComponentNameValue
.
getNameList
()));
return
value
;
}
private
String
getStringFromStringValue
(
ASTStringValue
value
)
{
return
value
.
getStringLiteral
().
getValue
();
}
...
...
@@ -310,6 +323,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration
.
getEntryMap
().
put
(
node
.
getName
(),
entry
);
}
@Override
public
void
visit
(
ASTRLAlgorithmEntry
node
)
{
EntrySymbol
entry
=
new
EntrySymbol
(
node
.
getName
());
ValueSymbol
value
=
new
ValueSymbol
();
if
(
node
.
getValue
().
isPresentDdpg
())
{
value
.
setValue
(
RLAlgorithm
.
DDPG
);
}
else
{
value
.
setValue
(
RLAlgorithm
.
DQN
);
}
entry
.
setValue
(
value
);
addToScopeAndLinkWithNode
(
entry
,
node
);
configuration
.
getEntryMap
().
put
(
node
.
getName
(),
entry
);
}
@Override
public
void
visit
(
ASTNumEpisodesEntry
node
)
{
EntrySymbol
entry
=
new
EntrySymbol
(
node
.
getName
());
...
...
@@ -390,6 +419,16 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration
.
getEntryMap
().
put
(
node
.
getName
(),
entry
);
}
@Override
public
void
visit
(
ASTCriticNetworkEntry
node
)
{
EntrySymbol
entry
=
new
EntrySymbol
(
node
.
getName
());
entry
.
setValue
(
getValueSymbolForComponentNameAsString
(
node
.
getValue
()));
addToScopeAndLinkWithNode
(
entry
,
node
);
configuration
.
getEntryMap
().
put
(
node
.
getName
(),
entry
);
}
@Override
public
void
visit
(
ASTReplayMemoryEntry
node
)
{
processMultiParamConfigVisit
(
node
,
node
.
getValue
().
getName
());
...
...
src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java
View file @
907f1be8
...
...
@@ -21,6 +21,7 @@
package
de.monticore.lang.monticar.cnntrain._symboltable
;
import
com.google.common.collect.Lists
;
import
de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture
;
import
de.monticore.symboltable.CommonScopeSpanningSymbol
;
import
java.util.*
;
...
...
@@ -30,12 +31,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private
Map
<
String
,
EntrySymbol
>
entryMap
=
new
HashMap
<>();
private
OptimizerSymbol
optimizer
;
private
RewardFunctionSymbol
rlRewardFunctionSymbol
;
private
TrainedArchitecture
trainedArchitecture
;
public
static
final
ConfigurationSymbolKind
KIND
=
new
ConfigurationSymbolKind
();
public
ConfigurationSymbol
()
{
super
(
""
,
KIND
);
rlRewardFunctionSymbol
=
null
;
trainedArchitecture
=
null
;
}
public
OptimizerSymbol
getOptimizer
()
{
...
...
@@ -54,6 +57,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return
Optional
.
ofNullable
(
this
.
rlRewardFunctionSymbol
);
}
public
Optional
<
TrainedArchitecture
>
getTrainedArchitecture
()
{
return
Optional
.
ofNullable
(
trainedArchitecture
);
}
public
void
setTrainedArchitecture
(
TrainedArchitecture
trainedArchitecture
)
{
this
.
trainedArchitecture
=
trainedArchitecture
;
}
public
Map
<
String
,
EntrySymbol
>
getEntryMap
()
{
return
entryMap
;
}
...
...
@@ -66,4 +77,4 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return
this
.
entryMap
.
containsKey
(
"learning_method"
)
?
(
LearningMethod
)
this
.
entryMap
.
get
(
"learning_method"
).
getValue
().
getValue
()
:
LearningMethod
.
SUPERVISED
;
}
}
}
\ No newline at end of file
src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java
0 → 100644
View file @
907f1be8
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnntrain._symboltable
;
public
enum
RLAlgorithm
{
DQN
{
@Override
public
String
toString
()
{
return
"dqn"
;
}
},
DDPG
{
@Override
public
String
toString
()
{
return
"ddpg"
;
}
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java
0 → 100644
View file @
907f1be8
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnntrain.annotations
;
import
java.util.Optional
;
public
class
Range
{
private
final
boolean
lowerLimitIsInfinity
;
private
final
boolean
upperLimitIsInfinity
;
private
final
Double
lowerLimit
;
private
final
Double
upperLimit
;
private
Range
(
boolean
lowerLimitIsInfinity
,
boolean
upperLimitIsInfinity
,
Double
lowerLimit
,
Double
upperLimit
)
{
this
.
lowerLimitIsInfinity
=
lowerLimitIsInfinity
;
this
.
upperLimitIsInfinity
=
upperLimitIsInfinity
;
this
.
lowerLimit
=
lowerLimit
;
this
.
upperLimit
=
upperLimit
;
}
public
Optional
<
Double
>
getLowerLimit
()
{
return
Optional
.
ofNullable
(
lowerLimit
);
}
public
Optional
<
Double
>
getUpperLimit
()
{
return
Optional
.
ofNullable
(
upperLimit
);
}
public
boolean
isLowerLimitInfinity
()
{
return
this
.
lowerLimitIsInfinity
;
}
public
boolean
isUpperLimitInfinity
()
{
return
this
.
upperLimitIsInfinity
;
}
public
static
Range
withLimits
(
double
lowerLimit
,
double
upperLimit
)
{
return
new
Range
(
false
,
false
,
lowerLimit
,
upperLimit
);
}
public
static
Range
withInfinityLimits
()
{
return
new
Range
(
true
,
true
,
null
,
null
);
}
public
static
Range
withUpperInfinityLimit
(
double
lowerLimit
)
{
return
new
Range
(
false
,
true
,
lowerLimit
,
null
);
}
public
static
Range
withLowerInfinityLimit
(
double
upperLimit
)
{
return
new
Range
(
true
,
false
,
null
,
upperLimit
);
}
}
src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java
0 → 100644
View file @
907f1be8
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnntrain.annotations
;
import
java.util.List
;
import
java.util.Map
;
public
interface
TrainedArchitecture
{
public
List
<
String
>
getInputs
();
public
List
<
String
>
getOutputs
();
public
Map
<
String
,
List
<
Integer
>>
getDimensions
();
public
Map
<
String
,
Range
>
getRanges
();
public
Map
<
String
,
String
>
getTypes
();
}
\ No newline at end of file
src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java
View file @
907f1be8
...
...
@@ -41,6 +41,7 @@ public class AllCoCoTest extends AbstractCoCoTest{
checkValid
(
"valid_tests"
,
"FullConfig2"
);
checkValid
(
"valid_tests"
,
"ReinforcementConfig"
);
checkValid
(
"valid_tests"
,
"ReinforcementConfig2"
);
checkValid
(
"valid_tests"
,
"DdpgConfig"
);
}
@Test
...
...
src/test/resources/valid_tests/DdpgConfig.cnnt
0 → 100644
View file @
907f1be8
configuration DdpgConfig {
learning_method : reinforcement
rl_algorithm : ddpg-algorithm
critic : path.to.component
environment : gym { name:"CartPole-v1" }
}
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment