Commit ab2e2987 authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Add reinforcement learning

parent 7de7b242
......@@ -3,11 +3,10 @@
# CNNTrain
CNNTrain is a domain specific language for describing training parameters of a feedforward neural network.
CNNTrain files must have .cnnt extension. Training configuration starts with a `configuration` word, followed by the configuration name and a list of parameters. The available parameters are batch size, number of epochs, loading previous checkpoint as well as an optimizer with its parameters. All these parameters are optional.
CNNTrain is a domain specific language for describing training parameters of a feedforward neural network. CNNTrain files must have .cnnt extension. Training configuration starts with a `configuration` word, followed by the configuration name and a list of parameters. The available parameters are batch size, number of epochs, loading previous checkpoint as well as an optimizer with its parameters. All these parameters are optional.
An example of a config:
```
configuration FullConfig{
num_epoch : 5
......@@ -29,9 +28,11 @@ configuration FullConfig{
}
}
```
See CNNTrain.mc4 for full grammar definition.
Using CNNTrainGenerator class, a Python file can be generated, which looks as following (for an example above):
```python
batch_size = 100,
num_epoch = 5,
......@@ -51,7 +52,140 @@ optimizer_params = {
'learning_rate_decay': 0.9,
'step_size': 1000}
```
## Reinforcement Learning
CNNTrain can be used to describe training parameters for supervised learning methods as well as for reinforcement learning methods. If reinforcement learning is selected, the network is trained with the Deep-Q-Network algorithm (Mnih et. al. in Playing Atari with Deep Reinforcement Learning).
An example of a supervised learning configuration can be seen above. The following is an example configuration for reinforcement learning:
```CNNTrainLang
configuration ReinforcementConfig {
learning_method : reinforcement
agent_name : "reinforcement-agent"
environment : gym { name:"CartPole-v1" }
context : cpu
num_episodes : 300
num_max_steps : 9999
discount_factor : 0.998
target_score : 1000
training_interval : 10
loss : huber_loss
use_fix_target_network : true
target_network_update_interval : 100
use_double_dqn : true
replay_memory : buffer{
memory_size : 1000000
sample_size : 64
}
action_selection : epsgreedy{
epsilon : 1.0
min_epsilon : 0.01
epsilon_decay_method: linear
epsilon_decay : 0.0001
}
optimizer : rmsprop{
learning_rate : 0.001
learning_rate_minimum : 0.00001
weight_decay : 0.01
learning_rate_decay : 0.9
learning_rate_policy : step
step_size : 1000
rescale_grad : 1.1
clip_gradient : 10
gamma1 : 0.9
gamma2 : 0.9
epsilon : 0.000001
centered : true
clip_weights : 10
}
}
```
### Available Parameters for Reinforcement Learning
| Parameter | Value | Default | Required | Description |
|------------|--------|---------|----------|-------------|
|learning_method| reinforcement,supervised | supervised | No | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration |
| agent_name | String | "agent" | No | Names the agent (e.g. for logging output) |
|environment | gym, ros_interface | Yes | / | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) |
| context | cpu, gpu | cpu | No | Determines whether the GPU is used during training or the CPU |
| num_episodes | Integer | 50 | No | Number of episodes the agent is trained. An episode is a full passing of a game from an initial state to a terminal state.|
| num_max_steps | Integer | 99999 | No | Number of steps within an episodes before the environment is forced to reset the state (e.g. to avoid a state in which the agent is stuck) |
|discount_factor | Float | 0.9 | No | Discount factor |
| target_score | Float | None | No | If set, the agent stops the training when the average score of the last 100 episodes is greater than the target score. |
| training_interval | Integer | 1 | No | Number of steps between two trainings |
| loss | euclidean, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber_loss | euclidean | No | Selects the loss function
| use_fix_target_network | bool | false | No | If set, an extra network with fixed parameters is used to estimate the Q values |
| target_network_update_interval | Integer | / | Yes, if fixed target network is true | If *use_fix_target_network* is set, it determines the number of steps after the target network is updated (Minh et. al. "Human Level Control through Deep Reinforcement Learning")|
| use_double_dqn | bool | false | No | If set, two value functions are used to determine the action values (Hasselt et. al. "Deep Reinforcement Learning with Double Q Learning") |
| replay_memory | buffer, online, combined | buffer | No | Determines the behaviour of the replay memory |
| action_selection | epsgreedy | epsgreedy | No | Determines the action selection policy during the training |
| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. |
#### Environment
##### Option: ros_interface
If selected, the communication between the environment and the agent is done via ROS. Additional parameters:
- **state_topic**: Topic on which the state is published
- **action_topic**: Topic on which the action is published
- **reset_topic**: Topic on which the reset command is published
- **terminal_state_topic**: Topic on which the terminal flag is published
##### Option: gym
The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/). Additional parameters:
- **name**: Name (see https://gym.openai.com/) of the environment
#### Replay Buffer
Different buffer behaviour can be selected for the training. For more information about the buffer behaviour see "A deeper look at Experience Replay" by Zhang, Sutton
##### Option: buffer
A simple buffer in which stores the SARS (**S**tate, **A**ction, **R**eward, next **S**tate) tuples. Additional parameters:
- **memory_size**: Determines the size of the buffer
- **sample_size**: Number of samples that are used for each training step
##### Option: online
No buffer is used. Only the current SARS tuple is used for taining.
##### Option: combined
Combination of *online* and *buffer*. Both the current SARS tuple as well as a sample from the buffer are used for each training step. Parameters are the same as *buffer*.
### Action Selection
Determines the behaviour when selecting an action based on the values. (Currently, only epsilon greedy is available.)
#### Option: epsgreedy
Selects an action based on Epsilon-Greedy-Policy. This means, based on epsilon, either a random action is choosen or an action with the highest value. Additional parameters:
- **epsilon**: Probability of choosing an action randomly
- **epsilon_decay_method**: Method which determines how epsilon decreases after each step. Can be *linear* for linear decrease or *no* for no decrease.
- **epsilon_decay**: The actual decay of epsilon after each step.
- **min_epsilon**: After *min_epsilon* is reached, epsilon is not decreased further.
## Generation
To execute generation in your project, use the following code to generate a separate Config file:
```java
import de.monticore.lang.monticar.cnntrain.generator.CNNTrainGenerator;
...
......@@ -61,6 +195,7 @@ cnnTrainGenerator.generate(modelPath, cnnt_filename);
```
Use the following code to get file contents as a map ( `fileContents.getValue()` contains the generated code):
```java
import de.monticore.lang.monticar.cnntrain.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
......@@ -69,10 +204,6 @@ CNNTrainGenerator cnnTrainGenerator = new CNNTrainGenerator();
ModelPath mp = new ModelPath(Paths.get("path/to/cnnt/file"));
GlobalScope trainScope = new GlobalScope(mp, new CNNTrainLanguage());
Map.Entry<String, String> fileContents = cnnTrainGenerator.generateFileContent( trainScope, cnnt_filename );
```
CNNTrain can be used together with [CNNArch](https://github.com/EmbeddedMontiArc/CNNArchLang) language, which describes architecture of a NN.
[EmbeddedMontiArcDL](https://github.com/EmbeddedMontiArc/EmbeddedMontiArcDL) uses both languages.
CNNTrain can be used together with [CNNArch](https://github.com/EmbeddedMontiArc/CNNArchLang) language, which describes architecture of a NN. [EmbeddedMontiArcDL](https://github.com/EmbeddedMontiArc/EmbeddedMontiArcDL) uses both languages.
\ No newline at end of file
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.2.6</version>
<version>0.3.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -33,7 +33,10 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
LossValue implements ConfigValue =(euclidean:"euclidean" | crossEntropy:"cross_entropy");
LossValue implements ConfigValue =(euclidean:"euclidean"
| l1: "l1"
| crossEntropy:"cross_entropy"
| huberLoss: "huber_loss");
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
......@@ -47,6 +50,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
DataVariable implements VariableReference = Name&;
IntegerValue implements ConfigValue = NumberWithUnit;
NumberValue implements ConfigValue = NumberWithUnit;
StringValue implements ConfigValue = StringLiteral;
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
interface OptimizerValue extends ConfigValue;
......@@ -90,4 +94,72 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry = name:"rho" ":" value:NumberValue;
// Reinforcement Extensions
interface MultiParamValue extends ConfigValue;
LearningMethodEntry implements ConfigEntry = name:"learning_method" ":" value:LearningMethodValue;
NumEpisodesEntry implements ConfigEntry = name:"num_episodes" ":" value:IntegerValue;
DiscountFactorEntry implements ConfigEntry = name:"discount_factor" ":" value:NumberValue;
NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue;
TargetScoreEntry implements ConfigEntry = name:"target_score" ":" value:NumberValue;
TrainingIntervalEntry implements ConfigEntry = name:"training_interval" ":" value:IntegerValue;
UseFixTargetNetworkEntry implements ConfigEntry = name:"use_fix_target_network" ":" value:BooleanValue;
TargetNetworkUpdateIntervalEntry implements ConfigEntry = name:"target_network_update_interval" ":" value:IntegerValue;
SnapshotIntervalEntry implements ConfigEntry = name:"snapshot_interval" ":" value:IntegerValue;
AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue;
UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue;
RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue;
ComponentNameValue implements ConfigValue = Name ("."Name)*;
LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement");
interface MultiParamConfigEntry extends ConfigEntry;
// Replay Memory
ReplayMemoryEntry implements MultiParamConfigEntry = name:"replay_memory" ":" value:ReplayMemoryValue;
interface ReplayMemoryValue extends MultiParamValue;
interface ReplayMemoryBufferEntry extends Entry;
ReplayMemoryBufferValue implements ReplayMemoryValue = name:"buffer" ("{" params:ReplayMemoryBufferEntry* "}")?;
ReplayMemoryOnlineValue implements ReplayMemoryValue = name:"online";
interface ReplayMemoryCombinedEntry extends Entry;
ReplayMemoryCombinedValue implements ReplayMemoryValue = name:"combined" ("{" params:ReplayMemoryCombinedEntry* "}")?;
interface GeneralReplayMemoryEntry extends ReplayMemoryBufferEntry, ReplayMemoryCombinedEntry;
MemorySizeEntry implements GeneralReplayMemoryEntry = name:"memory_size" ":" value:IntegerValue;
SampleSizeEntry implements GeneralReplayMemoryEntry = name:"sample_size" ":" value:IntegerValue;
// Action Selection
ActionSelectionEntry implements MultiParamConfigEntry = name:"action_selection" ":" value:ActionSelectionValue;
interface ActionSelectionValue extends MultiParamValue;
interface ActionSelectionEpsGreedyEntry extends Entry;
ActionSelectionEpsGreedyValue implements ActionSelectionValue = name:"epsgreedy" ("{" params:ActionSelectionEpsGreedyEntry* "}")?;
GreedyEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon" ":" value:NumberValue;
MinEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"min_epsilon" ":" value:NumberValue;
EpsilonDecayMethodEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue;
EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no");
EpsilonDecayEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay" ":" value:NumberValue;
// Environment
EnvironmentEntry implements MultiParamConfigEntry = name:"environment" ":" value:EnvironmentValue;
interface EnvironmentValue extends MultiParamValue;
interface GymEnvironmentEntry extends Entry;
GymEnvironmentValue implements EnvironmentValue = name:"gym" ("{" params:GymEnvironmentEntry* "}");
GymEnvironmentNameEntry implements GymEnvironmentEntry = name:"name" ":" value:StringValue;
interface RosEnvironmentEntry extends Entry;
RosEnvironmentValue implements EnvironmentValue = | name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?;
RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue;
RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue;
RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue;
RosEnvironmentGreetingTopicEntry implements RosEnvironmentEntry = name:"greeting_topic" ":" value:StringValue;
RosEnvironmentMetaTopicEntry implements RosEnvironmentEntry = name:"meta_topic" ":" value:StringValue;
RosEnvironmentTerminalStateTopicEntry implements RosEnvironmentEntry = name:"terminal_state_topic" ":" value:StringValue;
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.ArrayList;
import java.util.List;
/**
*
*/
public interface ASTMultiParamValue extends ASTMultiParamValueTOP {
String getName();
default List<? extends ASTEntry> getParamsList() {
return new ArrayList<>();
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain._ast.ASTEnvironmentEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry;
class ASTConfigurationUtils {
static boolean isReinforcementLearning(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e ->
(e instanceof ASTLearningMethodEntry)
&& ((ASTLearningMethodEntry)e).getValue().isPresentReinforcement());
}
static boolean hasEnvironment(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e -> e instanceof ASTEnvironmentEntry);
}
}
......@@ -29,7 +29,11 @@ public class CNNTrainCocos {
public static CNNTrainCoCoChecker createChecker() {
return new CNNTrainCoCoChecker()
.addCoCo(new CheckEntryRepetition())
.addCoCo(new CheckInteger());
.addCoCo(new CheckInteger())
.addCoCo(new CheckFixTargetNetworkRequiresInterval())
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
.addCoCo(new CheckRosEnvironmentRequiresRewardFunction());
}
public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){
......@@ -37,5 +41,4 @@ public class CNNTrainCocos {
int findings = Log.getFindings().size();
createChecker().checkAll(node);
}
}
}
\ No newline at end of file
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTGreedyEpsilonEntry;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -33,13 +34,17 @@ public class CheckEntryRepetition implements CNNTrainASTEntryCoCo {
@Override
public void check(ASTEntry node) {
if (entryNameSet.contains(node.getName())){
String parameterPrefix = "";
if (node instanceof ASTGreedyEpsilonEntry) {
parameterPrefix = "greedy_";
}
if (entryNameSet.contains(parameterPrefix + node.getName())){
Log.error("0" + ErrorCodes.ENTRY_REPETITION_CODE +" The parameter '" + node.getName() + "' has multiple values. " +
"Multiple assignments of the same parameter are not allowed",
node.get_SourcePositionStart());
}
else {
entryNameSet.add(node.getName());
entryNameSet.add(parameterPrefix + node.getName());
}
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.EntrySymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.List;
import java.util.Map;
/**
*
*/
public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigurationCoCo {
private static final String PARAMETER_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
private static final String PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
@Override
public void check(ASTConfiguration node) {
boolean useFixTargetNetwork = node.getEntriesList().stream()
.anyMatch(e -> e instanceof ASTUseFixTargetNetworkEntry
&& ((ASTUseFixTargetNetworkEntry)e).getValue().isPresentTRUE());
boolean hasTargetNetworkUpdateInterval = node.getEntriesList().stream()
.anyMatch(e -> (e instanceof ASTTargetNetworkUpdateIntervalEntry));
if (useFixTargetNetwork && !hasTargetNetworkUpdateInterval) {
ASTUseFixTargetNetworkEntry useFixTargetNetworkEntry = node.getEntriesList().stream()
.filter(e -> e instanceof ASTUseFixTargetNetworkEntry)
.map(e -> (ASTUseFixTargetNetworkEntry)e)
.findFirst()
.orElseThrow(() -> new IllegalStateException("ASTUseFixTargetNetwork entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + Boolean.toString(useFixTargetNetwork)
+ " requires parameter " + PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL,
useFixTargetNetworkEntry.get_SourcePositionStart());
} else if (!useFixTargetNetwork && hasTargetNetworkUpdateInterval) {
ASTTargetNetworkUpdateIntervalEntry targetNetworkUpdateIntervalEntry = node.getEntriesList().stream()
.filter(e -> e instanceof ASTTargetNetworkUpdateIntervalEntry)
.map(e -> (ASTTargetNetworkUpdateIntervalEntry)e)
.findFirst()
.orElseThrow(
() -> new IllegalStateException("ASTTargetNetworkUpdateInterval entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter "
+ targetNetworkUpdateIntervalEntry.getName() + " requires that parameter "
+ PARAMETER_USE_FIX_TARGET_NETWORK + " to be true.",
targetNetworkUpdateIntervalEntry.get_SourcePositionStart());
}
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
*
*/
public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
private final static List<Class> ALLOWED_SUPERVISED_LEARNING = Lists.newArrayList(
ASTTrainContextEntry.class,
ASTBatchSizeEntry.class,
ASTOptimizerEntry.class,
ASTLearningRateEntry.class,
ASTLoadCheckpointEntry.class,
ASTEvalMetricEntry.class,
ASTLossEntry.class,
ASTNormalizeEntry.class,
ASTMinimumLearningRateEntry.class,
ASTLRDecayEntry.class,
ASTWeightDecayEntry.class,
ASTLRPolicyEntry.class,
ASTStepSizeEntry.class,
ASTRescaleGradEntry.class,
ASTClipGradEntry.class,
ASTGamma1Entry.class,
ASTGamma2Entry.class,
ASTEpsilonEntry.class,
ASTCenteredEntry.class,
ASTClipWeightsEntry.class,
ASTBeta1Entry.class,
ASTBeta2Entry.class,
ASTNumEpochEntry.class
);
private final static List<Class> ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList(
ASTTrainContextEntry.class,
ASTOptimizerEntry.class,
ASTRewardFunctionEntry.class,
ASTMinimumLearningRateEntry.class,
ASTLRDecayEntry.class,
ASTWeightDecayEntry.class,
ASTLRPolicyEntry.class,
ASTGamma1Entry.class,
ASTGamma2Entry.class,
ASTEpsilonEntry.class,
ASTClipGradEntry.class,
ASTRescaleGradEntry.class,
ASTStepSizeEntry.class,
ASTCenteredEntry.class,
ASTClipWeightsEntry.class,
ASTLearningRateEntry.class,
ASTDiscountFactorEntry.class,
ASTNumMaxStepsEntry.class,
ASTTargetScoreEntry.class,
ASTTrainingIntervalEntry.class,
ASTUseFixTargetNetworkEntry.class,
ASTTargetNetworkUpdateIntervalEntry.class,
ASTSnapshotIntervalEntry.class,
ASTAgentNameEntry.class,
ASTGymEnvironmentNameEntry.class,
ASTEnvironmentEntry.class,
ASTUseDoubleDQNEntry.class,
ASTLossEntry.class,
ASTReplayMemoryEntry.class,
ASTMemorySizeEntry.class,
ASTSampleSizeEntry.class,
ASTActionSelectionEntry.class,
ASTGreedyEpsilonEntry.class,
ASTMinEpsilonEntry.class,
ASTEpsilonDecayEntry.class,
ASTEpsilonDecayMethodEntry.class,
ASTNumEpisodesEntry.class,
ASTRosEnvironmentActionTopicEntry.class,
ASTRosEnvironmentStateTopicEntry.class,
ASTRosEnvironmentMetaTopicEntry.class,
ASTRosEnvironmentResetTopicEntry.class,
ASTRosEnvironmentTerminalStateTopicEntry.class,
ASTRosEnvironmentGreetingTopicEntry.class
);
private Set<ASTEntry> allEntries;
private Boolean learningMethodKnown;
private LearningMethod learningMethod;
public CheckLearningParameterCombination() {
this.allEntries = new HashSet<>();