Commit b43861f8 authored by Julian Dierkes's avatar Julian Dierkes
Browse files

changes for GAN training

parent 5ad5d879
This diff is collapsed.
......@@ -84,8 +84,6 @@ public class CNNArch2Gluon extends CNNArchGenerator {
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
archTc.getStreamOutputDomains(archTc.getArchitecture().getStreams().get(0));
fileContentMap.putAll(compilePythonFiles(archTc, architecture));
fileContentMap.putAll(compileCppFiles(archTc));
......
......@@ -2,9 +2,12 @@
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingComponentParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingPortChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
......@@ -78,9 +81,9 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
try {
Iterator var6 = fileContents.keySet().iterator();
while(var6.hasNext()) {
String fileName = (String)var6.next();
genCPP.generateFile(new FileContent((String)fileContents.get(fileName), fileName));
while (var6.hasNext()) {
String fileName = (String) var6.next();
genCPP.generateFile(new FileContent((String) fileContents.get(fileName), fileName));
}
} catch (IOException var8) {
Log.error("CNNTrainer file could not be generated" + var8.getMessage());
......@@ -117,16 +120,16 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (configData.isSupervisedLearning()) {
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
} else if(configData.isGan()) {
} else if (configData.isGan()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
if(!configuration.getDiscriminatorNetwork().isPresent()) {
if (!configuration.getDiscriminatorNetwork().isPresent()) {
Log.error("No architecture model for discriminator available but is required for chosen " +
"GAN");
}
NNArchitectureSymbol genericDisArchitectureSymbol = configuration.getDiscriminatorNetwork().get();
ArchitectureSymbol disArchitectureSymbol
= ((ArchitectureAdapter)genericDisArchitectureSymbol).getArchitectureSymbol();
= ((ArchitectureAdapter) genericDisArchitectureSymbol).getArchitectureSymbol();
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
......@@ -147,7 +150,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (configuration.hasQNetwork()) {
NNArchitectureSymbol genericQArchitectureSymbol = configuration.getQNetwork().get();
ArchitectureSymbol qArchitectureSymbol
= ((ArchitectureAdapter)genericQArchitectureSymbol).getArchitectureSymbol();
= ((ArchitectureAdapter) genericQArchitectureSymbol).getArchitectureSymbol();
Map<String, String> qArchitectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(qArchitectureSymbol, true);
......@@ -179,7 +182,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
if (rlAlgorithm.equals(RLAlgorithm.DDPG)
|| rlAlgorithm.equals(RLAlgorithm.TD3)) {
|| rlAlgorithm.equals(RLAlgorithm.TD3)) {
if (!configuration.getCriticNetwork().isPresent()) {
Log.error("No architecture model for critic available but is required for chosen " +
......@@ -187,7 +190,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
NNArchitectureSymbol genericArchitectureSymbol = configuration.getCriticNetwork().get();
ArchitectureSymbol architectureSymbol
= ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol();
= ((ArchitectureAdapter) genericArchitectureSymbol).getArchitectureSymbol();
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
......@@ -201,8 +204,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
creatorName.indexOf('_') + 1, creatorName.lastIndexOf(".py"));
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("criticInstanceName", criticInstanceName);
......@@ -215,7 +218,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
} else {
Log.error("No architecture model for the trained neural network but is required for " +
"reinforcement learning configuration.");
"reinforcement learning configuration.");
}
}
......@@ -234,7 +237,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
private void generateRewardFunction(NNArchitectureSymbol trainedArchitecture,
RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
GeneratorPythonWrapperStandaloneApi pythonWrapperApi = new GeneratorPythonWrapperStandaloneApi();
List<String> fullNameOfComponent = rewardFunctionSymbol.getRewardFunctionComponentName();
......@@ -271,7 +274,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
rewardFunctionSymbol.setRewardFunctionParameter(functionParameter);
}
private void fixArmadilloEmamGenerationOfFile(Path pathToBrokenFile){
private void fixArmadilloEmamGenerationOfFile(Path pathToBrokenFile) {
final File brokenFile = pathToBrokenFile.toFile();
if (brokenFile.exists()) {
try {
......@@ -301,19 +304,19 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/agent.py", reinforcementAgentContent);
final String reinforcementStrategyContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/Strategy.ftl");
ftlContext, "reinforcement/agent/Strategy.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/strategy.py", reinforcementStrategyContent);
final String replayMemoryContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/ReplayMemory.ftl");
ftlContext, "reinforcement/agent/ReplayMemory.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/replay_memory.py", replayMemoryContent);
final String environmentContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/environment/Environment.ftl");
ftlContext, "reinforcement/environment/Environment.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/environment.py", environmentContent);
final String utilContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/util/Util.ftl");
ftlContext, "reinforcement/util/Util.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/util.py", utilContent);
final String initContent = "";
......@@ -322,3 +325,4 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
return fileContentMap;
}
}
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing;
import de.monticore.lang.monticar.cnntrain.annotations.PreprocessingComponentParameter;
import de.monticore.lang.monticar.cnntrain.annotations.RewardFunctionParameter;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/**
*
*/
public class PreprocessingComponentParameterAdapter implements PreprocessingComponentParameter {
private final ComponentPortInformation adaptee;
private String outputParameterName;
private String inputStateParameterName;
private String inputTerminalParameterName;
public PreprocessingComponentParameterAdapter(final ComponentPortInformation componentPortInformation) {
this.adaptee = componentPortInformation;
}
@Override
public List<String> getInputNames() {
return this.adaptee.getAllInputs().stream()
.map(PortVariable::getVariableName)
.collect(Collectors.toList());
}
@Override
public List<String> getOutputNames() {
return this.adaptee.getAllOutputs().stream()
.map(PortVariable::getVariableName)
.collect(Collectors.toList());
}
@Override
public Optional<String> getTypeOfInputPort(String portName) {
return this.adaptee.getAllInputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(port -> port.getEmadlType().toString())
.findFirst();
}
@Override
public Optional<String> getTypeOfOutputPort(String portName) {
return this.adaptee.getAllOutputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(port -> port.getEmadlType().toString())
.findFirst();
}
@Override
public Optional<List<Integer>> getInputPortDimensionOfPort(String portName) {
return this.adaptee.getAllInputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(PortVariable::getDimension)
.findFirst();
}
@Override
public Optional<List<Integer>> getOutputPortDimensionOfPort(String portName) {
return this.adaptee.getAllOutputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(PortVariable::getDimension)
.findFirst();
}
public Optional<String> getOutputParameterName() {
if (this.outputParameterName == null) {
if (this.getOutputNames().size() == 1) {
this.outputParameterName = this.getOutputNames().get(0);
} else {
return Optional.empty();
}
}
return Optional.of(this.outputParameterName);
}
private boolean isBooleanScalar(final PortVariable portVariable) {
return portVariable.getEmadlType().equals(EmadlType.B)
&& portVariable.getDimension().size() == 1
&& portVariable.getDimension().get(0) == 1;
}
private boolean determineInputNames() {
if (this.getInputNames().size() != 2) {
return false;
}
Optional<String> terminalInput = this.adaptee.getAllInputs()
.stream()
.filter(this::isBooleanScalar)
.map(PortVariable::getVariableName)
.findFirst();
if (terminalInput.isPresent()) {
this.inputTerminalParameterName = terminalInput.get();
} else {
return false;
}
Optional<String> stateInput = this.adaptee.getAllInputs().stream()
.filter(portVariable -> !portVariable.getVariableName().equals(this.inputTerminalParameterName))
.filter(portVariable -> !isBooleanScalar(portVariable))
.map(PortVariable::getVariableName)
.findFirst();
if (stateInput.isPresent()) {
this.inputStateParameterName = stateInput.get();
} else {
this.inputTerminalParameterName = null;
return false;
}
return true;
}
public Optional<String> getInputStateParameterName() {
if (this.inputStateParameterName == null) {
this.determineInputNames();
}
return Optional.ofNullable(this.inputStateParameterName);
}
public Optional<String> getInputTerminalParameter() {
if (this.inputTerminalParameterName == null) {
this.determineInputNames();
}
return Optional.ofNullable(this.inputTerminalParameterName);
}
}
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.ListIterator;
import java.util.Set;
/**
*
*/
public class PreprocessingPortChecker {
public PreprocessingPortChecker() { }
static public void check(final PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
assert preprocessingComponentParameter != null;
checkEqualNumberofInAndOutPorts(preprocessingComponentParameter);
checkCorrectPortNames(preprocessingComponentParameter);
}
static private void checkEqualNumberofInAndOutPorts(PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
failIfConditionFails(equalNumberOfInAndOutPorts(preprocessingComponentParameter),
"The number of in- and output ports of the " +
"preprocessing component is not equal");
}
static private boolean equalNumberOfInAndOutPorts(PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
return preprocessingComponentParameter.getInputNames().size()
== preprocessingComponentParameter.getOutputNames().size();
}
static private void checkCorrectPortNames(PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
failIfConditionFails(correctPortNames(preprocessingComponentParameter),
"The output ports are not correctly named with \"_out\" appendix");
}
static private boolean correctPortNames(PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
ListIterator<String> iterator = preprocessingComponentParameter.getInputNames().listIterator();
Set<String> inputs = new HashSet<String>();
while (iterator.hasNext()) {
inputs.add(iterator.next() + "_out");
}
Set<String> outputs = new HashSet<String>(preprocessingComponentParameter.getOutputNames());
return inputs.equals(outputs);
}
static private void failIfConditionFails(final boolean condition, final String message) {
if (!condition) {
fail(message);
}
}
static private void fail(final String message) {
Log.error(message);
//System.exit(-1);
}
}
......@@ -5,7 +5,6 @@ import mxnet as mx
import logging
import sys
import numpy as np
import cv2
import importlib
from mxnet import nd
......@@ -79,6 +78,7 @@ class ${tc.fileNameWithoutEnding}:
train_label = {}
data_mean = {}
data_std = {}
train_images = {}
shape_output = self.preprocess_data(instance, inp, 0, train_h5)
train_len = len(train_h5[self._input_names_[0]])
......@@ -141,6 +141,7 @@ class ${tc.fileNameWithoutEnding}:
for output_name in self._output_names_:
test_label[output_name][i] = getattr(shape_output, output_name + "_out")
test_images = {}
if 'images' in test_h5:
test_images = test_h5['images']
......@@ -152,7 +153,7 @@ class ${tc.fileNameWithoutEnding}:
def preprocess_data(self, instance_wrapper, input_wrapper, index, data_h5):
for input_name in self._input_names_:
data = data_h5[input_name][0]
data = data_h5[input_name][index]
attr = getattr(input_wrapper, input_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
......@@ -160,7 +161,7 @@ class ${tc.fileNameWithoutEnding}:
data = type(attr)(data)
setattr(input_wrapper, input_name, data)
for output_name in self._output_names_:
data = data_h5[output_name][0]
data = data_h5[output_name][index]
attr = getattr(input_wrapper, output_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
......
......@@ -12,7 +12,6 @@ class CrossEntropyLoss(gluon.loss.Loss):
self._axis = axis
self._sparse_label = sparse_label
def hybrid_forward(self, F, pred, label, sample_weight=None):
pred = F.log(pred)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
......@@ -54,16 +53,22 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
"""
# ugly hardcoded
#import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import pyplot
def visualize(img_arr):
plt.imshow((img_arr.asnumpy().transpose(1, 2, 0) * 255).astype(np.uint8).reshape(28,28))
plt.axis('off')
"""
img_np = img_arr.asnumpy().transpose(1, 2, 0)
img_np = ((img_np+1) * 127.5).astype(np.uint8)
if not img_np.shape[2] == 1:
pyplot.axis('off')
pyplot.imshow(img_np)
else:
pyplot.axis('off')
s = img_np.shape
img_np = img_np.reshape((s[0], s[1]))
pyplot.imshow(img_np, cmap = 'Greys')
# ugly hardcoded
def getDataIter(ctx, batch_size=64, Z=100):
img_number = 500
mnist_train = mx.gluon.data.vision.datasets.MNIST(train=True)
......@@ -92,12 +97,6 @@ def getDataIter(ctx, batch_size=64, Z=100):
X = np.tile(X, (1, 1, 1, 1))
data = mx.nd.array(X)
"""
for i in range(4):
plt.subplot(1,4,i+1)
visualize(data[i])
plt.show()
"""
image_iter = mx.io.NDArrayIter(data, batch_size=batch_size)
return image_iter
......@@ -119,8 +118,6 @@ class ${tc.fileNameWithoutEnding}:
def train(self, batch_size=64,
num_epoch=10,
eval_metric='acc',
loss ='softmax_cross_entropy',
loss_params={},
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
......@@ -136,8 +133,12 @@ class ${tc.fileNameWithoutEnding}:
preprocessing = False,
k_value = 1,
generator_loss = None,
conditional_input = None,
noise_input = None):
generator_target_name = "",
noise_input = "",
gen_loss_weight = 1,
dis_loss_weight = 1,
log_period = 50,
print_images = False):
if context == 'gpu':
mx_context = mx.gpu()
......@@ -151,20 +152,29 @@ class ${tc.fileNameWithoutEnding}:
dis_input_names = list(self._net_creator_dis.getInputs().keys())
dis_input_names = [name[:-1] for name in dis_input_names]
if self.use_qnet:
qnet_input_names = list(self._net_creator_qnet.getInputs().keys())
qnet_input_names = list(self._net_creator_qnet.getOutputs().keys())
qnet_input_names = [name[:-1] for name in qnet_input_names]
dis_real_input = list(self._net_creator_gen.getOutputs().keys())[0][:-1]
gen_output_name = list(self._net_creator_gen.getOutputs().keys())[0][:-1]
if self.use_qnet:
cGAN_input_names = set(gen_input_names).difference(qnet_input_names)
cGAN_input_names.discard(noise_input)
cGAN_input_names = list(cGAN_input_names)
else:
cGAN_input_names = set(gen_input_names)
cGAN_input_names.discard(noise_input)
cGAN_input_names = list(cGAN_input_names)
if preprocessing:
preproc_lib = "CNNPreprocessor_${tc.fileNameWithoutEnding?keep_after("CNNGanTrainer_")}_executor"
self._data_loader._output_names_ = []
if self.use_qnet:
dataloader_inputs = set(gen_input_names + dis_input_names).difference(qnet_input_names)
dataloader_inputs.discard(noise_input)
if not generator_target_name == "":
self._data_loader._input_names_ = cGAN_input_names + [gen_output_name] + [generator_target_name]
else:
dataloader_inputs = set(gen_input_names + dis_input_names)
dataloader_inputs.discard(noise_input)
self._data_loader._input_names_ = list(dataloader_inputs)
self._data_loader._input_names_ = cGAN_input_names + [gen_output_name]
# train_iter = getDataIter(mx_context, batch_size, 100)
if preprocessing:
......@@ -255,21 +265,12 @@ class ${tc.fileNameWithoutEnding}:
if self.use_qnet:
qnet_trainer = mx.gluon.Trainer(q_net.collect_params(), discriminator_optimizer, discriminator_optimizer_params)
if loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'log_cosh':
loss_function = LogCoshLoss()
else:
logging.error("Invalid loss parameter.")
dis_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
if not generator_loss == None:
if generator_loss == "L2":
if generator_loss == "l2":
generator_loss_func = mx.gluon.loss.L2Loss()
elif generator_loss == "L1":
elif generator_loss == "l1":
generator_loss_func = mx.gluon.loss.L1Loss()
else:
logging.error("Invalid generator loss parameter")
......@@ -278,34 +279,34 @@ class ${tc.fileNameWithoutEnding}:
metric_gen = mx.metric.create(eval_metric)
<#include "gan/InputGenerator.ftl">
speed_period = 100
tic = None
for epoch in range(begin_epoch, begin_epoch + num_epoch):
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
real_data = batch.data[traindata_to_index[dis_input_names[0]]].as_in_context(mx_context)
real_data = batch.data[traindata_to_index[dis_real_input + "_"]].as_in_context(mx_context)
dis_conditional_input = create_discriminator_input(batch)
gen_input, exp_qnet_output = create_generator_input(batch)
fake_labels = mx.nd.zeros((batch_size), ctx=mx_context)
real_labels = mx.nd.ones((batch_size), ctx=mx_context)
with autograd.record():
fake_data = gen_net(*gen_input)
fake_data.detach()
discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)
if self.use_qnet:
discriminated_fake_dis, _ = discriminated_fake_dis
loss_resultF = loss_function(discriminated_fake_dis, fake_labels)
fake_labels = mx.nd.zeros(discriminated_fake_dis.shape, ctx=mx_context)
real_labels = mx.nd.ones(discriminated_fake_dis.shape, ctx=mx_context)
loss_resultF = dis_loss(discriminated_fake_dis, fake_labels)
discriminated_real_dis = dis_net(real_data, *dis_conditional_input)
if self.use_qnet:
discriminated_real_dis, _ = discriminated_real_dis
loss_resultR = loss_function(discriminated_real_dis, real_labels)
loss_resultR = dis_loss(discriminated_real_dis, real_labels)
loss_resultD = loss_resultR + loss_resultF
loss_resultD.backward()
loss_resultD = dis_loss_weight * (loss_resultR + loss_resultF)
loss_resultD.backward()
dis_trainer.step(batch_size)
if batch_i % k_value == 0:
......@@ -314,15 +315,15 @@ class ${tc.fileNameWithoutEnding}:
discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)
if self.use_qnet:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = loss_function(discriminated_fake_gen, real_labels)
loss_resultG = dis_loss(discriminated_fake_gen, real_labels)
if not generator_loss == None:
condition = batch.data[traindata_to_index[conditional_input + "_"]]
loss_resultG = loss_resultG + generator_loss_func(fake_data, condition)
condition = batch.data[traindata_to_index[generator_target_name + "_"]]
loss_resultG = loss_resultG + gen_loss_weight * generator_loss_func(fake_data, condition)
if self.use_qnet:
qnet_discriminated = [q_net(features)]
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultG.backward()
loss_resultG.backward()
gen_trainer.step(batch_size)