Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
CNNArch2Gluon
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
b4645437
Commit
b4645437
authored
5 years ago
by
Nicola Gatto
Browse files
Options
Downloads
Patches
Plain Diff
Use constants defined in CNNTrain
parent
53330fa7
No related branches found
No related tags found
3 merge requests
!20
Implemented layer variables and RNN layer
,
!19
Integrate TD3 Algorithm and Gaussian Noise
,
!18
Integrate TD3 Algorithm and Gaussian Noise
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java
+47
-74
47 additions, 74 deletions
...nnarch/gluongenerator/ReinforcementConfigurationData.java
with
47 additions
and
74 deletions
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java
+
47
−
74
View file @
b4645437
...
...
@@ -5,141 +5,114 @@ import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import
de.monticore.lang.monticar.cnntrain._symboltable.*
;
import
de.monticore.lang.monticar.cnntrain.annotations.Range
;
import
static
de
.
monticore
.
lang
.
monticar
.
cnntrain
.
helper
.
ConfigEntryNameConstants
.*;
import
java.util.*
;
public
class
ReinforcementConfigurationData
extends
ConfigurationData
{
private
static
final
String
AST_ENTRY_LEARNING_METHOD
=
"learning_method"
;
private
static
final
String
AST_ENTRY_NUM_EPISODES
=
"num_episodes"
;
private
static
final
String
AST_ENTRY_DISCOUNT_FACTOR
=
"discount_factor"
;
private
static
final
String
AST_ENTRY_NUM_MAX_STEPS
=
"num_max_steps"
;
private
static
final
String
AST_ENTRY_TARGET_SCORE
=
"target_score"
;
private
static
final
String
AST_ENTRY_TRAINING_INTERVAL
=
"training_interval"
;
private
static
final
String
AST_ENTRY_USE_FIX_TARGET_NETWORK
=
"use_fix_target_network"
;
private
static
final
String
AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL
=
"target_network_update_interval"
;
private
static
final
String
AST_ENTRY_SNAPSHOT_INTERVAL
=
"snapshot_interval"
;
private
static
final
String
AST_ENTRY_AGENT_NAME
=
"agent_name"
;
private
static
final
String
AST_ENTRY_USE_DOUBLE_DQN
=
"use_double_dqn"
;
private
static
final
String
AST_ENTRY_LOSS
=
"loss"
;
private
static
final
String
AST_ENTRY_RL_ALGORITHM
=
"rl_algorithm"
;
private
static
final
String
AST_ENTRY_REPLAY_MEMORY
=
"replay_memory"
;
private
static
final
String
AST_ENTRY_STRATEGY
=
"strategy"
;
private
static
final
String
AST_ENTRY_ENVIRONMENT
=
"environment"
;
private
static
final
String
AST_ENTRY_START_TRAINING_AT
=
"start_training_at"
;
private
static
final
String
AST_SOFT_TARGET_UPDATE_RATE
=
"soft_target_update_rate"
;
private
static
final
String
AST_EVALUATION_SAMPLES
=
"evaluation_samples"
;
private
static
final
String
AST_ENTRY_POLICY_NOISE
=
"policy_noise"
;
private
static
final
String
AST_ENTRY_NOISE_CLIP
=
"noise_clip"
;
private
static
final
String
AST_ENTRY_POLICY_DELAY
=
"policy_delay"
;
private
static
final
String
ENVIRONMENT_PARAM_REWARD_TOPIC
=
"reward_topic"
;
private
static
final
String
ENVIRONMENT_ROS
=
"ros_interface"
;
private
static
final
String
ENVIRONMENT_GYM
=
"gym"
;
private
static
final
String
STRATEGY_ORNSTEIN_UHLENBECK
=
"ornstein_uhlenbeck"
;
public
ReinforcementConfigurationData
(
ConfigurationSymbol
configuration
,
String
instanceName
)
{
super
(
configuration
,
instanceName
);
}
public
Boolean
isSupervisedLearning
()
{
if
(
configurationContainsKey
(
AST_ENTRY_
LEARNING_METHOD
))
{
return
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
LEARNING_METHOD
)
if
(
configurationContainsKey
(
LEARNING_METHOD
))
{
return
retrieveConfigurationEntryValueByKey
(
LEARNING_METHOD
)
.
equals
(
LearningMethod
.
SUPERVISED
);
}
return
true
;
}
public
Boolean
isReinforcementLearning
()
{
return
configurationContainsKey
(
AST_ENTRY_
LEARNING_METHOD
)
&&
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
LEARNING_METHOD
).
equals
(
LearningMethod
.
REINFORCEMENT
);
return
configurationContainsKey
(
LEARNING_METHOD
)
&&
retrieveConfigurationEntryValueByKey
(
LEARNING_METHOD
).
equals
(
LearningMethod
.
REINFORCEMENT
);
}
public
Integer
getNumEpisodes
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
NUM_EPISODES
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
NUM_EPISODES
);
return
!
configurationContainsKey
(
NUM_EPISODES
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
NUM_EPISODES
);
}
public
Double
getDiscountFactor
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
DISCOUNT_FACTOR
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
DISCOUNT_FACTOR
);
return
!
configurationContainsKey
(
DISCOUNT_FACTOR
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
DISCOUNT_FACTOR
);
}
public
Integer
getNumMaxSteps
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
NUM_MAX_STEPS
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
NUM_MAX_STEPS
);
return
!
configurationContainsKey
(
NUM_MAX_STEPS
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
NUM_MAX_STEPS
);
}
public
Double
getTargetScore
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
TARGET_SCORE
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
TARGET_SCORE
);
return
!
configurationContainsKey
(
TARGET_SCORE
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
TARGET_SCORE
);
}
public
Integer
getTrainingInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
TRAINING_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
TRAINING_INTERVAL
);
return
!
configurationContainsKey
(
TRAINING_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
TRAINING_INTERVAL
);
}
public
Boolean
getUseFixTargetNetwork
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
USE_FIX_TARGET_NETWORK
)
?
null
:
(
Boolean
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
USE_FIX_TARGET_NETWORK
);
return
!
configurationContainsKey
(
USE_FIX_TARGET_NETWORK
)
?
null
:
(
Boolean
)
retrieveConfigurationEntryValueByKey
(
USE_FIX_TARGET_NETWORK
);
}
public
Integer
getTargetNetworkUpdateInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
TARGET_NETWORK_UPDATE_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
TARGET_NETWORK_UPDATE_INTERVAL
);
return
!
configurationContainsKey
(
TARGET_NETWORK_UPDATE_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
TARGET_NETWORK_UPDATE_INTERVAL
);
}
public
Integer
getSnapshotInterval
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
SNAPSHOT_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
SNAPSHOT_INTERVAL
);
return
!
configurationContainsKey
(
SNAPSHOT_INTERVAL
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
SNAPSHOT_INTERVAL
);
}
public
String
getAgentName
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
AGENT_NAME
)
?
null
:
(
String
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
AGENT_NAME
);
return
!
configurationContainsKey
(
AGENT_NAME
)
?
null
:
(
String
)
retrieveConfigurationEntryValueByKey
(
AGENT_NAME
);
}
public
Boolean
getUseDoubleDqn
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
USE_DOUBLE_DQN
)
?
null
:
(
Boolean
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
USE_DOUBLE_DQN
);
return
!
configurationContainsKey
(
USE_DOUBLE_DQN
)
?
null
:
(
Boolean
)
retrieveConfigurationEntryValueByKey
(
USE_DOUBLE_DQN
);
}
public
Double
getSoftTargetUpdateRate
()
{
return
!
configurationContainsKey
(
AST_
SOFT_TARGET_UPDATE_RATE
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_
SOFT_TARGET_UPDATE_RATE
);
return
!
configurationContainsKey
(
SOFT_TARGET_UPDATE_RATE
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
SOFT_TARGET_UPDATE_RATE
);
}
public
Integer
getStartTrainingAt
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
START_TRAINING_AT
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
START_TRAINING_AT
);
return
!
configurationContainsKey
(
START_TRAINING_AT
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
START_TRAINING_AT
);
}
public
Integer
getEvaluationSamples
()
{
return
!
configurationContainsKey
(
AST_
EVALUATION_SAMPLES
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_
EVALUATION_SAMPLES
);
return
!
configurationContainsKey
(
EVALUATION_SAMPLES
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
EVALUATION_SAMPLES
);
}
public
Double
getPolicyNoise
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
POLICY_NOISE
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
POLICY_NOISE
);
return
!
configurationContainsKey
(
POLICY_NOISE
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
POLICY_NOISE
);
}
public
Double
getNoiseClip
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
NOISE_CLIP
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
NOISE_CLIP
);
return
!
configurationContainsKey
(
NOISE_CLIP
)
?
null
:
(
Double
)
retrieveConfigurationEntryValueByKey
(
NOISE_CLIP
);
}
public
Integer
getPolicyDelay
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
POLICY_DELAY
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
POLICY_DELAY
);
return
!
configurationContainsKey
(
POLICY_DELAY
)
?
null
:
(
Integer
)
retrieveConfigurationEntryValueByKey
(
POLICY_DELAY
);
}
public
RLAlgorithm
getRlAlgorithm
()
{
if
(!
isReinforcementLearning
())
{
return
null
;
}
return
!
configurationContainsKey
(
AST_ENTRY_
RL_ALGORITHM
)
?
RLAlgorithm
.
DQN
:
(
RLAlgorithm
)
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
RL_ALGORITHM
);
return
!
configurationContainsKey
(
RL_ALGORITHM
)
?
RLAlgorithm
.
DQN
:
(
RLAlgorithm
)
retrieveConfigurationEntryValueByKey
(
RL_ALGORITHM
);
}
public
String
getInputNameOfTrainedArchitecture
()
{
...
...
@@ -179,18 +152,18 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public
String
getLoss
()
{
return
!
configurationContainsKey
(
AST_ENTRY_
LOSS
)
?
null
:
retrieveConfigurationEntryValueByKey
(
AST_ENTRY_
LOSS
).
toString
();
return
!
configurationContainsKey
(
LOSS
)
?
null
:
retrieveConfigurationEntryValueByKey
(
LOSS
).
toString
();
}
public
Map
<
String
,
Object
>
getReplayMemory
()
{
return
getMultiParamEntry
(
AST_ENTRY_
REPLAY_MEMORY
,
"method"
);
return
getMultiParamEntry
(
REPLAY_MEMORY
,
"method"
);
}
public
Map
<
String
,
Object
>
getStrategy
()
{
assert
isReinforcementLearning
():
"Strategy parameter only for reinforcement learning but called in a "
+
" non reinforcement learning context"
;
Map
<
String
,
Object
>
strategyParams
=
getMultiParamEntry
(
AST_ENTRY_
STRATEGY
,
"method"
);
Map
<
String
,
Object
>
strategyParams
=
getMultiParamEntry
(
STRATEGY
,
"method"
);
assert
getConfiguration
().
getTrainedArchitecture
().
isPresent
():
"Architecture not present,"
+
" but reinforcement training"
;
NNArchitectureSymbol
trainedArchitecture
=
getConfiguration
().
getTrainedArchitecture
().
get
();
...
...
@@ -218,7 +191,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public
Map
<
String
,
Object
>
getEnvironment
()
{
return
getMultiParamEntry
(
AST_ENTRY_
ENVIRONMENT
,
"environment"
);
return
getMultiParamEntry
(
ENVIRONMENT
,
"environment"
);
}
public
Boolean
hasRewardFunction
()
{
...
...
@@ -312,12 +285,12 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public
boolean
hasRosRewardTopic
()
{
Map
<
String
,
Object
>
environmentParameters
=
getMultiParamEntry
(
AST_ENTRY_
ENVIRONMENT
,
"environment"
);
Map
<
String
,
Object
>
environmentParameters
=
getMultiParamEntry
(
ENVIRONMENT
,
"environment"
);
if
(
environmentParameters
==
null
||
!
environmentParameters
.
containsKey
(
"environment"
))
{
return
false
;
}
return
environmentParameters
.
containsKey
(
ENVIRONMENT_
PARAM_
REWARD_TOPIC
);
return
environmentParameters
.
containsKey
(
ENVIRONMENT_REWARD_TOPIC
);
}
private
Map
<
String
,
Object
>
getMultiParamEntry
(
final
String
key
,
final
String
valueName
)
{
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment