Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
C
CNNArch2Gluon
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
1
Issues
1
List
Boards
Labels
Service Desk
Milestones
Iterations
Merge Requests
0
Merge Requests
0
Requirements
Requirements
List
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Test Cases
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Package Registry
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issue
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
f9465a60
Commit
f9465a60
authored
May 16, 2019
by
Nicola Gatto
Committed by
Evgeny Kusmenko
May 16, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Implement reinforcement learning
parent
2dd5aae6
Changes
51
Show whitespace changes
Inline
Side-by-side
Showing
51 changed files
with
5279 additions
and
57 deletions
+5279
-57
.gitignore
.gitignore
+1
-0
.gitlab-ci.yml
.gitlab-ci.yml
+9
-3
pom.xml
pom.xml
+16
-2
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
...e/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
+4
-2
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
.../lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
+161
-7
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java
...nnarch/gluongenerator/ReinforcementConfigurationData.java
+216
-0
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
...luongenerator/reinforcement/FunctionParameterChecker.java
+109
-0
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionParameterAdapter.java
...nerator/reinforcement/RewardFunctionParameterAdapter.java
+136
-0
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionSourceGenerator.java
...enerator/reinforcement/RewardFunctionSourceGenerator.java
+8
-0
src/main/resources/templates/gluon/CNNCreator.ftl
src/main/resources/templates/gluon/CNNCreator.ftl
+4
-1
src/main/resources/templates/gluon/CNNPredictor.ftl
src/main/resources/templates/gluon/CNNPredictor.ftl
+8
-6
src/main/resources/templates/gluon/reinforcement/StartTrainer.ftl
.../resources/templates/gluon/reinforcement/StartTrainer.ftl
+3
-0
src/main/resources/templates/gluon/reinforcement/Trainer.ftl
src/main/resources/templates/gluon/reinforcement/Trainer.ftl
+182
-0
src/main/resources/templates/gluon/reinforcement/agent/ActionPolicy.ftl
...rces/templates/gluon/reinforcement/agent/ActionPolicy.ftl
+73
-0
src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
...n/resources/templates/gluon/reinforcement/agent/Agent.ftl
+506
-0
src/main/resources/templates/gluon/reinforcement/agent/ReplayMemory.ftl
...rces/templates/gluon/reinforcement/agent/ReplayMemory.ftl
+155
-0
src/main/resources/templates/gluon/reinforcement/environment/Environment.ftl
...templates/gluon/reinforcement/environment/Environment.ftl
+217
-0
src/main/resources/templates/gluon/reinforcement/util/Util.ftl
...ain/resources/templates/gluon/reinforcement/util/Util.ftl
+140
-0
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java
.../lang/monticar/cnnarch/gluongenerator/GenerationTest.java
+35
-3
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/IntegrationPythonWrapperTest.java
.../cnnarch/gluongenerator/IntegrationPythonWrapperTest.java
+53
-0
src/test/resources/target_code/CNNCreator_Alexnet.py
src/test/resources/target_code/CNNCreator_Alexnet.py
+4
-1
src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
...esources/target_code/CNNCreator_CifarClassifierNetwork.py
+4
-1
src/test/resources/target_code/CNNCreator_VGG16.py
src/test/resources/target_code/CNNCreator_VGG16.py
+4
-1
src/test/resources/target_code/CNNPredictor_Alexnet.h
src/test/resources/target_code/CNNPredictor_Alexnet.h
+5
-10
src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h
...sources/target_code/CNNPredictor_CifarClassifierNetwork.h
+5
-10
src/test/resources/target_code/CNNPredictor_VGG16.h
src/test/resources/target_code/CNNPredictor_VGG16.h
+5
-10
src/test/resources/target_code/ReinforcementConfig1/CNNTrainer_reinforcementConfig1.py
...e/ReinforcementConfig1/CNNTrainer_reinforcementConfig1.py
+101
-0
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/__init__.py
...e/ReinforcementConfig1/reinforcement_learning/__init__.py
+0
-0
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/action_policy.py
...nforcementConfig1/reinforcement_learning/action_policy.py
+73
-0
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/agent.py
...code/ReinforcementConfig1/reinforcement_learning/agent.py
+506
-0
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/environment.py
...einforcementConfig1/reinforcement_learning/environment.py
+144
-0
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/replay_memory.py
...nforcementConfig1/reinforcement_learning/replay_memory.py
+155
-0
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/util.py
..._code/ReinforcementConfig1/reinforcement_learning/util.py
+140
-0
src/test/resources/target_code/ReinforcementConfig1/start_training.sh
...ources/target_code/ReinforcementConfig1/start_training.sh
+2
-0
src/test/resources/target_code/ReinforcementConfig2/CNNTrainer_reinforcementConfig2.py
...e/ReinforcementConfig2/CNNTrainer_reinforcementConfig2.py
+106
-0
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/__init__.py
...e/ReinforcementConfig2/reinforcement_learning/__init__.py
+0
-0
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/action_policy.py
...nforcementConfig2/reinforcement_learning/action_policy.py
+73
-0
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/agent.py
...code/ReinforcementConfig2/reinforcement_learning/agent.py
+506
-0
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/environment.py
...einforcementConfig2/reinforcement_learning/environment.py
+71
-0
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/replay_memory.py
...nforcementConfig2/reinforcement_learning/replay_memory.py
+155
-0
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/util.py
..._code/ReinforcementConfig2/reinforcement_learning/util.py
+140
-0
src/test/resources/target_code/ReinforcementConfig2/start_training.sh
...ources/target_code/ReinforcementConfig2/start_training.sh
+2
-0
src/test/resources/target_code/reinforcement_learning/__init__.py
.../resources/target_code/reinforcement_learning/__init__.py
+0
-0
src/test/resources/target_code/reinforcement_learning/action_policy.py
...urces/target_code/reinforcement_learning/action_policy.py
+73
-0
src/test/resources/target_code/reinforcement_learning/agent.py
...est/resources/target_code/reinforcement_learning/agent.py
+503
-0
src/test/resources/target_code/reinforcement_learning/environment.py
...sources/target_code/reinforcement_learning/environment.py
+67
-0
src/test/resources/target_code/reinforcement_learning/replay_memory.py
...urces/target_code/reinforcement_learning/replay_memory.py
+155
-0
src/test/resources/target_code/reinforcement_learning/util.py
...test/resources/target_code/reinforcement_learning/util.py
+134
-0
src/test/resources/valid_tests/ReinforcementConfig1.cnnt
src/test/resources/valid_tests/ReinforcementConfig1.cnnt
+45
-0
src/test/resources/valid_tests/ReinforcementConfig2.cnnt
src/test/resources/valid_tests/ReinforcementConfig2.cnnt
+50
-0
src/test/resources/valid_tests/reward/RewardFunction.emadl
src/test/resources/valid_tests/reward/RewardFunction.emadl
+15
-0
No files found.
.gitignore
View file @
f9465a60
...
...
@@ -8,3 +8,4 @@ nppBackup
*.iml
.vscode
.gitlab-ci.yml
View file @
f9465a60
...
...
@@ -27,7 +27,7 @@ masterJobLinux:
stage
:
linux
image
:
maven:3-jdk-8
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
-Dtest=\!Integration*
-
cat target/site/jacoco/index.html
-
mvn package sonar:sonar -s settings.xml
only
:
...
...
@@ -36,7 +36,7 @@ masterJobLinux:
masterJobWindows
:
stage
:
windows
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-Dtest=\!Integration*
tags
:
-
Windows10
...
...
@@ -44,7 +44,13 @@ BranchJobLinux:
stage
:
linux
image
:
maven:3-jdk-8
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-Dtest=\!Integration*
-
cat target/site/jacoco/index.html
except
:
-
master
PythonWrapperIntegrationTest
:
stage
:
linux
image
:
registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
pom.xml
View file @
f9465a60
...
...
@@ -8,7 +8,7 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnnarch-gluon-generator
</artifactId>
<version>
0.
1.6
</version>
<version>
0.
2.0-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
...
...
@@ -16,9 +16,10 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.3.0-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.
2.6
</CNNTrain.version>
<CNNTrain.version>
0.
3.0-SNAPSHOT
</CNNTrain.version>
<CNNArch2MXNet.version>
0.2.14-SNAPSHOT
</CNNArch2MXNet.version>
<embedded-montiarc-math-opt-generator>
0.1.4
</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>
0.0.1
</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>
18.0
</guava.version>
...
...
@@ -100,6 +101,12 @@
<version>
${embedded-montiarc-math-opt-generator}
</version>
</dependency>
<dependency>
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
embedded-montiarc-emadl-pythonwrapper-generator
</artifactId>
<version>
${EMADL2PythonWrapper.version}
</version>
</dependency>
<!-- .. Test Libraries ............................................... -->
<dependency>
...
...
@@ -109,6 +116,13 @@
<scope>
test
</scope>
</dependency>
<dependency>
<groupId>
org.mockito
</groupId>
<artifactId>
mockito-core
</artifactId>
<version>
1.10.19
</version>
<scope>
test
</scope>
</dependency>
<dependency>
<groupId>
ch.qos.logback
</groupId>
<artifactId>
logback-classic
</artifactId>
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
View file @
f9465a60
...
...
@@ -47,8 +47,10 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
temp
=
archTc
.
process
(
"CNNNet"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
if
(
architecture
.
getDataPath
()
!=
null
)
{
temp
=
archTc
.
process
(
"CNNDataLoader"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
}
temp
=
archTc
.
process
(
"CNNCreator"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
View file @
f9465a60
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
import
com.google.common.collect.Maps
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod
;
import
de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cpp.GeneratorCPP
;
import
de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation
;
import
de.se_rwth.commons.logging.Log
;
import
java.io.File
;
import
java.io.IOException
;
import
java.nio.charset.Charset
;
import
java.nio.charset.StandardCharsets
;
import
java.nio.file.Files
;
import
java.nio.file.Path
;
import
java.nio.file.Paths
;
import
java.util.*
;
public
class
CNNTrain2Gluon
extends
CNNTrain2MxNet
{
public
CNNTrain2Gluon
()
{
private
static
final
String
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
=
"reinforcement_learning"
;
private
final
RewardFunctionSourceGenerator
rewardFunctionSourceGenerator
;
private
String
rootProjectModelsDir
;
public
Optional
<
String
>
getRootProjectModelsDir
()
{
return
Optional
.
ofNullable
(
rootProjectModelsDir
);
}
public
void
setRootProjectModelsDir
(
String
rootProjectModelsDir
)
{
this
.
rootProjectModelsDir
=
rootProjectModelsDir
;
}
public
CNNTrain2Gluon
(
RewardFunctionSourceGenerator
rewardFunctionSourceGenerator
)
{
super
();
this
.
rewardFunctionSourceGenerator
=
rewardFunctionSourceGenerator
;
}
@Override
public
ConfigurationSymbol
getConfigurationSymbol
(
Path
modelsDirPath
,
String
rootModelName
)
{
ConfigurationSymbol
configurationSymbol
=
super
.
getConfigurationSymbol
(
modelsDirPath
,
rootModelName
);
// Generate Reward function if necessary
if
(
configurationSymbol
.
getLearningMethod
().
equals
(
LearningMethod
.
REINFORCEMENT
)
&&
configurationSymbol
.
getRlRewardFunction
().
isPresent
())
{
generateRewardFunction
(
configurationSymbol
.
getRlRewardFunction
().
get
(),
modelsDirPath
);
}
return
configurationSymbol
;
}
@Override
public
void
generate
(
Path
modelsDirPath
,
String
rootModelName
)
{
ConfigurationSymbol
configuration
=
this
.
getConfigurationSymbol
(
modelsDirPath
,
rootModelName
);
Map
<
String
,
String
>
fileContents
=
this
.
generateStrings
(
configuration
);
GeneratorCPP
genCPP
=
new
GeneratorCPP
();
genCPP
.
setGenerationTargetPath
(
this
.
getGenerationTargetPath
());
try
{
Iterator
var6
=
fileContents
.
keySet
().
iterator
();
while
(
var6
.
hasNext
())
{
String
fileName
=
(
String
)
var6
.
next
();
genCPP
.
generateFile
(
new
FileContent
((
String
)
fileContents
.
get
(
fileName
),
fileName
));
}
}
catch
(
IOException
var8
)
{
Log
.
error
(
"CNNTrainer file could not be generated"
+
var8
.
getMessage
());
}
}
@Override
public
Map
<
String
,
String
>
generateStrings
(
ConfigurationSymbol
configuration
)
{
TemplateConfiguration
templateConfiguration
=
new
GluonTemplateConfiguration
();
ConfigurationData
configData
=
new
ConfigurationData
(
configuration
,
getInstanceName
());
ReinforcementConfigurationData
configData
=
new
Reinforcement
ConfigurationData
(
configuration
,
getInstanceName
());
List
<
ConfigurationData
>
configDataList
=
new
ArrayList
<>();
configDataList
.
add
(
configData
);
Map
<
String
,
Object
>
ftlContext
=
Collections
.
singletonMap
(
"configurations"
,
configDataList
);
Map
<
String
,
Object
>
ftlContext
=
Maps
.
newHashMap
();
ftlContext
.
put
(
"configurations"
,
configDataList
);
Map
<
String
,
String
>
fileContentMap
=
new
HashMap
<>();
if
(
configData
.
isSupervisedLearning
())
{
String
cnnTrainTemplateContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"CNNTrainer.ftl"
);
fileContentMap
.
put
(
"CNNTrainer_"
+
getInstanceName
()
+
".py"
,
cnnTrainTemplateContent
);
String
cnnSupervisedTrainerContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"CNNSupervisedTrainer.ftl"
);
fileContentMap
.
put
(
"supervised_trainer.py"
,
cnnSupervisedTrainerContent
);
}
else
if
(
configData
.
isReinforcementLearning
())
{
final
String
trainerName
=
"CNNTrainer_"
+
getInstanceName
();
ftlContext
.
put
(
"trainerName"
,
trainerName
);
Map
<
String
,
String
>
rlFrameworkContentMap
=
constructReinforcementLearningFramework
(
templateConfiguration
,
ftlContext
);
fileContentMap
.
putAll
(
rlFrameworkContentMap
);
final
String
reinforcementTrainerContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/Trainer.ftl"
);
fileContentMap
.
put
(
trainerName
+
".py"
,
reinforcementTrainerContent
);
final
String
startTrainerScriptContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/StartTrainer.ftl"
);
fileContentMap
.
put
(
"start_training.sh"
,
startTrainerScriptContent
);
}
return
fileContentMap
;
}
private
void
generateRewardFunction
(
RewardFunctionSymbol
rewardFunctionSymbol
,
Path
modelsDirPath
)
{
GeneratorPythonWrapperStandaloneApi
pythonWrapperApi
=
new
GeneratorPythonWrapperStandaloneApi
();
List
<
String
>
fullNameOfComponent
=
rewardFunctionSymbol
.
getRewardFunctionComponentName
();
String
rewardFunctionRootModel
=
String
.
join
(
"."
,
fullNameOfComponent
);
String
rewardFunctionOutputPath
=
Paths
.
get
(
this
.
getGenerationTargetPath
(),
"reward"
).
toString
();
if
(!
getRootProjectModelsDir
().
isPresent
())
{
setRootProjectModelsDir
(
modelsDirPath
.
toString
());
}
rewardFunctionSourceGenerator
.
generate
(
getRootProjectModelsDir
().
get
(),
rewardFunctionRootModel
,
rewardFunctionOutputPath
);
fixArmadilloEmamGenerationOfFile
(
Paths
.
get
(
rewardFunctionOutputPath
,
String
.
join
(
"_"
,
fullNameOfComponent
)
+
".h"
));
String
pythonWrapperOutputPath
=
Paths
.
get
(
rewardFunctionOutputPath
,
"pylib"
).
toString
();
Log
.
info
(
"Generating reward function python wrapper..."
,
"CNNTrain2Gluon"
);
ComponentPortInformation
componentPortInformation
;
if
(
pythonWrapperApi
.
checkIfPythonModuleBuildAvailable
())
{
final
String
rewardModuleOutput
=
Paths
.
get
(
getGenerationTargetPath
(),
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
).
toString
();
componentPortInformation
=
pythonWrapperApi
.
generateAndTryBuilding
(
getRootProjectModelsDir
().
get
(),
rewardFunctionRootModel
,
pythonWrapperOutputPath
,
rewardModuleOutput
);
}
else
{
Log
.
warn
(
"Cannot build wrapper automatically: OS not supported. Please build manually before starting training."
);
componentPortInformation
=
pythonWrapperApi
.
generate
(
getRootProjectModelsDir
().
get
(),
rewardFunctionRootModel
,
pythonWrapperOutputPath
);
}
RewardFunctionParameterAdapter
functionParameter
=
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
new
FunctionParameterChecker
().
check
(
functionParameter
);
rewardFunctionSymbol
.
setRewardFunctionParameter
(
functionParameter
);
}
private
void
fixArmadilloEmamGenerationOfFile
(
Path
pathToBrokenFile
){
final
File
brokenFile
=
pathToBrokenFile
.
toFile
();
if
(
brokenFile
.
exists
())
{
try
{
Charset
charset
=
StandardCharsets
.
UTF_8
;
String
fileContent
=
new
String
(
Files
.
readAllBytes
(
pathToBrokenFile
),
charset
);
fileContent
=
fileContent
.
replace
(
"armadillo.h"
,
"armadillo"
);
Files
.
write
(
pathToBrokenFile
,
fileContent
.
getBytes
());
}
catch
(
IOException
e
)
{
Log
.
warn
(
"Cannot fix wrong armadillo library in "
+
pathToBrokenFile
.
toString
());
}
}
}
private
Map
<
String
,
String
>
constructReinforcementLearningFramework
(
final
TemplateConfiguration
templateConfiguration
,
final
Map
<
String
,
Object
>
ftlContext
)
{
Map
<
String
,
String
>
fileContentMap
=
Maps
.
newHashMap
();
ftlContext
.
put
(
"rlFrameworkModule"
,
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
);
final
String
reinforcementAgentContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/agent/Agent.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/agent.py"
,
reinforcementAgentContent
);
final
String
reinforcementPolicyContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/agent/ActionPolicy.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/action_policy.py"
,
reinforcementPolicyContent
);
final
String
replayMemoryContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/agent/ReplayMemory.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/replay_memory.py"
,
replayMemoryContent
);
final
String
environmentContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/environment/Environment.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/environment.py"
,
environmentContent
);
final
String
utilContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/util/Util.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/util.py"
,
utilContent
);
final
String
initContent
=
""
;
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/__init__.py"
,
initContent
);
return
fileContentMap
;
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java
0 → 100644
View file @
f9465a60
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData
;
import
de.monticore.lang.monticar.cnntrain._symboltable.*
;
import
java.util.HashMap
;
import
java.util.List
;
import
java.util.Map
;
import
java.util.Optional
;
/**
*
*/
public
class
ReinforcementConfigurationData
extends
ConfigurationData
{
private
static
final
String
AST_ENTRY_LEARNING_METHOD
=
"learning_method"
;
private
static
final
String
AST_ENTRY_NUM_EPISODES
=
"num_episodes"
;
private
static
final
String
AST_ENTRY_DISCOUNT_FACTOR
=
"discount_factor"
;
private
static
final
String
AST_ENTRY_NUM_MAX_STEPS
=
"num_max_steps"
;
private
static
final
String
AST_ENTRY_TARGET_SCORE
=
"target_score"
;
private
static
final
String
AST_ENTRY_TRAINING_INTERVAL
=
"training_interval"
;
private
static
final
String
AST_ENTRY_USE_FIX_TARGET_NETWORK
=
"use_fix_target_network"
;
private
static
final
String
AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL
=
"target_network_update_interval"
;
private
static
final
String
AST_ENTRY_SNAPSHOT_INTERVAL
=
"snapshot_interval"
;
private
static
final
String
AST_ENTRY_AGENT_NAME
=
"agent_name"
;
private
static
final
String
AST_ENTRY_USE_DOUBLE_DQN
=
"use_double_dqn"
;
private
static
final
String
AST_ENTRY_LOSS
=
"loss"
;
private
static
final
String
AST_ENTRY_REPLAY_MEMORY
=
"replay_memory"
;
private
static
final
String
AST_ENTRY_ACTION_SELECTION
=
"action_selection"
;
private
static
final
String
AST_ENTRY_ENVIRONMENT
=
"environment"
;
public
ReinforcementConfigurationData
(
ConfigurationSymbol
configuration
,
String
instanceName
)
{
super
(
configuration
,
instanceName
);
}
public
Boolean
isSupervisedLearning
()
{
if
(
configurationContainsKey
(
AST_ENTRY_LEARNING_METHOD
))
{
return
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_LEARNING_METHOD
)
.
equals
(
LearningMethod
.
SUPERVISED
);
}
return
true
;
}
public
Boolean
isReinforcementLearning
()
{
return
configurationContainsKey
(
AST_ENTRY_LEARNING_METHOD
)
&&
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_LEARNING_METHOD
).
equals
(
LearningMethod
.
REINFORCEMENT
);
}
public
Integer
getNumEpisodes
()
{
return
!
configurationContainsKey
(
AST_ENTRY_NUM_EPISODES
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_NUM_EPISODES
);
}
public
Double
getDiscountFactor
()
{
return
!
configurationContainsKey
(
AST_ENTRY_DISCOUNT_FACTOR
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_DISCOUNT_FACTOR
);
}
public
Integer
getNumMaxSteps
()
{
return
!
configurationContainsKey
(
AST_ENTRY_NUM_MAX_STEPS
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_NUM_MAX_STEPS
);
}
public
Double
getTargetScore
()
{
return
!
configurationContainsKey
(
AST_ENTRY_TARGET_SCORE
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_TARGET_SCORE
);
}
public
Integer
getTrainingInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_TRAINING_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_TRAINING_INTERVAL
);
}
public
Boolean
getUseFixTargetNetwork
()
{
return
!
configurationContainsKey
(
AST_ENTRY_USE_FIX_TARGET_NETWORK
)
?
null
:
(
Boolean
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_USE_FIX_TARGET_NETWORK
);
}
public
Integer
getTargetNetworkUpdateInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL
);
}
public
Integer
getSnapshotInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_SNAPSHOT_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_SNAPSHOT_INTERVAL
);
}
public
String
getAgentName
()
{
return
!
configurationContainsKey
(
AST_ENTRY_AGENT_NAME
)
?
null
:
(
String
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_AGENT_NAME
);
}
public
Boolean
getUseDoubleDqn
()
{
return
!
configurationContainsKey
(
AST_ENTRY_USE_DOUBLE_DQN
)
?
null
:
(
Boolean
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_USE_DOUBLE_DQN
);
}
public
String
getLoss
()
{
return
!
configurationContainsKey
(
AST_ENTRY_LOSS
)
?
null
:
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_LOSS
).
toString
();
}
public
Map
<
String
,
Object
>
getReplayMemory
()
{
return
getMultiParamEntry
(
AST_ENTRY_REPLAY_MEMORY
,
"method"
);
}
public
Map
<
String
,
Object
>
getActionSelection
()
{
return
getMultiParamEntry
(
AST_ENTRY_ACTION_SELECTION
,
"method"
);
}
public
Map
<
String
,
Object
>
getEnvironment
()
{
return
getMultiParamEntry
(
AST_ENTRY_ENVIRONMENT
,
"environment"
);
}
public
Boolean
hasRewardFunction
()
{
return
this
.
getConfiguration
().
getRlRewardFunction
().
isPresent
();
}
public
String
getRewardFunctionName
()
{
if
(!
this
.
getConfiguration
().
getRlRewardFunction
().
isPresent
())
{
return
null
;
}
return
String
.
join
(
"_"
,
this
.
getConfiguration
().
getRlRewardFunction
()
.
get
().
getRewardFunctionComponentName
());
}
private
Optional
<
RewardFunctionParameterAdapter
>
getRlRewardFunctionParameter
()
{
if
(!
this
.
getConfiguration
().
getRlRewardFunction
().
isPresent
()
||
!
this
.
getConfiguration
().
getRlRewardFunction
().
get
().
getRewardFunctionParameter
().
isPresent
())
{
return
Optional
.
empty
();
}
return
Optional
.
ofNullable
(
(
RewardFunctionParameterAdapter
)
this
.
getConfiguration
().
getRlRewardFunction
().
get
()
.
getRewardFunctionParameter
().
orElse
(
null
));
}
public
Map
<
String
,
Object
>
getRewardFunctionStateParameter
()
{
if
(!
getRlRewardFunctionParameter
().
isPresent
()
||
!
getRlRewardFunctionParameter
().
get
().
getInputStateParameterName
().
isPresent
())
{
return
null
;
}
return
getInputParameterWithName
(
getRlRewardFunctionParameter
().
get
().
getInputStateParameterName
().
get
());
}
public
Map
<
String
,
Object
>
getRewardFunctionTerminalParameter
()
{
if
(!
getRlRewardFunctionParameter
().
isPresent
()
||
!
getRlRewardFunctionParameter
().
get
().
getInputTerminalParameter
().
isPresent
())
{
return
null
;
}
return
getInputParameterWithName
(
getRlRewardFunctionParameter
().
get
().
getInputTerminalParameter
().
get
());
}
public
String
getRewardFunctionOutputName
()
{
if
(!
getRlRewardFunctionParameter
().
isPresent
())
{
return
null
;
}
return
getRlRewardFunctionParameter
().
get
().
getOutputParameterName
().
orElse
(
null
);
}
private
Map
<
String
,
Object
>
getMultiParamEntry
(
final
String
key
,
final
String
valueName
)
{
if
(!
configurationContainsKey
(
key
))
{
return
null
;
}
Map
<
String
,
Object
>
resultView
=
new
HashMap
<>();
MultiParamValueSymbol
multiParamValue
=
(
MultiParamValueSymbol
)
this
.
getConfiguration
().
getEntryMap
()
.
get
(
key
).
getValue
();
resultView
.
put
(
valueName
,
multiParamValue
.
getValue
());
resultView
.
putAll
(
multiParamValue
.
getParameters
());
return
resultView
;
}
private
Boolean
configurationContainsKey
(
final
String
key
)
{
return
this
.
getConfiguration
().
getEntryMap
().
containsKey
(
key
);
}
private
Object
retrieveConfigurationEntryValueByKey
(
final
String
key
)
{
return
this
.
getConfiguration
().
getEntry
(
key
).
getValue
().
getValue
();
}
private
Map
<
String
,
Object
>
getInputParameterWithName
(
final
String
parameterName
)
{
if
(!
getRlRewardFunctionParameter
().
isPresent
()
||
!
getRlRewardFunctionParameter
().
get
().
getTypeOfInputPort
(
parameterName
).
isPresent
()
||
!
getRlRewardFunctionParameter
().
get
().
getInputPortDimensionOfPort
(
parameterName
).
isPresent
())
{
return
null
;
}
Map
<
String
,
Object
>
functionStateParameter
=
new
HashMap
<>();;
final
String
portType
=
getRlRewardFunctionParameter
().
get
().
getTypeOfInputPort
(
parameterName
).
get
();
final
List
<
Integer
>
dimension
=
getRlRewardFunctionParameter
().
get
().
getInputPortDimensionOfPort
(
parameterName
).
get
();
String
dtype
=
null
;
if
(
portType
.
equals
(
"Q"
))
{
dtype
=
"double"
;
}
else
if
(
portType
.
equals
(
"Z"
))
{
dtype
=
"int"
;
}
else
if
(
portType
.
equals
(
"B"
))
{
dtype
=
"bool"
;
}
Boolean
isMultiDimensional
=
dimension
.
size
()
>
1
||
(
dimension
.
size
()
==
1
&&
dimension
.
get
(
0
)
>
1
);
functionStateParameter
.
put
(
"name"
,
parameterName
);
functionStateParameter
.
put
(
"dtype"
,
dtype
);
functionStateParameter
.
put
(
"isMultiDimensional"
,
isMultiDimensional
);
return
functionStateParameter
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
0 → 100644
View file @
f9465a60
package
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement
;
import
de.se_rwth.commons.logging.Log
;
/**
*
*/
public
class
FunctionParameterChecker
{
private
String
inputStateParameterName
;
private
String
inputTerminalParameterName
;
private
String
outputParameterName
;
private
RewardFunctionParameterAdapter
rewardFunctionParameter
;
public
FunctionParameterChecker
()
{
}
public
void
check
(
final
RewardFunctionParameterAdapter
rewardFunctionParameter
)
{
this
.
rewardFunctionParameter
=
rewardFunctionParameter
;
retrieveParameterNames
();
checkHasExactlyTwoInputs
();
checkHasExactlyOneOutput
();
checkHasStateAndTerminalInput
();
checkInputStateDimension
();
checkInputTerminalTypeAndDimension
();
checkOutputDimension
();
}
private
void
checkHasExactlyTwoInputs
()
{
failIfConditionFails
(
functionHasTwoInputs
(),
"Reward function must have exactly two input parameters: "
+
"One input needs to represents the environment's state and another input needs to be a "
+
"boolean value which expresses whether the environment's state is terminal or not"
);
}
private
void
checkHasExactlyOneOutput
()
{
failIfConditionFails
(
functionHasOneOutput
(),
"Reward function must have exactly one output"
);
}
private
void
checkHasStateAndTerminalInput
()
{
failIfConditionFails
(
inputParametersArePresent
(),