Commit 102c104c authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Integrate TD3 Algorithm and Gaussian Noise

parent e680d6bb
......@@ -27,7 +27,7 @@ masterJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml -Dtest=\!Integration*
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only:
......@@ -36,7 +36,7 @@ masterJobLinux:
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags:
- Windows10
......@@ -44,13 +44,7 @@ BranchJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- cat target/site/jacoco/index.html
except:
- master
PythonWrapperIntegrationTest:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.2-SNAPSHOT</version>
<version>0.2.6-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,10 +16,10 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.3.6-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.2-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......
......@@ -29,4 +29,4 @@ public class CNNArch2GluonCli {
GenericCNNArchCli cli = new GenericCNNArchCli(generator);
cli.run(args);
}
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerationPair;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerator;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
......@@ -10,15 +11,12 @@ import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
import de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.se_rwth.commons.logging.Log;
import java.io.File;
......@@ -54,13 +52,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
@Override
public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configurationSymbol = super.getConfigurationSymbol(modelsDirPath, rootModelName);
// Generate Reward function if necessary
if (configurationSymbol.getLearningMethod().equals(LearningMethod.REINFORCEMENT)
&& configurationSymbol.getRlRewardFunction().isPresent()) {
generateRewardFunction(configurationSymbol.getRlRewardFunction().get(), modelsDirPath);
}
return configurationSymbol;
}
......@@ -93,17 +84,25 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
}
public void generate(Path modelsDirPath, String rootModelName, TrainedArchitecture trainedArchitecture) {
public void generate(Path modelsDirPath,
String rootModelName,
NNArchitectureSymbol trainedArchitecture,
NNArchitectureSymbol criticNetwork) {
ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName);
configurationSymbol.setTrainedArchitecture(trainedArchitecture);
configurationSymbol.setCriticNetwork(criticNetwork);
this.setRootProjectModelsDir(modelsDirPath.toString());
generateFilesFromConfigurationSymbol(configurationSymbol);
}
public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) {
generate(modelsDirPath, rootModelName, trainedArchitecture, null);
}
@Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
ReinforcementConfigurationData configData = new ReinforcementConfigurationData(configuration, getInstanceName());
GluonConfigurationData configData = new GluonConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData);
......@@ -119,24 +118,39 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
if (rlAlgorithm.equals(RLAlgorithm.DDPG)) {
CriticNetworkGenerator criticNetworkGenerator = new CriticNetworkGenerator();
criticNetworkGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
if (getRootProjectModelsDir().isPresent()) {
criticNetworkGenerator.setRootModelsDir(getRootProjectModelsDir().get());
} else {
Log.error("No root model dir set");
if (rlAlgorithm.equals(RLAlgorithm.DDPG)
|| rlAlgorithm.equals(RLAlgorithm.TD3)) {
if (!configuration.getCriticNetwork().isPresent()) {
Log.error("No architecture model for critic available but is required for chosen " +
"actor-critic algorithm");
}
NNArchitectureSymbol genericArchitectureSymbol = configuration.getCriticNetwork().get();
ArchitectureSymbol architectureSymbol
= ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol();
CriticNetworkGenerationPair criticNetworkResult
= criticNetworkGenerator.generateCriticNetworkContent(templateConfiguration, configuration);
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
Map<String, String> architectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true);
fileContentMap.putAll(criticNetworkResult.getFileContent().entrySet().stream().collect(Collectors.toMap(
final String creatorName = architectureFileContentMap.keySet().iterator().next();
final String criticInstanceName = creatorName.substring(
creatorName.indexOf('_') + 1, creatorName.lastIndexOf(".py"));
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("criticInstanceName", criticNetworkResult.getCriticNetworkName());
ftlContext.put("criticInstanceName", criticInstanceName);
}
// Generate Reward function if necessary
if (configuration.getRlRewardFunction().isPresent()) {
generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
}
ftlContext.put("trainerName", trainerName);
......@@ -164,8 +178,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
setRootProjectModelsDir(modelsDirPath.toString());
}
rewardFunctionSourceGenerator.generate(getRootProjectModelsDir().get(),
rewardFunctionRootModel, rewardFunctionOutputPath);
final TaggingResolver taggingResolver
= rewardFunctionSourceGenerator.createTaggingResolver(getRootProjectModelsDir().get());
final EMAComponentInstanceSymbol emaSymbol
= rewardFunctionSourceGenerator.resolveSymbol(taggingResolver, rewardFunctionRootModel);
rewardFunctionSourceGenerator.generate(emaSymbol, taggingResolver, rewardFunctionOutputPath);
fixArmadilloEmamGenerationOfFile(Paths.get(rewardFunctionOutputPath, String.join("_", fullNameOfComponent) + ".h"));
String pythonWrapperOutputPath = Paths.get(rewardFunctionOutputPath, "pylib").toString();
......@@ -175,12 +192,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (pythonWrapperApi.checkIfPythonModuleBuildAvailable()) {
final String rewardModuleOutput
= Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString();
componentPortInformation = pythonWrapperApi.generateAndTryBuilding(getRootProjectModelsDir().get(),
rewardFunctionRootModel, pythonWrapperOutputPath, rewardModuleOutput);
componentPortInformation = pythonWrapperApi.generateAndTryBuilding(emaSymbol,
pythonWrapperOutputPath, rewardModuleOutput);
} else {
Log.warn("Cannot build wrapper automatically: OS not supported. Please build manually before starting training.");
componentPortInformation = pythonWrapperApi.generate(getRootProjectModelsDir().get(), rewardFunctionRootModel,
pythonWrapperOutputPath);
componentPortInformation = pythonWrapperApi.generate(emaSymbol, pythonWrapperOutputPath);
}
RewardFunctionParameterAdapter functionParameter = new RewardFunctionParameterAdapter(componentPortInformation);
new FunctionParameterChecker().check(functionParameter);
......
......@@ -2,8 +2,8 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator.annotations;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.symboltable.CommonSymbol;
......@@ -14,11 +14,13 @@ import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkNotNull;
public class ArchitectureAdapter implements TrainedArchitecture {
public class ArchitectureAdapter extends NNArchitectureSymbol {
private ArchitectureSymbol architectureSymbol;
public ArchitectureAdapter(final ArchitectureSymbol architectureSymbol) {
public ArchitectureAdapter(final String name,
final ArchitectureSymbol architectureSymbol) {
super(name);
checkNotNull(architectureSymbol);
this.architectureSymbol = architectureSymbol;
}
......@@ -55,6 +57,10 @@ public class ArchitectureAdapter implements TrainedArchitecture {
s -> s.getDefinition().getType().getDomain().getName()));
}
public ArchitectureSymbol getArchitectureSymbol() {
return this.architectureSymbol;
}
private Range astRangeToTrainRange(final ASTRange range) {
if (range == null || (range.hasNoLowerLimit() && range.hasNoUpperLimit())) {
return Range.withInfinityLimits();
......
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
/**
*
*/
public interface RewardFunctionSourceGenerator {
void generate(String modelPath, String qualifiedName, String targetPath);
TaggingResolver createTaggingResolver(String modelPath);
EMAComponentInstanceSymbol resolveSymbol(TaggingResolver taggingResolver, String rootModel);
void generate(String modelPath, String rootModel, String targetPath);
void generate(EMAComponentInstanceSymbol componentInstanceSymbol, TaggingResolver taggingResolver, String targetPath);
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic;
public class CriticNetworkGenerationException extends RuntimeException {
public CriticNetworkGenerationException(String s) {
super("Generation of critic network failed: " + s);
}
}
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic;
import java.util.Map;
public class CriticNetworkGenerationPair {
private String criticNetworkName;
private Map<String, String> fileContent;
public CriticNetworkGenerationPair(String criticNetworkName, Map<String, String> fileContent) {
this.criticNetworkName = criticNetworkName;
this.fileContent = fileContent;
}
public String getCriticNetworkName() {
return criticNetworkName;
}
public Map<String, String> getFileContent() {
return fileContent;
}
}
<#setting number_format="computer">
<#assign config = configurations[0]>
<#assign rlAgentType=config.rlAlgorithm?switch("dqn", "DqnAgent", "ddpg", "DdpgAgent")>
<#assign rlAgentType=config.rlAlgorithm?switch("dqn", "DqnAgent", "ddpg", "DdpgAgent", "td3", "TwinDelayedDdpgAgent")>
from ${rlFrameworkModule}.agent import ${rlAgentType}
from ${rlFrameworkModule}.util import AgentSignalHandler
from ${rlFrameworkModule}.cnnarch_logger import ArchLogger
<#if config.rlAlgorithm=="ddpg">
<#if config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
from ${rlFrameworkModule}.CNNCreator_${criticInstanceName} import CNNCreator_${criticInstanceName}
</#if>
import ${rlFrameworkModule}.environment
......@@ -137,8 +137,10 @@ if __name__ == "__main__":
</#if>
<#if (config.rlAlgorithm == "dqn")>
<#include "params/DqnAgentParams.ftl">
<#else>
<#elseif config.rlAlgorithm == "ddpg">
<#include "params/DdpgAgentParams.ftl">
<#else>
<#include "params/Td3AgentParams.ftl">
</#if>
}
......@@ -168,7 +170,7 @@ if __name__ == "__main__":
if train_successful:
<#if (config.rlAlgorithm == "dqn")>
agent.save_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
<#else>
agent.save_best_network(actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
agent.export_best_network(path=actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
</#if>
\ No newline at end of file
......@@ -13,18 +13,21 @@ class StrategyBuilder(object):
epsilon_decay_method='no',
epsilon_decay=0.0,
epsilon_decay_start=0,
epsilon_decay_per_step=False,
action_dim=None,
action_low=None,
action_high=None,
mu=0.0,
theta=0.5,
sigma=0.3
sigma=0.3,
noise_variance=0.1
):
if epsilon_decay_method == 'linear':
decay = LinearDecay(
eps_decay=epsilon_decay, min_eps=min_epsilon,
decay_start=epsilon_decay_start)
decay_start=epsilon_decay_start,
decay_per_step=epsilon_decay_per_step)
else:
decay = NoDecay()
......@@ -44,6 +47,13 @@ class StrategyBuilder(object):
return OrnsteinUhlenbeckStrategy(
action_dim, action_low, action_high, epsilon, mu, theta,
sigma, decay)
elif method == 'gaussian':
assert action_dim is not None
assert action_low is not None
assert action_high is not None
assert noise_variance is not None
return GaussianNoiseStrategy(action_dim, action_low, action_high,
epsilon, noise_variance, decay)
else:
assert action_dim is not None
assert len(action_dim) == 1
......@@ -70,17 +80,27 @@ class NoDecay(BaseDecay):
class LinearDecay(BaseDecay):
def __init__(self, eps_decay, min_eps=0, decay_start=0):
def __init__(self, eps_decay, min_eps=0, decay_start=0, decay_per_step=False):
super(LinearDecay, self).__init__()
self.eps_decay = eps_decay
self.min_eps = min_eps
self.decay_start = decay_start
self.decay_per_step = decay_per_step
self.last_episode = -1
def decay(self, cur_eps, episode):
if episode < self.decay_start:
return cur_eps
def do_decay(self, episode):
if self.decay_per_step:
do = (episode >= self.decay_start)
else:
do = ((self.last_episode != episode) and (episode >= self.decay_start))
self.last_episode = episode
return do
def decay(self, cur_eps, episode):
if self.do_decay(episode):
return max(cur_eps - self.eps_decay, self.min_eps)
else:
return cur_eps
class BaseStrategy(object):
......@@ -170,3 +190,29 @@ class OrnsteinUhlenbeckStrategy(BaseStrategy):
noise = self._evolve_state()
action = (1.0 - self.cur_eps) * values + (self.cur_eps * noise)
return np.clip(action, self._action_low, self._action_high)
class GaussianNoiseStrategy(BaseStrategy):
def __init__(
self,
action_dim,
action_low,
action_high,
eps,
noise_variance,
decay=NoDecay()
):
super(GaussianNoiseStrategy, self).__init__(decay)
self.eps = eps
self.cur_eps = eps
self._action_dim = action_dim
self._action_low = action_low
self._action_high = action_high
self._noise_variance = noise_variance
def select_action(self, values):
noise = np.random.normal(loc=0.0, scale=self._noise_variance, size=self._action_dim)
action = values + self.cur_eps * noise
return np.clip(action, self._action_low, self._action_high)
architecture ${architectureName}() {
def input ${stateType}<#if stateRange??>(<#if stateRange.isLowerLimitInfinity()>-oo<#else>${stateRange.lowerLimit.get()}</#if>:<#if stateRange.isUpperLimitInfinity()>oo<#else>${stateRange.upperLimit.get()}</#if>)</#if>^{<#list stateDimension as d>${d}<#if d?has_next>,</#if></#list>} state
def input ${actionType}<#if actionRange??>(<#if actionRange.isLowerLimitInfinity()>-oo<#else>${actionRange.lowerLimit.get()}</#if>:<#if actionRange.isUpperLimitInfinity()>oo<#else>${actionRange.upperLimit.get()}</#if>)</#if>^{<#list actionDimension as d>${d}<#if d?has_next>,</#if></#list>} action
def output Q(-oo:oo)^{1} qvalue
${implementation}->FullyConnected(units=1)->qvalue;
}
\ No newline at end of file
......@@ -151,7 +151,6 @@ class RosEnvironment(Environment):
def reset(self):
self.__in_reset = True
time.sleep(0.5)
reset_message = Bool()
reset_message.data = True
self.__waiting_for_state_update = True
......@@ -187,7 +186,8 @@ class RosEnvironment(Environment):
next_state = self.__last_received_state
terminal = self.__last_received_terminal
reward = <#if config.hasRosRewardTopic()>self.__last_received_reward<#else>self.__calc_reward(next_state, terminal)</#if>
rospy.logdebug('Calculated reward: {}'.format(reward))
logger.debug('Transition: ({}, {}, {}, {})'.format(action, reward, next_state, terminal))
return next_state, reward, terminal, 0
......@@ -206,25 +206,24 @@ class RosEnvironment(Environment):
else:
rospy.logerr("Timeout 3 times in a row: Terminate application")
exit()
time.sleep(100/1000)
time.sleep(1/500)
def close(self):
rospy.signal_shutdown('Program ended!')
def __state_callback(self, data):
self.__last_received_state = np.array(data.data, dtype='float32').reshape((<#list config.stateDim as d>${d},</#list>))
rospy.logdebug('Received state: {}'.format(self.__last_received_state))
logger.debug('Received state: {}'.format(self.__last_received_state))
self.__waiting_for_state_update = False
def __terminal_state_callback(self, data):
self.__last_received_terminal = data.data
rospy.logdebug('Received terminal flag: {}'.format(self.__last_received_terminal))
logger.debug('Received terminal: {}'.format(self.__last_received_terminal))
self.__last_received_terminal = np.bool(data.data)
logger.debug('Received terminal flag: {}'.format(self.__last_received_terminal))
self.__waiting_for_terminal_update = False
<#if config.hasRosRewardTopic()>
def __reward_callback(self, data):
self.__last_received_reward = float(data.data)
self.__last_received_reward = np.float32(data.data)
logger.debug('Received reward: {}'.format(self.__last_received_reward))
self.__waiting_for_reward_update = False
<#else>
......
'actor': actor_creator.net,
'critic': critic_creator.net,
'actor': actor_creator.networks[0],
'critic': critic_creator.networks[0],
<#if (config.softTargetUpdateRate)??>
'soft_target_update_rate': ${config.softTargetUpdateRate},
</#if>
......
'qnet':qnet_creator.net,
'qnet':qnet_creator.networks[0],
<#if (config.useFixTargetNetwork)?? && config.useFixTargetNetwork>
'use_fix_target': True,
'target_update_interval': ${config.targetNetworkUpdateInterval},
......@@ -6,7 +6,7 @@
'use_fix_target': False,
</#if>
<#if (config.configuration.loss)??>
'loss': '${config.lossName}',
'loss_function': '${config.lossName}',
<#if (config.lossParams)??>
'loss_params': {
<#list config.lossParams?keys as param>
......
......@@ -10,5 +10,5 @@
'method': 'online',
</#if>
'state_dtype': 'float32',
'action_dtype': <#if config.rlAlgorithm=="DQN">'uint8'<#else>'float32'</#if>,
'action_dtype': <#if config.rlAlgorithm=="dqn">'uint8'<#else>'float32'</#if>,
'rewards_dtype': 'float32'
......@@ -16,7 +16,15 @@
<#if (config.strategy.epsilon_decay_start)??>
'epsilon_decay_start': ${config.strategy.epsilon_decay_start},
</#if>
<#if (config.strategy.method)?? && (config.strategy.method=="ornstein_uhlenbeck")>
<#if (config.strategy.epsilon_decay_per_step)??>
'epsilon_decay_per_step': ${config.strategy.epsilon_decay_per_step?string('True', 'False')},
</#if>
<#if (config.strategy.method=="gaussian")>
<#if (config.strategy.noise_variance)??>
'noise_variance': ${config.strategy.noise_variance},
</#if>
</#if>
<#if (config.strategy.method)?? && (config.strategy.method=="ornstein_uhlenbeck" || config.strategy.method=="gaussian")>
<#if (config.strategy.action_low)?? >
'action_low': ${config.strategy.action_low},
<#else>
......@@ -27,6 +35,8 @@
<#else>
'action_high' : np.infty,
</#if>
</#if>
<#if (config.strategy.method)?? && (config.strategy.method=="ornstein_uhlenbeck")>
<#if (config.strategy.mu)??>
'mu': [<#list config.strategy.mu as m>${m}<#if m?has_next>, </#if></#list>],
</#if>
......
<#include "DdpgAgentParams.ftl">
<#if (config.policyNoise)??>
'policy_noise': ${config.policyNoise},
</#if>
<#if (config.noiseClip)??>
'noise_clip': ${config.noiseClip},
</#if>
<#if (config.policyDelay)??>
'policy_delay': ${config.policyDelay},
</#if>
\ No newline at end of file
......@@ -11,8 +11,8 @@ import cnnarch_logger
LOSS_FUNCTIONS = {
'l1': gluon.loss.L1Loss(),
'euclidean': gluon.loss.L2Loss(),
'huber_loss': gluon.loss.HuberLoss(),
'l2': gluon.loss.L2Loss(),
'huber': gluon.loss.HuberLoss(),
'softmax_cross_entropy': gluon.loss.SoftmaxCrossEntropyLoss(),
'sigmoid_cross_entropy': gluon.loss.SigmoidBinaryCrossEntropyLoss()}
......
......@@ -20,12 +20,9 @@
*/
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.util.TrainedArchitectureMockFactory;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.cnnarch.gluongenerator.util.NNArchitectureMockFactory;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
......@@ -37,15 +34,9 @@ import java.io.IOException;
import java.nio.file.Path;