Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
736a2256
Commit
736a2256
authored
May 16, 2019
by
Evgeny Kusmenko
Browse files
Merge branch 'implement-reinforcement-learning' into 'master'
Implement reinforcement learning See merge request
!14
parents
2dd5aae6
f9465a60
Pipeline
#144750
passed with stages
in 3 minutes and 55 seconds
Changes
51
Pipelines
2
Expand all
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
736a2256
...
...
@@ -8,3 +8,4 @@ nppBackup
*.iml
.vscode
.gitlab-ci.yml
View file @
736a2256
...
...
@@ -27,7 +27,7 @@ masterJobLinux:
stage
:
linux
image
:
maven:3-jdk-8
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
-Dtest=\!Integration*
-
cat target/site/jacoco/index.html
-
mvn package sonar:sonar -s settings.xml
only
:
...
...
@@ -36,7 +36,7 @@ masterJobLinux:
masterJobWindows
:
stage
:
windows
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-Dtest=\!Integration*
tags
:
-
Windows10
...
...
@@ -44,7 +44,13 @@ BranchJobLinux:
stage
:
linux
image
:
maven:3-jdk-8
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-Dtest=\!Integration*
-
cat target/site/jacoco/index.html
except
:
-
master
PythonWrapperIntegrationTest
:
stage
:
linux
image
:
registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
pom.xml
View file @
736a2256
...
...
@@ -8,7 +8,7 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnnarch-gluon-generator
</artifactId>
<version>
0.
1.6
</version>
<version>
0.
2.0-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
...
...
@@ -16,9 +16,10 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.3.0-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.
2.6
</CNNTrain.version>
<CNNTrain.version>
0.
3.0-SNAPSHOT
</CNNTrain.version>
<CNNArch2MXNet.version>
0.2.14-SNAPSHOT
</CNNArch2MXNet.version>
<embedded-montiarc-math-opt-generator>
0.1.4
</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>
0.0.1
</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>
18.0
</guava.version>
...
...
@@ -100,6 +101,12 @@
<version>
${embedded-montiarc-math-opt-generator}
</version>
</dependency>
<dependency>
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
embedded-montiarc-emadl-pythonwrapper-generator
</artifactId>
<version>
${EMADL2PythonWrapper.version}
</version>
</dependency>
<!-- .. Test Libraries ............................................... -->
<dependency>
...
...
@@ -109,6 +116,13 @@
<scope>
test
</scope>
</dependency>
<dependency>
<groupId>
org.mockito
</groupId>
<artifactId>
mockito-core
</artifactId>
<version>
1.10.19
</version>
<scope>
test
</scope>
</dependency>
<dependency>
<groupId>
ch.qos.logback
</groupId>
<artifactId>
logback-classic
</artifactId>
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
View file @
736a2256
...
...
@@ -47,8 +47,10 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
temp
=
archTc
.
process
(
"CNNNet"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
temp
=
archTc
.
process
(
"CNNDataLoader"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
if
(
architecture
.
getDataPath
()
!=
null
)
{
temp
=
archTc
.
process
(
"CNNDataLoader"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
}
temp
=
archTc
.
process
(
"CNNCreator"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
View file @
736a2256
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
import
com.google.common.collect.Maps
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod
;
import
de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cpp.GeneratorCPP
;
import
de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation
;
import
de.se_rwth.commons.logging.Log
;
import
java.io.File
;
import
java.io.IOException
;
import
java.nio.charset.Charset
;
import
java.nio.charset.StandardCharsets
;
import
java.nio.file.Files
;
import
java.nio.file.Path
;
import
java.nio.file.Paths
;
import
java.util.*
;
public
class
CNNTrain2Gluon
extends
CNNTrain2MxNet
{
public
CNNTrain2Gluon
()
{
private
static
final
String
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
=
"reinforcement_learning"
;
private
final
RewardFunctionSourceGenerator
rewardFunctionSourceGenerator
;
private
String
rootProjectModelsDir
;
public
Optional
<
String
>
getRootProjectModelsDir
()
{
return
Optional
.
ofNullable
(
rootProjectModelsDir
);
}
public
void
setRootProjectModelsDir
(
String
rootProjectModelsDir
)
{
this
.
rootProjectModelsDir
=
rootProjectModelsDir
;
}
public
CNNTrain2Gluon
(
RewardFunctionSourceGenerator
rewardFunctionSourceGenerator
)
{
super
();
this
.
rewardFunctionSourceGenerator
=
rewardFunctionSourceGenerator
;
}
@Override
public
ConfigurationSymbol
getConfigurationSymbol
(
Path
modelsDirPath
,
String
rootModelName
)
{
ConfigurationSymbol
configurationSymbol
=
super
.
getConfigurationSymbol
(
modelsDirPath
,
rootModelName
);
// Generate Reward function if necessary
if
(
configurationSymbol
.
getLearningMethod
().
equals
(
LearningMethod
.
REINFORCEMENT
)
&&
configurationSymbol
.
getRlRewardFunction
().
isPresent
())
{
generateRewardFunction
(
configurationSymbol
.
getRlRewardFunction
().
get
(),
modelsDirPath
);
}
return
configurationSymbol
;
}
@Override
public
void
generate
(
Path
modelsDirPath
,
String
rootModelName
)
{
ConfigurationSymbol
configuration
=
this
.
getConfigurationSymbol
(
modelsDirPath
,
rootModelName
);
Map
<
String
,
String
>
fileContents
=
this
.
generateStrings
(
configuration
);
GeneratorCPP
genCPP
=
new
GeneratorCPP
();
genCPP
.
setGenerationTargetPath
(
this
.
getGenerationTargetPath
());
try
{
Iterator
var6
=
fileContents
.
keySet
().
iterator
();
while
(
var6
.
hasNext
())
{
String
fileName
=
(
String
)
var6
.
next
();
genCPP
.
generateFile
(
new
FileContent
((
String
)
fileContents
.
get
(
fileName
),
fileName
));
}
}
catch
(
IOException
var8
)
{
Log
.
error
(
"CNNTrainer file could not be generated"
+
var8
.
getMessage
());
}
}
@Override
public
Map
<
String
,
String
>
generateStrings
(
ConfigurationSymbol
configuration
)
{
TemplateConfiguration
templateConfiguration
=
new
GluonTemplateConfiguration
();
ConfigurationData
configData
=
new
ConfigurationData
(
configuration
,
getInstanceName
());
Reinforcement
ConfigurationData
configData
=
new
Reinforcement
ConfigurationData
(
configuration
,
getInstanceName
());
List
<
ConfigurationData
>
configDataList
=
new
ArrayList
<>();
configDataList
.
add
(
configData
);
Map
<
String
,
Object
>
ftlContext
=
Collections
.
singletonMap
(
"configurations"
,
configDataList
);
Map
<
String
,
Object
>
ftlContext
=
Maps
.
newHashMap
();
ftlContext
.
put
(
"configurations"
,
configDataList
);
Map
<
String
,
String
>
fileContentMap
=
new
HashMap
<>();
String
cnnTrainTemplateContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"CNNTrainer.ftl"
);
fileContentMap
.
put
(
"CNNTrainer_"
+
getInstanceName
()
+
".py"
,
cnnTrainTemplateContent
);
if
(
configData
.
isSupervisedLearning
())
{
String
cnnTrainTemplateContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"CNNTrainer.ftl"
);
fileContentMap
.
put
(
"CNNTrainer_"
+
getInstanceName
()
+
".py"
,
cnnTrainTemplateContent
);
String
cnnSupervisedTrainerContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"CNNSupervisedTrainer.ftl"
);
fileContentMap
.
put
(
"supervised_trainer.py"
,
cnnSupervisedTrainerContent
);
}
else
if
(
configData
.
isReinforcementLearning
())
{
final
String
trainerName
=
"CNNTrainer_"
+
getInstanceName
();
ftlContext
.
put
(
"trainerName"
,
trainerName
);
Map
<
String
,
String
>
rlFrameworkContentMap
=
constructReinforcementLearningFramework
(
templateConfiguration
,
ftlContext
);
fileContentMap
.
putAll
(
rlFrameworkContentMap
);
final
String
reinforcementTrainerContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/Trainer.ftl"
);
fileContentMap
.
put
(
trainerName
+
".py"
,
reinforcementTrainerContent
);
final
String
startTrainerScriptContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/StartTrainer.ftl"
);
fileContentMap
.
put
(
"start_training.sh"
,
startTrainerScriptContent
);
}
return
fileContentMap
;
}
private
void
generateRewardFunction
(
RewardFunctionSymbol
rewardFunctionSymbol
,
Path
modelsDirPath
)
{
GeneratorPythonWrapperStandaloneApi
pythonWrapperApi
=
new
GeneratorPythonWrapperStandaloneApi
();
List
<
String
>
fullNameOfComponent
=
rewardFunctionSymbol
.
getRewardFunctionComponentName
();
String
rewardFunctionRootModel
=
String
.
join
(
"."
,
fullNameOfComponent
);
String
rewardFunctionOutputPath
=
Paths
.
get
(
this
.
getGenerationTargetPath
(),
"reward"
).
toString
();
if
(!
getRootProjectModelsDir
().
isPresent
())
{
setRootProjectModelsDir
(
modelsDirPath
.
toString
());
}
rewardFunctionSourceGenerator
.
generate
(
getRootProjectModelsDir
().
get
(),
rewardFunctionRootModel
,
rewardFunctionOutputPath
);
fixArmadilloEmamGenerationOfFile
(
Paths
.
get
(
rewardFunctionOutputPath
,
String
.
join
(
"_"
,
fullNameOfComponent
)
+
".h"
));
String
pythonWrapperOutputPath
=
Paths
.
get
(
rewardFunctionOutputPath
,
"pylib"
).
toString
();
Log
.
info
(
"Generating reward function python wrapper..."
,
"CNNTrain2Gluon"
);
ComponentPortInformation
componentPortInformation
;
if
(
pythonWrapperApi
.
checkIfPythonModuleBuildAvailable
())
{
final
String
rewardModuleOutput
=
Paths
.
get
(
getGenerationTargetPath
(),
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
).
toString
();
componentPortInformation
=
pythonWrapperApi
.
generateAndTryBuilding
(
getRootProjectModelsDir
().
get
(),
rewardFunctionRootModel
,
pythonWrapperOutputPath
,
rewardModuleOutput
);
}
else
{
Log
.
warn
(
"Cannot build wrapper automatically: OS not supported. Please build manually before starting training."
);
componentPortInformation
=
pythonWrapperApi
.
generate
(
getRootProjectModelsDir
().
get
(),
rewardFunctionRootModel
,
pythonWrapperOutputPath
);
}
RewardFunctionParameterAdapter
functionParameter
=
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
new
FunctionParameterChecker
().
check
(
functionParameter
);
rewardFunctionSymbol
.
setRewardFunctionParameter
(
functionParameter
);
}
private
void
fixArmadilloEmamGenerationOfFile
(
Path
pathToBrokenFile
){
final
File
brokenFile
=
pathToBrokenFile
.
toFile
();
if
(
brokenFile
.
exists
())
{
try
{
Charset
charset
=
StandardCharsets
.
UTF_8
;
String
fileContent
=
new
String
(
Files
.
readAllBytes
(
pathToBrokenFile
),
charset
);
fileContent
=
fileContent
.
replace
(
"armadillo.h"
,
"armadillo"
);
Files
.
write
(
pathToBrokenFile
,
fileContent
.
getBytes
());
}
catch
(
IOException
e
)
{
Log
.
warn
(
"Cannot fix wrong armadillo library in "
+
pathToBrokenFile
.
toString
());
}
}
}
private
Map
<
String
,
String
>
constructReinforcementLearningFramework
(
final
TemplateConfiguration
templateConfiguration
,
final
Map
<
String
,
Object
>
ftlContext
)
{
Map
<
String
,
String
>
fileContentMap
=
Maps
.
newHashMap
();
ftlContext
.
put
(
"rlFrameworkModule"
,
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
);
final
String
reinforcementAgentContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/agent/Agent.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/agent.py"
,
reinforcementAgentContent
);
final
String
reinforcementPolicyContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/agent/ActionPolicy.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/action_policy.py"
,
reinforcementPolicyContent
);
final
String
replayMemoryContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/agent/ReplayMemory.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/replay_memory.py"
,
replayMemoryContent
);
final
String
environmentContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/environment/Environment.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/environment.py"
,
environmentContent
);
final
String
utilContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/util/Util.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/util.py"
,
utilContent
);
String
cnnSupervisedTrainerContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"CNNSupervisedTrainer.ftl"
)
;
fileContentMap
.
put
(
"supervised_trainer.py"
,
cnnSupervisedTrainer
Content
);
final
String
initContent
=
""
;
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/__init__.py"
,
init
Content
);
return
fileContentMap
;
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java
0 → 100644
View file @
736a2256
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter
;
import
de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData
;
import
de.monticore.lang.monticar.cnntrain._symboltable.*
;
import
java.util.HashMap
;
import
java.util.List
;
import
java.util.Map
;
import
java.util.Optional
;
/**
*
*/
public
class
ReinforcementConfigurationData
extends
ConfigurationData
{
private
static
final
String
AST_ENTRY_LEARNING_METHOD
=
"learning_method"
;
private
static
final
String
AST_ENTRY_NUM_EPISODES
=
"num_episodes"
;
private
static
final
String
AST_ENTRY_DISCOUNT_FACTOR
=
"discount_factor"
;
private
static
final
String
AST_ENTRY_NUM_MAX_STEPS
=
"num_max_steps"
;
private
static
final
String
AST_ENTRY_TARGET_SCORE
=
"target_score"
;
private
static
final
String
AST_ENTRY_TRAINING_INTERVAL
=
"training_interval"
;
private
static
final
String
AST_ENTRY_USE_FIX_TARGET_NETWORK
=
"use_fix_target_network"
;
private
static
final
String
AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL
=
"target_network_update_interval"
;
private
static
final
String
AST_ENTRY_SNAPSHOT_INTERVAL
=
"snapshot_interval"
;
private
static
final
String
AST_ENTRY_AGENT_NAME
=
"agent_name"
;
private
static
final
String
AST_ENTRY_USE_DOUBLE_DQN
=
"use_double_dqn"
;
private
static
final
String
AST_ENTRY_LOSS
=
"loss"
;
private
static
final
String
AST_ENTRY_REPLAY_MEMORY
=
"replay_memory"
;
private
static
final
String
AST_ENTRY_ACTION_SELECTION
=
"action_selection"
;
private
static
final
String
AST_ENTRY_ENVIRONMENT
=
"environment"
;
public
ReinforcementConfigurationData
(
ConfigurationSymbol
configuration
,
String
instanceName
)
{
super
(
configuration
,
instanceName
);
}
public
Boolean
isSupervisedLearning
()
{
if
(
configurationContainsKey
(
AST_ENTRY_LEARNING_METHOD
))
{
return
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_LEARNING_METHOD
)
.
equals
(
LearningMethod
.
SUPERVISED
);
}
return
true
;
}
public
Boolean
isReinforcementLearning
()
{
return
configurationContainsKey
(
AST_ENTRY_LEARNING_METHOD
)
&&
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_LEARNING_METHOD
).
equals
(
LearningMethod
.
REINFORCEMENT
);
}
public
Integer
getNumEpisodes
()
{
return
!
configurationContainsKey
(
AST_ENTRY_NUM_EPISODES
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_NUM_EPISODES
);
}
public
Double
getDiscountFactor
()
{
return
!
configurationContainsKey
(
AST_ENTRY_DISCOUNT_FACTOR
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_DISCOUNT_FACTOR
);
}
public
Integer
getNumMaxSteps
()
{
return
!
configurationContainsKey
(
AST_ENTRY_NUM_MAX_STEPS
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_NUM_MAX_STEPS
);
}
public
Double
getTargetScore
()
{
return
!
configurationContainsKey
(
AST_ENTRY_TARGET_SCORE
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_TARGET_SCORE
);
}
public
Integer
getTrainingInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_TRAINING_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_TRAINING_INTERVAL
);
}
public
Boolean
getUseFixTargetNetwork
()
{
return
!
configurationContainsKey
(
AST_ENTRY_USE_FIX_TARGET_NETWORK
)
?
null
:
(
Boolean
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_USE_FIX_TARGET_NETWORK
);
}
public
Integer
getTargetNetworkUpdateInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL
);
}
public
Integer
getSnapshotInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_SNAPSHOT_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_SNAPSHOT_INTERVAL
);
}
public
String
getAgentName
()
{
return
!
configurationContainsKey
(
AST_ENTRY_AGENT_NAME
)
?
null
:
(
String
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_AGENT_NAME
);
}
public
Boolean
getUseDoubleDqn
()
{
return
!
configurationContainsKey
(
AST_ENTRY_USE_DOUBLE_DQN
)
?
null
:
(
Boolean
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_USE_DOUBLE_DQN
);
}
public
String
getLoss
()
{
return
!
configurationContainsKey
(
AST_ENTRY_LOSS
)
?
null
:
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_LOSS
).
toString
();
}
public
Map
<
String
,
Object
>
getReplayMemory
()
{
return
getMultiParamEntry
(
AST_ENTRY_REPLAY_MEMORY
,
"method"
);
}
public
Map
<
String
,
Object
>
getActionSelection
()
{
return
getMultiParamEntry
(
AST_ENTRY_ACTION_SELECTION
,
"method"
);
}
public
Map
<
String
,
Object
>
getEnvironment
()
{
return
getMultiParamEntry
(
AST_ENTRY_ENVIRONMENT
,
"environment"
);
}
public
Boolean
hasRewardFunction
()
{
return
this
.
getConfiguration
().
getRlRewardFunction
().
isPresent
();
}
public
String
getRewardFunctionName
()
{
if
(!
this
.
getConfiguration
().
getRlRewardFunction
().
isPresent
())
{
return
null
;
}
return
String
.
join
(
"_"
,
this
.
getConfiguration
().
getRlRewardFunction
()
.
get
().
getRewardFunctionComponentName
());
}
private
Optional
<
RewardFunctionParameterAdapter
>
getRlRewardFunctionParameter
()
{
if
(!
this
.
getConfiguration
().
getRlRewardFunction
().
isPresent
()
||
!
this
.
getConfiguration
().
getRlRewardFunction
().
get
().
getRewardFunctionParameter
().
isPresent
())
{
return
Optional
.
empty
();
}
return
Optional
.
ofNullable
(
(
RewardFunctionParameterAdapter
)
this
.
getConfiguration
().
getRlRewardFunction
().
get
()
.
getRewardFunctionParameter
().
orElse
(
null
));
}
public
Map
<
String
,
Object
>
getRewardFunctionStateParameter
()
{
if
(!
getRlRewardFunctionParameter
().
isPresent
()
||
!
getRlRewardFunctionParameter
().
get
().
getInputStateParameterName
().
isPresent
())
{
return
null
;
}
return
getInputParameterWithName
(
getRlRewardFunctionParameter
().
get
().
getInputStateParameterName
().
get
());
}
public
Map
<
String
,
Object
>
getRewardFunctionTerminalParameter
()
{
if
(!
getRlRewardFunctionParameter
().
isPresent
()
||
!
getRlRewardFunctionParameter
().
get
().
getInputTerminalParameter
().
isPresent
())
{
return
null
;
}
return
getInputParameterWithName
(
getRlRewardFunctionParameter
().
get
().
getInputTerminalParameter
().
get
());
}
public
String
getRewardFunctionOutputName
()
{
if
(!
getRlRewardFunctionParameter
().
isPresent
())
{
return
null
;
}
return
getRlRewardFunctionParameter
().
get
().
getOutputParameterName
().
orElse
(
null
);
}
private
Map
<
String
,
Object
>
getMultiParamEntry
(
final
String
key
,
final
String
valueName
)
{
if
(!
configurationContainsKey
(
key
))
{
return
null
;
}
Map
<
String
,
Object
>
resultView
=
new
HashMap
<>();
MultiParamValueSymbol
multiParamValue
=
(
MultiParamValueSymbol
)
this
.
getConfiguration
().
getEntryMap
()
.
get
(
key
).
getValue
();
resultView
.
put
(
valueName
,
multiParamValue
.
getValue
());
resultView
.
putAll
(
multiParamValue
.
getParameters
());
return
resultView
;
}
private
Boolean
configurationContainsKey
(
final
String
key
)
{
return
this
.
getConfiguration
().
getEntryMap
().
containsKey
(
key
);
}
private
Object
retrieveConfigurationEntryValueByKey
(
final
String
key
)
{
return
this
.
getConfiguration
().
getEntry
(
key
).
getValue
().
getValue
();
}
private
Map
<
String
,
Object
>
getInputParameterWithName
(
final
String
parameterName
)
{
if
(!
getRlRewardFunctionParameter
().
isPresent
()
||
!
getRlRewardFunctionParameter
().
get
().
getTypeOfInputPort
(
parameterName
).
isPresent
()
||
!
getRlRewardFunctionParameter
().
get
().
getInputPortDimensionOfPort
(
parameterName
).
isPresent
())
{
return
null
;
}
Map
<
String
,
Object
>
functionStateParameter
=
new
HashMap
<>();;
final
String
portType
=
getRlRewardFunctionParameter
().
get
().
getTypeOfInputPort
(
parameterName
).
get
();
final
List
<
Integer
>
dimension
=
getRlRewardFunctionParameter
().
get
().
getInputPortDimensionOfPort
(
parameterName
).
get
();
String
dtype
=
null
;
if
(
portType
.
equals
(
"Q"
))
{
dtype
=
"double"
;
}
else
if
(
portType
.
equals
(
"Z"
))
{
dtype
=
"int"
;
}
else
if
(
portType
.
equals
(
"B"
))
{
dtype
=
"bool"
;
}
Boolean
isMultiDimensional
=
dimension
.
size
()
>
1
||
(
dimension
.
size
()
==
1
&&
dimension
.
get
(
0
)
>
1
);
functionStateParameter
.
put
(
"name"
,
parameterName
);
functionStateParameter
.
put
(
"dtype"
,
dtype
);
functionStateParameter
.
put
(
"isMultiDimensional"
,
isMultiDimensional
);
return
functionStateParameter
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
0 → 100644
View file @
736a2256
package
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement
;
import
de.se_rwth.commons.logging.Log
;
/**
*
*/
public
class
FunctionParameterChecker
{
private
String
inputStateParameterName
;
private
String
inputTerminalParameterName
;
private
String
outputParameterName
;
private
RewardFunctionParameterAdapter
rewardFunctionParameter
;
public
FunctionParameterChecker
()
{
}
public
void
check
(
final
RewardFunctionParameterAdapter
rewardFunctionParameter
)
{
this
.
rewardFunctionParameter
=
rewardFunctionParameter
;
retrieveParameterNames
();
checkHasExactlyTwoInputs
();
checkHasExactlyOneOutput
();
checkHasStateAndTerminalInput
();
checkInputStateDimension
();
checkInputTerminalTypeAndDimension
();
checkOutputDimension
();
}
private
void
checkHasExactlyTwoInputs
()
{
failIfConditionFails
(
functionHasTwoInputs
(),
"Reward function must have exactly two input parameters: "
+
"One input needs to represents the environment's state and another input needs to be a "
+
"boolean value which expresses whether the environment's state is terminal or not"
);
}
private
void
checkHasExactlyOneOutput
()
{
failIfConditionFails
(
functionHasOneOutput
(),
"Reward function must have exactly one output"
);
}
private
void
checkHasStateAndTerminalInput
()
{
failIfConditionFails
(
inputParametersArePresent
(),
"Reward function must have exactly two input parameters: "
+
"One input needs to represents the environment's state as a numerical scalar, vector or matrice, "
+
"and another input needs to be a "
+
"boolean value which expresses whether the environment's state is terminal or not"
);
}