Commit 68ee368d authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'develop' into 'master'

Develop

See merge request !30
parents 6c1c21bd d157159f
Pipeline #325057 passed with stage
in 4 minutes and 4 seconds
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.11-SNAPSHOT</version>
<version>0.2.12-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,9 +17,9 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.5-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.10-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.6-SNAPSHOT</CNNArch2X.version>
<CNNArch.version>0.3.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.12-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.7-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
......
......@@ -71,15 +71,18 @@ public class CNNArch2Gluon extends CNNArchGenerator {
temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = controller.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
temp = controller.process("CNNModelLoader", Target.CPP);
fileContentMap.put("CNNModelLoader.h", temp.getValue());
return fileContentMap;
}
private Map<String, String> compileFileContentMap(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
architecture.processForEpisodicReplayMemory();
Map<String, String> fileContentMap = new HashMap<>();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
......
......@@ -41,6 +41,10 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.BROADCAST_ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME);
// supportedLayerList.add(AllPredefinedLayers.CROP_NAME);
supportedLayerList.add(AllPredefinedLayers.LARGE_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.EPISODIC_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.DOT_PRODUCT_SELF_ATTENTION_NAME);
supportedLayerList.add(AllPredefinedLayers.LOAD_NETWORK_NAME);
}
}
......@@ -87,6 +87,17 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
setCurrentElement(previousElement);
}
public void include(SerialCompositeElementSymbol compositeElement, Integer episodicSubNetIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getEpisodicSubNetworks().get(episodicSubNetIndex)){
include(element, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode);
......@@ -106,6 +117,10 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElementSymbol, Integer episodicSubNetIndex, String netDefinitionMode) {
include(architectureElementSymbol, episodicSubNetIndex, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElement, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
......@@ -113,10 +128,21 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElement, getWriter(), netDefinitionMode);
}
public void include(ArchitectureElementSymbol architectureElement, Integer episodicSubNetIndex, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include((SerialCompositeElementSymbol) architectureElement, episodicSubNetIndex, getWriter(), netDefinitionMode);
}
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
return getStreamInputs(stream, outputAsArray).keySet();
}
public Set<String> getSubnetInputNames(List<ArchitectureElementSymbol> subNet) {
return getSubnetInputs(subNet).keySet();
}
public ArrayList<String> getStreamInputVariableNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
ArrayList<String> inputVariableNames = new ArrayList<String>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
......@@ -226,6 +252,29 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return outputNames;
}
public Set<String> getSubnetOutputNames(List<ArchitectureElementSymbol> subNet){
Set<String> outputNames = new LinkedHashSet<>();
for (ArchitectureElementSymbol element : subNet.get(subNet.size()-1).getLastAtomicElements()) {
String name = getName(element);
outputNames.add(name);
}
return outputNames;
}
public int getSubnetOutputSize(List<ArchitectureElementSymbol> subNet){
int outputSize = 0;
for (ArchitectureElementSymbol element : subNet.get(subNet.size()-1).getLastAtomicElements()) {
outputSize += element.getOutputTypes().size();
}
if(outputSize == 0){
outputSize = 1;
}
return outputSize;
}
public List<String> getUnrollOutputNames(UnrollInstructionSymbol unroll, String variable) {
List<String> outputNames = new LinkedList<>(getStreamOutputNames(unroll.getBody(), true));
Map<String, String> pairs = getUnrollPairs(unroll.getBody(), unroll.getResolvedBodies().get(0), variable);
......@@ -403,6 +452,31 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return inputs;
}
public Map<String, List<String>> getSubnetInputs(List<ArchitectureElementSymbol> subNet) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : subNet.get(0).getFirstAtomicElements()) {
if (element instanceof ConstantSymbol) {
inputs.put(getName(element), Arrays.asList("1"));
}
else {
List<Integer> intDimensions = element.getOutputTypes().get(0).getDimensions();
List<String> dimensions = new ArrayList<>();
for (Integer intDimension : intDimensions) {
dimensions.add(intDimension.toString());
}
String name = getName(element);
inputs.put(name, dimensions);
}
}
return inputs;
}
public Map<String, List<String>> getStreamOutputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
Map<String, List<String>> outputs = new LinkedHashMap<>();
......@@ -519,6 +593,15 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return dimensions;
}
public List<Integer> cutDimensionsInteger(List<Integer> dimensions) {
while (dimensions.size() > 1 && dimensions.get(dimensions.size() - 1).equals(1)) {
dimensions.remove(dimensions.size() - 1);
}
return dimensions;
}
public boolean hasUnrollInstructions() {
for (NetworkInstructionSymbol networkInstruction : getArchitecture().getNetworkInstructions()) {
if (networkInstruction.isUnroll()) {
......
......@@ -130,9 +130,27 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
Map<String, String> fileContentMap = new HashMap<>();
//Context Information and Optimizer for local adaption during prediction for replay memory layer (the second only applicaple for supervised learning)
String cnnTrainLAOptimizerTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNLAOptimizer.ftl");
fileContentMap.put("CNNLAOptimizer_" + getInstanceName() + ".h", cnnTrainLAOptimizerTemplateContent);
//AdamW optimizer if used for training
if(configuration.getOptimizer() != null) {
String optimizerName = configuration.getOptimizer().getName();
Optional<OptimizerSymbol> criticOptimizer = configuration.getCriticOptimizer();
String criticOptimizerName = "";
if (criticOptimizer.isPresent()) {
criticOptimizerName = criticOptimizer.get().getName();
}
if (optimizerName.equals("adamw") || criticOptimizerName.equals("adamw")) {
String adamWContent = templateConfiguration.processTemplate(ftlContext, "Optimizer/AdamW.ftl");
fileContentMap.put("AdamW.py", adamWContent);
}
}
if (configData.isSupervisedLearning()) {
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
String cnnTrainTrainerTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTrainerTemplateContent);
} else if (configData.isGan()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
if (!configuration.getDiscriminatorNetwork().isPresent()) {
......@@ -189,7 +207,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final String ganTrainerContent = templateConfiguration.processTemplate(ftlContext, "gan/Trainer.ftl");
fileContentMap.put(trainerName + ".py", ganTrainerContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
......
......@@ -6,6 +6,7 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator;
*/
public enum NetDefinitionMode {
ARCHITECTURE_DEFINITION,
PREDICTION_PARAMETER,
FORWARD_FUNCTION;
public static NetDefinitionMode fromString(final String netDefinitionMode) {
......@@ -14,6 +15,8 @@ public enum NetDefinitionMode {
return ARCHITECTURE_DEFINITION;
case "FORWARD_FUNCTION":
return FORWARD_FUNCTION;
case "PREDICTION_PARAMETER":
return PREDICTION_PARAMETER;
default:
throw new IllegalArgumentException("Unknown Net Definition Mode");
}
......
<#-- (c) https://github.com/MontiCore/monticore -->
#ifndef CNNBUFFERFILE_H
#define CNNBUFFERFILE_H
#include <stdio.h>
#include <iostream>
#include <fstream>
// Read file to buffer
class BufferFile {
public :
std::string file_path_;
int length_;
char* buffer_;
explicit BufferFile(std::string file_path)
:file_path_(file_path) {
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
std::cerr << "Can't open the file. Please check " << file_path << ". \n";
length_ = 0;
buffer_ = NULL;
return;
}
ifs.seekg(0, std::ios::end);
length_ = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
buffer_ = new char[sizeof(char) * length_];
ifs.read(buffer_, length_);
ifs.close();
}
int GetLength() {
return length_;
}
char* GetBuffer() {
return buffer_;
}
~BufferFile() {
if (buffer_) {
delete[] buffer_;
buffer_ = NULL;
}
}
};
#endif // CNNBUFFERFILE_H
......@@ -3,6 +3,8 @@ import mxnet as mx
import logging
import os
import shutil
import warnings
import inspect
<#list tc.architecture.networkInstructions as networkInstruction>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
......@@ -27,6 +29,10 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
if hasattr(network, 'episodic_sub_nets'):
num_episodic_sub_nets = len(network.episodic_sub_nets)
lastMemEpoch = [0]*num_episodic_sub_nets
mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
......@@ -37,22 +43,77 @@ class ${tc.fileNameWithoutEnding}:
except OSError:
pass
if hasattr(network, 'episodic_sub_nets'):
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
except OSError:
pass
for j in range(len(network.episodic_sub_nets)):
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
if epoch > lastEpoch:
if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
memSubNet = int(relMemPathInfo[0])
memEpochStr = relMemPathInfo[1]
memEpoch = int(memEpochStr)
if memEpoch >= lastMemEpoch[memSubNet-1]:
lastMemEpoch[memSubNet-1] = memEpoch
mem_files[memSubNet-1] = file
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
if hasattr(network, 'episodic_sub_nets'):
for j, sub_net in enumerate(network.episodic_sub_nets):
if mem_files[j] != None:
logging.info("Loading Replay Memory: " + mem_files[j])
mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
mem_layer.load_memory(self._model_dir_ + mem_files[j])
if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
earliestLastEpoch = lastEpoch
if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
......@@ -63,28 +124,56 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if hasattr(network, 'episodic_sub_nets'):
num_episodic_sub_nets = len(network.episodic_sub_nets)
lastMemEpoch = [0] * num_episodic_sub_nets
mem_files = [None] * num_episodic_sub_nets
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
memSubNet = int(relMemPathInfo[0])
memEpochStr = relMemPathInfo[1]
memEpoch = int(memEpochStr)
if memEpoch >= lastMemEpoch[memSubNet-1]:
lastMemEpoch[memSubNet-1] = memEpoch
mem_files[memSubNet-1] = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
if hasattr(network, 'episodic_sub_nets'):
assert lastEpoch == lastMemEpoch
for j, sub_net in enumerate(network.episodic_sub_nets):
if mem_files[j] != None:
logging.info("Loading pretrained Replay Memory: " + mem_files[j])
mem_layer = \
[param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
param[0].startswith("memory")][0][1]
mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
<#list tc.architecture.networkInstructions as networkInstruction>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context)<#sep>, </#list>)
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
<#if networkInstruction.body.episodicSubNetworks?has_content>
self.networks[0].episodicsubnet0_(<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
</#if>
</#list>
if not os.path.exists(self._model_dir_):
......
......@@ -184,16 +184,16 @@ class ${tc.fileNameWithoutEnding}:
del discriminator_optimizer_params['learning_rate_decay']
if normalize:
self._net_creator_dis.construct(mx_context, data_mean=data_mean, data_std=data_std)
self._net_creator_dis.construct([mx_context], data_mean=data_mean, data_std=data_std)
else:
self._net_creator_dis.construct(mx_context)
self._net_creator_dis.construct([mx_context])
self._net_creator_gen.construct(mx_context)
self._net_creator_gen.construct([mx_context])
if self.use_qnet:
self._net_creator_qnet.construct(mx_context)
self._net_creator_qnet.construct([mx_context])
if load_checkpoint:
self._net_creator_qnet.load(mx_context)
self._net_creator_qnet.load([mx_context])
else:
if os.path.isdir(self._net_creator_qnet._model_dir_):
shutil.rmtree(self._net_creator_qnet._model_dir_)
......@@ -206,8 +206,8 @@ class ${tc.fileNameWithoutEnding}:
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator_dis.load(mx_context)
self._net_creator_gen.load(mx_context)
begin_epoch = self._net_creator_dis.load([mx_context])
self._net_creator_gen.load([mx_context])
else:
if os.path.isdir(self._net_creator_dis._model_dir_):
shutil.rmtree(self._net_creator_dis._model_dir_)
......@@ -255,9 +255,9 @@ class ${tc.fileNameWithoutEnding}:
gen_input, exp_qnet_output = create_generator_input(batch)
with autograd.record():
fake_data = gen_net(*gen_input)
fake_data = gen_net(*gen_input)[0][0]
fake_data.detach()
discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)
discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_fake_dis, _ = discriminated_fake_dis
......@@ -265,7 +265,7 @@ class ${tc.fileNameWithoutEnding}:
real_labels = mx.nd.ones(discriminated_fake_dis.shape, ctx=mx_context)
loss_resultF = dis_loss(discriminated_fake_dis, fake_labels)
discriminated_real_dis = dis_net(real_data, *dis_conditional_input)
discriminated_real_dis = dis_net(real_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_real_dis, _ = discriminated_real_dis
loss_resultR = dis_loss(discriminated_real_dis, real_labels)
......@@ -276,8 +276,8 @@ class ${tc.fileNameWithoutEnding}:
if batch_i % k_value == 0:
with autograd.record():
fake_data = gen_net(*gen_input)
discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)
fake_data = gen_net(*gen_input)[0][0]
discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = dis_loss(discriminated_fake_gen, real_labels)
......@@ -285,7 +285,7 @@ class ${tc.fileNameWithoutEnding}:
condition = batch.data[traindata_to_index[generator_target_name + "_"]]
loss_resultG = loss_resultG + gen_loss_weight * generator_loss_func(fake_data, condition)
if self.use_qnet:
qnet_discriminated = [q_net(features)]
qnet_discriminated = [q_net(features)[0][0]]
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultG.backward()
......
<#-- (c) https://github.com/MontiCore/monticore -->
<#list configurations as config>
#ifndef CNNLAOPTIMIZER_${config.instanceName?upper_case}
#define CNNLAOPTIMIZER_${config.instanceName?upper_case}
#include <mxnet-cpp/MxNetCpp.h>
#include <string>
#include <vector>
#include <memory>
using namespace mxnet::cpp;
class CNNLAOptimizer_${config.instanceName}{
private:
Optimizer *optimizerHandle;
<#if (config.context)??>
std::string context_name = "${config.context}";
<#else>
std::string context_name = "cpu";
</#if>
public:
explicit CNNLAOptimizer_${config.instanceName}(){
<#if (config.configuration.optimizer)??>
<#if config.optimizerName == "adamw">
optimizerHandle = OptimizerRegistry::Find("adam");
<#else>
optimizerHandle = OptimizerRegistry::Find("${config.optimizerName}");
</#if>
<#list config.optimizerParams?keys as param>
<#if param == "learning_rate">
optimizerHandle->SetParam("lr", ${config.optimizerParams[param]});
<#elseif param == "weight_decay">
optimizerHandle->SetParam("wd", ${config.optimizerParams[param]});
<#elseif param == "learning_rate_decay">
<#assign learningRateDecay = config.optimizerParams[param]>
<#elseif param == "learning_rate_minimum">
<#assign minLearningRate = config.optimizerParams[param]>
<#elseif param == "step_size">
<#assign stepSize = config.optimizerParams[param]>
<#else>
optimizerHandle->SetParam("${param}", ${config.optimizerParams[param]});
</#if>
</#list>
<#if (learningRateDecay)?? && (stepSize)??>
<#if !(minLearningRate)??>
<#assign minLearningRate = "1e-08">
</#if>
std::unique_ptr<LRScheduler> lrScheduler(new FactorScheduler(${stepSize}, ${learningRateDecay}, ${minLearningRate}));
optimizerHandle->SetLRScheduler(std::move(lrScheduler));
</#if>
<#else>
optimizerHandle = OptimizerRegistry::Find("adam");
optimizerHandle->SetParam("lr", 0.001);
</#if>
<#if (config.clipGlobalGradNorm)??>