Commit be8eb31c authored by Julian Dierkes's avatar Julian Dierkes

added features for gan and InfoGan

parent 3953ada5
......@@ -84,6 +84,8 @@ public class CNNArch2Gluon extends CNNArchGenerator {
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
archTc.getStreamOutputDomains(archTc.getArchitecture().getStreams().get(0));
fileContentMap.putAll(compilePythonFiles(archTc, architecture));
fileContentMap.putAll(compileCppFiles(archTc));
......
......@@ -7,7 +7,7 @@ import de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import java.io.Writer;
import java.util.*;
......@@ -117,6 +117,10 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return getStreamInputs(stream, outputAsArray).keySet();
}
public List<String> get(Map<String, List<String>> map, String name) {
return map.get(name);
}
public List<String> getUnrollInputNames(UnrollInstructionSymbol unroll, String variable) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(unroll.getBody(), true));
Map<String, String> pairs = getUnrollPairs(unroll.getBody(), unroll.getResolvedBodies().get(0), variable);
......@@ -134,6 +138,26 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return getStreamInputs(stream, false).values();
}
public Collection<List<String>> getStreamOutputDimensions(SerialCompositeElementSymbol stream) {
return getStreamOutputs(stream, false).values();
}
public Collection<List<String>> getStreamInputInformation(SerialCompositeElementSymbol stream) {
Map<String, List<String>> dimensions = getStreamInputs(stream, false);
Map<String, List<String>> domains = getStreamInputDomains(stream);
Collection<List<String>> information = new HashSet<List<String>>();
for (String name : dimensions.keySet()) {
List<String> newEntry = new ArrayList<String>();
newEntry.add(name);
}
return null;
}
/*public Collection<List<String>> getStreamOutputsWithTypes(SerialCompositeElementSymbol stream) {
}*/
public String getOutputName() {
return getNameWithoutIndex(getName(getArchitectureOutputSymbols().get(0)));
}
......@@ -258,7 +282,75 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return pairs;
}
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
public Map<String, List<String>> getStreamInputDomains(SerialCompositeElementSymbol stream) {
Map<String, List<String>> inputTypes = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
if (element.isInput() || element.isOutput()) {
ASTElementType type = element.getOutputTypes().get(0).getDomain();
HashMap<String,String> ranges = element.getOutputTypes().get(0).getElementRange();
if (ranges.get("min") == "-inf")
ranges.put("min", "float('-inf')");
if (ranges.get("max") == "inf")
ranges.put("max", "float('inf')");
String typeAsString = new String();
if(type.isBoolean())
typeAsString = "bool";
else if (type.isComplex())
typeAsString = "complex";
else if (type.isNaturalNumber() || type.isWholeNumber())
typeAsString = "int";
else if (type.isRational())
typeAsString = "float";
String name = getName(element);
ArrayList<String> domain = new ArrayList<String>();
domain.add(typeAsString);
domain.add(ranges.get("min"));
domain.add(ranges.get("max"));
inputTypes.put(name, domain);
}
}
return inputTypes;
}
public Map<String, List<String>> getStreamOutputDomains(SerialCompositeElementSymbol stream) {
Map<String, List<String>> outputTypes = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getLastAtomicElements()) {
if (element.isInput() || element.isOutput()) {
ASTElementType type = element.getInputTypes().get(0).getDomain();
HashMap<String,String> ranges = element.getInputTypes().get(0).getElementRange();
if (ranges.get("min") == "-inf")
ranges.put("min", "float('-inf')");
if (ranges.get("max") == "inf")
ranges.put("max", "float('inf')");
String typeAsString = new String();
if(type.isBoolean())
typeAsString = "bool";
else if (type.isComplex())
typeAsString = "complex";
else if (type.isNaturalNumber() || type.isWholeNumber())
typeAsString = "int";
else if (type.isRational())
typeAsString = "float";
String name = getName(element);
ArrayList<String> domain = new ArrayList<String>();
domain.add(typeAsString);
domain.add(ranges.get("min"));
domain.add(ranges.get("max"));
outputTypes.put(name, domain);
}
}
return outputTypes;
}
public Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
......@@ -288,10 +380,42 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
inputs.putAll(getStreamLayerVariableMembers(stream, false));
return inputs;
}
public Map<String, List<String>> getStreamOutputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
Map<String, List<String>> outputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getLastAtomicElements()) {
if (element.isInput() || element.isOutput()) {
List<Integer> intDimensions = element.getPrevious().get(0).getOutputTypes().get(0).getDimensions();
List<String> dimensions = new ArrayList<>();
for (Integer intDimension : intDimensions) {
dimensions.add(intDimension.toString());
}
String name = getName(element);
if (outputAsArray && element.isOutput() && element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.IO) {
name = getNameAsArray(name);
}
}
outputs.put(name, dimensions);
}
else if (element instanceof ConstantSymbol) {
outputs.put(getName(element), Arrays.asList("1"));
}
}
outputs.putAll(getStreamLayerVariableMembers(stream, false));
return outputs;
}
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, boolean includeOutput) {
Map<String, List<String>> members = new LinkedHashMap<>();
......
......@@ -15,6 +15,7 @@ import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
......@@ -123,26 +124,46 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
"GAN");
}
NNArchitectureSymbol genericArchitectureSymbol = configuration.getDiscriminatorNetwork().get();
ArchitectureSymbol architectureSymbol
= ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol();
NNArchitectureSymbol genericDisArchitectureSymbol = configuration.getDiscriminatorNetwork().get();
ArchitectureSymbol disArchitectureSymbol
= ((ArchitectureAdapter)genericDisArchitectureSymbol).getArchitectureSymbol();
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), GAN_LEARNING_FRAMEWORK_MODULE).toString());
Map<String, String> architectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true);
final String creatorName = architectureFileContentMap.keySet().iterator().next();
final String discriminatorInstanceName = creatorName.substring(
creatorName.indexOf('_') + 1, creatorName.lastIndexOf(".py"));
Map<String, String> disArchitectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(disArchitectureSymbol, true);
final String disCreatorName = disArchitectureFileContentMap.keySet().iterator().next();
final String discriminatorInstanceName = disCreatorName.substring(
disCreatorName.indexOf('_') + 1, disCreatorName.lastIndexOf(".py"));
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
fileContentMap.putAll(disArchitectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
if (configuration.hasQNetwork()) {
NNArchitectureSymbol genericQArchitectureSymbol = configuration.getQNetwork().get();
ArchitectureSymbol qArchitectureSymbol
= ((ArchitectureAdapter)genericQArchitectureSymbol).getArchitectureSymbol();
Map<String, String> qArchitectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(qArchitectureSymbol, true);
final String qCreatorName = qArchitectureFileContentMap.keySet().iterator().next();
final String qNetworkInstanceName = qCreatorName.substring(
qCreatorName.indexOf('_') + 1, qCreatorName.lastIndexOf(".py"));
fileContentMap.putAll(qArchitectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("qNetworkInstanceName", qNetworkInstanceName);
}
ftlContext.put("ganFrameworkModule", GAN_LEARNING_FRAMEWORK_MODULE);
ftlContext.put("discriminatorInstanceName", discriminatorInstanceName);
ftlContext.put("trainerName", trainerName);
......@@ -153,9 +174,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final String ganTrainerContent = templateConfiguration.processTemplate(ftlContext, "gan/Trainer.ftl");
fileContentMap.put(trainerName + ".py", ganTrainerContent);
//final String startTrainerScriptContent = templateConfiguration.processTemplate(ftlContext, "gan/StartTrainer.ftl");
//fileContentMap.put("start_training.sh", startTrainerScriptContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
......
......@@ -55,7 +55,7 @@ class ${tc.fileNameWithoutEnding}:
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context)<#sep>, </#list>)
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((1, ${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
</#list>
if not os.path.exists(self._model_dir_):
......
......@@ -6,6 +6,7 @@ import logging
import sys
import numpy as np
import cv2
import importlib
from mxnet import nd
class ${tc.fileNameWithoutEnding}:
......@@ -110,6 +111,102 @@ class ${tc.fileNameWithoutEnding}:
return train_iter, test_iter, data_mean, data_std
def load_preprocessed_data(self, batch_size, preproc_lib):
train_h5, test_h5 = self.load_h5_files()
wrapper = importlib.import_module(preproc_lib)
instance = getattr(wrapper, preproc_lib)()
instance.init()
lib_head, _sep, tail = preproc_lib.rpartition('_')
inp = getattr(wrapper, lib_head + "_input")()
train_data = {}
train_label = {}
data_mean = {}
data_std = {}
shape_output = self.preprocess_data(instance, inp, 0, train_h5)
train_len = len(train_h5[self._input_names_[0]])
for input_name in self._input_names_:
if type(getattr(shape_output, input_name + "_out")) == np.ndarray:
cur_shape = (train_len,) + getattr(shape_output, input_name + "_out").shape
else:
cur_shape = (train_len, 1)
train_data[input_name] = mx.nd.zeros(cur_shape)
for output_name in self._output_names_:
if type(getattr(shape_output, output_name + "_out")) == nd.array:
cur_shape = (train_len,) + getattr(shape_output, output_name + "_out").shape
else:
cur_shape = (train_len, 1)
train_label[output_name] = mx.nd.zeros(cur_shape)
for i in range(train_len):
output = self.preprocess_data(instance, inp, i, train_h5)
for input_name in self._input_names_:
train_data[input_name][i] = getattr(output, input_name + "_out")
for output_name in self._output_names_:
train_label[output_name][i] = getattr(shape_output, output_name + "_out")
for input_name in self._input_names_:
data_mean[input_name + '_'] = nd.array(train_data[input_name][:].mean(axis=0))
data_std[input_name + '_'] = nd.array(train_data[input_name][:].asnumpy().std(axis=0) + 1e-5)
train_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=batch_size)
test_data = {}
test_label = {}
shape_output = self.preprocess_data(instance, inp, 0, test_h5)
test_len = len(test_h5[self._input_names_[0]])
for input_name in self._input_names_:
if type(getattr(shape_output, input_name + "_out")) == np.ndarray:
cur_shape = (test_len,) + getattr(shape_output, input_name + "_out").shape
else:
cur_shape = (test_len, 1)
test_data[input_name] = mx.nd.zeros(cur_shape)
for output_name in self._output_names_:
if type(getattr(shape_output, output_name + "_out")) == nd.array:
cur_shape = (test_len,) + getattr(shape_output, output_name + "_out").shape
else:
cur_shape = (test_len, 1)
test_label[output_name] = mx.nd.zeros(cur_shape)
for i in range(test_len):
output = self.preprocess_data(instance, inp, i, test_h5)
for input_name in self._input_names_:
test_data[input_name][i] = getattr(output, input_name + "_out")
for output_name in self._output_names_:
test_label[output_name][i] = getattr(shape_output, output_name + "_out")
test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label,
batch_size=batch_size)
return train_iter, test_iter, data_mean, data_std
def preprocess_data(self, instance_wrapper, input_wrapper, index, data_h5):
for input_name in self._input_names_:
data = data_h5[input_name][0]
attr = getattr(input_wrapper, input_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
else:
data = type(attr)(data)
setattr(input_wrapper, input_name, data)
for output_name in self._output_names_:
data = data_h5[output_name][0]
attr = getattr(input_wrapper, output_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
else:
data = type(attr)(data)
setattr(input_wrapper, output_name, data)
return instance_wrapper.execute(input_wrapper)
def load_h5_files(self):
train_h5 = None
test_h5 = None
......
......@@ -39,7 +39,7 @@ def getDataIter(ctx, batch_size=64, Z=100):
X = np.asarray([cv2.resize(x, (64,64)) for x in X])
X = X.astype(np.float32, copy=False)/(255.0/2) - 1.0
X = X.reshape((img_number, 1, 64, 64))
X = np.tile(X, (1, 3, 1, 1))
X = np.tile(X, (1, 1, 1, 1))
data = mx.nd.array(X)
for i in range(4):
......@@ -53,11 +53,17 @@ def getDataIter(ctx, batch_size=64, Z=100):
class ${tc.fileNameWithoutEnding}:
def __init__(self, data_loader, net_constructor_gen, net_constructor_dis):
def __init__(self, data_loader, net_constructor_gen, net_constructor_dis, net_constructor_qnet = None):
self._data_loader = data_loader
self._net_creator_gen = net_constructor_gen
self._net_creator_dis = net_constructor_dis
if net_constructor_qnet == None:
self.use_qnet = False
else:
self._net_creator_qnet = net_constructor_qnet
self.use_qnet = True
def train(self, batch_size=64,
num_epoch=10,
eval_metric='acc',
......@@ -71,7 +77,8 @@ class ${tc.fileNameWithoutEnding}:
normalize=True,
img_resize=(64,64),
noise_distribution='gaussian',
noise_distribution_params=(('mean_value', 0),('spread_value', 1),)):
noise_distribution_params=(('mean_value', 0),('spread_value', 1),),
preprocessing = False):
if context == 'gpu':
mx_context = mx.gpu()
......@@ -80,8 +87,34 @@ class ${tc.fileNameWithoutEnding}:
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
#train_iter = getDataIter(mx_context, batch_size, 100)
train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size, img_resize)
if self.use_qnet:
self._net_creator_qnet.construct(mx_context)
if load_checkpoint:
self._net_creator_qnet.load(mx_context)
else:
if os.path.isdir(self._net_creator_qnet._model_dir_):
shutil.rmtree(self._net_creator_qnet._model_dir_)
try:
os.makedirs(self._net_creator_qnet._model_dir_)
except OSError:
if not (os.path.isdir(self._net_creator_qnet._model_dir_)):
raise
q_net = self._net_creator_qnet.networks[0]
qnet_trainer = mx.gluon.Trainer(q_net.collect_params(), optimizer, optimizer_params)
g_input = self._data_loader._input_names_
q_input = [name[:-1] for name in q_net.getOutputs()]
new_inputs = [name for name in g_input if (name not in q_input)]
self._data_loader._input_names_ = new_inputs
if preprocessing:
preproc_lib = "CNNPreprocessor_${tc.fileNameWithoutEnding?keep_after("CNNGanTrainer_")}_executor"
train_iter = getDataIter(mx_context, batch_size, 100)
# if preprocessing:
# train_iter, test_iter, data_mean, data_std = self._data_loader.load_preprocessed_data(batch_size, preproc_lib)
# else:
# train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size, img_resize)
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
......@@ -108,7 +141,7 @@ class ${tc.fileNameWithoutEnding}:
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator_dis.load(mx_context)
begin_epoch = self._net_creator_gen.load(mx_context)
self._net_creator_gen.load(mx_context)
else:
if os.path.isdir(self._net_creator_dis._model_dir_):
shutil.rmtree(self._net_creator_dis._model_dir_)
......@@ -133,54 +166,53 @@ class ${tc.fileNameWithoutEnding}:
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
activation_name = 'sigmoid'
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
input_shape = <#list tc.getStreamInputDimensions(stream) as dimensions>${tc.join(dimensions, ",")}</#list>
</#if>
</#list>
shape_list = list(input_shape)
shape_list[0] = batch_size
input_shape = tuple(shape_list)
metric_dis = mx.metric.create(eval_metric)
metric_gen = mx.metric.create(eval_metric)
<#include "gan/InputGenerator.ftl">
if noise_distribution == "gaussian":
random_distributor = lambda : mx.ndarray.random.normal(noise_distribution_params["mean_value"],
noise_distribution_params["spread_value"],
shape=input_shape, ctx=mx_context)
speed_period = 5
speed_period = 100
tic = None
for epoch in range(begin_epoch, begin_epoch + num_epoch):
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
real_data = batch.data[0].as_in_context(mx_context)
rbatch = random_distributor()
gen_input, exp_qnet_output = create_generator_input()
fake_labels = mx.nd.zeros((batch_size), ctx=mx_context)
real_labels = mx.nd.ones((batch_size), ctx=mx_context)
with autograd.record():
fake_data = gen_net(rbatch)
fake_data = gen_net(*gen_input)
fake_data.detach()
discriminated_fake_dis = dis_net(fake_data)
if self.use_qnet:
discriminated_fake_dis, _ = discriminated_fake_dis
loss_resultF = loss_function(discriminated_fake_dis, fake_labels)
discriminated_real_dis = dis_net(real_data)
if self.use_qnet:
discriminated_real_dis, _ = discriminated_real_dis
loss_resultR = loss_function(discriminated_real_dis, real_labels)
loss_resultD = loss_resultR + loss_resultF
loss_resultD.backward()
dis_trainer.step(batch_size)
with autograd.record():
fake_data = gen_net(rbatch)
fake_data = gen_net(*gen_input)
discriminated_fake_gen = dis_net(fake_data)
if self.use_qnet:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = loss_function(discriminated_fake_gen, real_labels)
if self.use_qnet:
qnet_discriminated = list(q_net(features))
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultG.backward()
gen_trainer.step(batch_size)
if self.use_qnet:
qnet_trainer.step(batch_size)
if tic is None:
tic = time.time()
......@@ -210,9 +242,9 @@ class ${tc.fileNameWithoutEnding}:
tic = time.time()
# ugly start
#if batch_i % 20 == 0:
#if batch_i % 200 == 0:
# fake_data[0].asnumpy()
if batch_i % 50 == 0:
if batch_i % 500 == 0:
#gen_net.save_parameters(self.parameter_path_gen() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
#gen_net.export(self.parameter_path_gen() + '_newest', epoch=0)
#dis_net.save_parameters(self.parameter_path_dis() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
......@@ -239,4 +271,3 @@ class ${tc.fileNameWithoutEnding}:
def parameter_path_dis(self):
return self._net_creator_dis._model_dir_ + self._net_creator_dis._model_prefix_ + '_' + str(0)
<#-- (c) https://github.com/MontiCore/monticore -->
import mxnet as mx
import numpy as np
import math
from mxnet import gluon
......@@ -102,5 +103,30 @@ ${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
<#else>
return ${tc.join(tc.getStreamOutputNames(networkInstruction.body, false), ", ")}
</#if>
</#list>
def getInputs(self):
inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputNames(stream, false) as name>
input_dimensions = (${tc.join(dimensions[name], ",")})
input_domains = (${tc.join(domains[name], ",")})
inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
return inputs
def getOutputs(self):
outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputNames(stream, false) as name>
output_dimensions = (${tc.join(dimensions[name], ",")})
output_domains = (${tc.join(domains[name], ",")})
outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
return outputs
gen_inputs = gen_net.getInputs()
qnet_outputs = []
if self.use_qnet:
qnet_outputs = q_net.getOutputs()
qnet_losses = []
generators = {}
if self.use_qnet:
for name in qnet_outputs:
domain = gen_inputs[name]
min = domain[1]
max = domain[2]
if domain[0] == float:
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.uniform(min,max,
shape=(batch_size,)+domain[3],
dtype=domain[0], ctx=mx_context,), dtype="float32")
qnet_losses += [mx.gluon.loss.L2Loss()]
elif domain[0] == int:
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.randint(low=int(min),
high=int(max)+1, shape=(batch_size,)+domain[3],
ctx=mx_context), dtype="float32")
qnet_losses += [lambda pred, labels: mx.gluon.loss.SoftmaxCrossEntropyLoss()(pred, labels.reshape(batch_size))]
for name in gen_inputs:
if not name in qnet_outputs:
domain = gen_inputs[name]
min = domain[1]
max = domain[2]
if noise_distribution == "gaussian":
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.normal(noise_distribution_params["mean_value"],
noise_distribution_params["spread_value"],
shape=(batch_size,)+domain[3], dtype=domain[0],
ctx=mx_context), dtype="float32")
def create_generator_input():
expected_output_qnet = []
input_to_gen = []
for name in gen_inputs:
if not name in qnet_outputs:
input_to_gen += [generators[name]()]
for name in qnet_outputs:
expected_output_qnet += [generators[name]()]
input_to_gen += [generators[name]()]
return input_to_gen, expected_output_qnet
......@@ -14,6 +14,9 @@ import CNNDataLoader_${config.instanceName}
import CNNGanTrainer_${config.instanceName}
from ${ganFrameworkModule}.CNNCreator_${discriminatorInstanceName} import CNNCreator_${discriminatorInstanceName}
<#if (qNetworkInstanceName)??>
from ${ganFrameworkModule}.CNNCreator_${qNetworkInstanceName} import CNNCreator_${qNetworkInstanceName}
</#if>
if __name__ == "__main__":
......@@ -26,11 +29,17 @@ if __name__ == "__main__":
gen_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
dis_creator = CNNCreator_${discriminatorInstanceName}()
<#if (qNetworkInstanceName)??>
qnet_creator = CNNCreator_${qNetworkInstanceName}()
</#if>
${config.instanceName}_trainer = CNNGanTrainer_${config.instanceName}.CNNGanTrainer_${config.instanceName}(
data_loader,
gen_creator,
dis_creator
dis_creator,
<#if (qNetworkInstanceName)??>
qnet_creator
</#if>
)
${config.instanceName}_trainer.train(
......@@ -49,6 +58,9 @@ if __name__ == "__main__":
<#if (config.normalize)??>
normalize=${config.normalize?string("True","False")},
</#if>
<#if (config.preprocessingName)??>
preprocessing=${config.preprocessingName???string("True","False")},
</#if>
<#if (config.imgResizeWidth)??>
<#if (config.imgResizeHeight)??>
img_resize=(${config.imgResizeWidth}, ${config.imgResizeHeight}),
......
<#setting number_format="computer">
<#assign config = configurations[0]>
<#assign rlAgentType=config.rlAlgorithm?switch("dqn", "DqnAgent", "ddpg", "DdpgAgent", "td3", "TwinDelayedDdpgAgent")>
from ${ganFrameworkModule}.CNNCreator_${discriminatorInstanceName} import CNNCreator_${discriminatorInstanceName}
import CNNCreator_${config.instanceName}
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
from mxnet import gluon, autograd, nd
if __name__ = "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
<#if (config.context)??>
context = mx.${config.context}()
<#else>
context = mx.cpu()
</#if>
generator_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
generator_creator.construct(context)
discriminator_creator = CNNCreator_${discriminatorInstanceName}()
discriminator_creator.construct(context)
<#if (config.batchSize)??>
batch_size=${config.batchSize},
</#if>
<#if (config.numEpoch)??>