Commit 2dec1bac authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Added description on how to export internal network layers, e.g. attention

See merge request !26
parents ae909ef1 86b9d59f
<!-- (c) https://github.com/MontiCore/monticore -->
![pipeline](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/CNNArch2MXNet/badges/master/build.svg)
![coverage](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/CNNArch2MXNet/badges/master/coverage.svg)
## How to export inner network layers, e.g. an attention matrix
In order to visualize attention from an attention network, the data from this layer has to be returned from the network. Two steps are neccessary for that.
1. The layer that should be exported has to be defined as a VariableSymbol. In order to do this, the keyword `layer` can be used in front of the respective layer. The definition of the layer has to be made before it is actually used in the network.
For example, one could define an attention layer as `layer FullyConnected(units = 1, flatten=false) attention;` at the beginning of a network. By filling this layer with data, e.g. `
input ->
... ->
attention ->
...
`, the data in this layer will be saved until the end of the network iteration.
2. In order to make the network return the saved data, the networks name must be added to the 'AllAttentionModels' class in this project. Furthermore, the layer must either be named `attention`, or the CNNNet.ftl template has to be adjusted to return differently named layers.
\ No newline at end of file
......@@ -20,7 +20,7 @@
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.5-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
......
......@@ -84,6 +84,8 @@ public class CNNArch2Gluon extends CNNArchGenerator {
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
archTc.getStreamOutputDomains(archTc.getArchitecture().getStreams().get(0));
fileContentMap.putAll(compilePythonFiles(archTc, architecture));
fileContentMap.putAll(compileCppFiles(archTc));
......
......@@ -7,7 +7,7 @@ import de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import java.io.Writer;
import java.util.*;
......@@ -117,6 +117,20 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return getStreamInputs(stream, outputAsArray).keySet();
}
public ArrayList<String> getStreamInputVariableNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
ArrayList<String> inputVariableNames = new ArrayList<String>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
if (element.isInput()) {
inputVariableNames.add(getName(element));
}
}
return inputVariableNames;
}
public List<String> get(Map<String, List<String>> map, String name) {
return map.get(name);
}
public List<String> getUnrollInputNames(UnrollInstructionSymbol unroll, String variable) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(unroll.getBody(), true));
Map<String, String> pairs = getUnrollPairs(unroll.getBody(), unroll.getResolvedBodies().get(0), variable);
......@@ -134,6 +148,36 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return getStreamInputs(stream, false).values();
}
public Collection<List<String>> getStreamOutputDimensions(SerialCompositeElementSymbol stream) {
return getStreamOutputs(stream, false).values();
}
public ArrayList<String> getStreamOutputVariableNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
ArrayList<String> outputVariableNames = new ArrayList<String>();
for (ArchitectureElementSymbol element : stream.getLastAtomicElements()) {
if (element.isOutput()) {
outputVariableNames.add(getName(element));
}
}
return outputVariableNames;
}
public Collection<List<String>> getStreamInputInformation(SerialCompositeElementSymbol stream) {
Map<String, List<String>> dimensions = getStreamInputs(stream, false);
Map<String, List<String>> domains = getStreamInputDomains(stream);
Collection<List<String>> information = new HashSet<List<String>>();
for (String name : dimensions.keySet()) {
List<String> newEntry = new ArrayList<String>();
newEntry.add(name);
}
return null;
}
/*public Collection<List<String>> getStreamOutputsWithTypes(SerialCompositeElementSymbol stream) {
}*/
public String getOutputName() {
return getNameWithoutIndex(getName(getArchitectureOutputSymbols().get(0)));
}
......@@ -258,7 +302,75 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return pairs;
}
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
public Map<String, List<String>> getStreamInputDomains(SerialCompositeElementSymbol stream) {
Map<String, List<String>> inputTypes = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
if (element.isInput() || element.isOutput()) {
ASTElementType type = element.getOutputTypes().get(0).getDomain();
HashMap<String,String> ranges = element.getOutputTypes().get(0).getElementRange();
if (ranges.get("min") == "-inf")
ranges.put("min", "float('-inf')");
if (ranges.get("max") == "inf")
ranges.put("max", "float('inf')");
String typeAsString = new String();
if(type.isBoolean())
typeAsString = "bool";
else if (type.isComplex())
typeAsString = "complex";
else if (type.isNaturalNumber() || type.isWholeNumber())
typeAsString = "int";
else if (type.isRational())
typeAsString = "float";
String name = getName(element);
ArrayList<String> domain = new ArrayList<String>();
domain.add(typeAsString);
domain.add(ranges.get("min"));
domain.add(ranges.get("max"));
inputTypes.put(name, domain);
}
}
return inputTypes;
}
public Map<String, List<String>> getStreamOutputDomains(SerialCompositeElementSymbol stream) {
Map<String, List<String>> outputTypes = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getLastAtomicElements()) {
if (element.isInput() || element.isOutput()) {
ASTElementType type = element.getInputTypes().get(0).getDomain();
HashMap<String,String> ranges = element.getInputTypes().get(0).getElementRange();
if (ranges.get("min") == "-inf")
ranges.put("min", "float('-inf')");
if (ranges.get("max") == "inf")
ranges.put("max", "float('inf')");
String typeAsString = new String();
if(type.isBoolean())
typeAsString = "bool";
else if (type.isComplex())
typeAsString = "complex";
else if (type.isNaturalNumber() || type.isWholeNumber())
typeAsString = "int";
else if (type.isRational())
typeAsString = "float";
String name = getName(element);
ArrayList<String> domain = new ArrayList<String>();
domain.add(typeAsString);
domain.add(ranges.get("min"));
domain.add(ranges.get("max"));
outputTypes.put(name, domain);
}
}
return outputTypes;
}
public Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
......@@ -288,10 +400,42 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
inputs.putAll(getStreamLayerVariableMembers(stream, false));
return inputs;
}
public Map<String, List<String>> getStreamOutputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
Map<String, List<String>> outputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getLastAtomicElements()) {
if (element.isInput() || element.isOutput()) {
List<Integer> intDimensions = element.getPrevious().get(0).getOutputTypes().get(0).getDimensions();
List<String> dimensions = new ArrayList<>();
for (Integer intDimension : intDimensions) {
dimensions.add(intDimension.toString());
}
String name = getName(element);
if (outputAsArray && element.isOutput() && element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.IO) {
name = getNameAsArray(name);
}
}
outputs.put(name, dimensions);
}
else if (element instanceof ConstantSymbol) {
outputs.put(getName(element), Arrays.asList("1"));
}
}
outputs.putAll(getStreamLayerVariableMembers(stream, false));
return outputs;
}
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, boolean includeOutput) {
Map<String, List<String>> members = new LinkedHashMap<>();
......
......@@ -15,6 +15,7 @@ import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
......@@ -123,26 +124,46 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
"GAN");
}
NNArchitectureSymbol genericArchitectureSymbol = configuration.getDiscriminatorNetwork().get();
ArchitectureSymbol architectureSymbol
= ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol();
NNArchitectureSymbol genericDisArchitectureSymbol = configuration.getDiscriminatorNetwork().get();
ArchitectureSymbol disArchitectureSymbol
= ((ArchitectureAdapter)genericDisArchitectureSymbol).getArchitectureSymbol();
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), GAN_LEARNING_FRAMEWORK_MODULE).toString());
Map<String, String> architectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true);
final String creatorName = architectureFileContentMap.keySet().iterator().next();
final String discriminatorInstanceName = creatorName.substring(
creatorName.indexOf('_') + 1, creatorName.lastIndexOf(".py"));
Map<String, String> disArchitectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(disArchitectureSymbol, true);
final String disCreatorName = disArchitectureFileContentMap.keySet().iterator().next();
final String discriminatorInstanceName = disCreatorName.substring(
disCreatorName.indexOf('_') + 1, disCreatorName.lastIndexOf(".py"));
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
fileContentMap.putAll(disArchitectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
if (configuration.hasQNetwork()) {
NNArchitectureSymbol genericQArchitectureSymbol = configuration.getQNetwork().get();
ArchitectureSymbol qArchitectureSymbol
= ((ArchitectureAdapter)genericQArchitectureSymbol).getArchitectureSymbol();
Map<String, String> qArchitectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(qArchitectureSymbol, true);
final String qCreatorName = qArchitectureFileContentMap.keySet().iterator().next();
final String qNetworkInstanceName = qCreatorName.substring(
qCreatorName.indexOf('_') + 1, qCreatorName.lastIndexOf(".py"));
fileContentMap.putAll(qArchitectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("qNetworkInstanceName", qNetworkInstanceName);
}
ftlContext.put("ganFrameworkModule", GAN_LEARNING_FRAMEWORK_MODULE);
ftlContext.put("discriminatorInstanceName", discriminatorInstanceName);
ftlContext.put("trainerName", trainerName);
......@@ -153,9 +174,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final String ganTrainerContent = templateConfiguration.processTemplate(ftlContext, "gan/Trainer.ftl");
fileContentMap.put(trainerName + ".py", ganTrainerContent);
//final String startTrainerScriptContent = templateConfiguration.processTemplate(ftlContext, "gan/StartTrainer.ftl");
//fileContentMap.put("start_training.sh", startTrainerScriptContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
......
......@@ -169,19 +169,14 @@ public class GluonConfigurationData extends ConfigurationData {
return getMultiParamEntry(NOISE_DISTRIBUTION, "name");
}
public String getImgResizeWidth() {
if (!this.getConfiguration().getEntryMap().containsKey("img_resize_width")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("img_resize_width").getValue());
public Map<String, Map<String, Object>> getConstraintDistributions() {
return getMultiParamMapEntry(CONSTRAINT_DISTRIBUTION, "name");
}
public String getImgResizeHeight() {
if (!this.getConfiguration().getEntryMap().containsKey("img_resize_height")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("img_resize_height").getValue());
public Map<String, Map<String, Object>> getConstraintLosses() {
return getMultiParamMapEntry(CONSTRAINT_LOSS, "name");
}
public Map<String, Object> getStrategy() {
assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " +
" non reinforcement learning context";
......
......@@ -6,6 +6,7 @@ import logging
import sys
import numpy as np
import cv2
import importlib
from mxnet import nd
class ${tc.fileNameWithoutEnding}:
......@@ -15,7 +16,7 @@ class ${tc.fileNameWithoutEnding}:
def __init__(self):
self._data_dir = "${tc.dataPath}/"
def load_data(self, train_batch_size, test_batch_size):
def load_data(self, batch_size, shuffle=False):
train_h5, test_h5 = self.load_h5_files()
train_data = {}
......@@ -39,11 +40,8 @@ class ${tc.fileNameWithoutEnding}:
train_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=train_batch_size)
train_test_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=test_batch_size)
batch_size=batch_size,
shuffle=shuffle)
test_iter = None
......@@ -64,51 +62,112 @@ class ${tc.fileNameWithoutEnding}:
test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label,
batch_size=test_batch_size)
batch_size=batch_size)
return train_iter, train_test_iter, test_iter, data_mean, data_std, train_images, test_images
return train_iter, test_iter, data_mean, data_std, train_images, test_images
def load_data_img(self, batch_size, img_size):
def load_preprocessed_data(self, batch_size, preproc_lib, shuffle=False):
train_h5, test_h5 = self.load_h5_files()
width = img_size[0]
height = img_size[1]
comb_data = {}
wrapper = importlib.import_module(preproc_lib)
instance = getattr(wrapper, preproc_lib)()
instance.init()
lib_head, _sep, tail = preproc_lib.rpartition('_')
inp = getattr(wrapper, lib_head + "_input")()
train_data = {}
train_label = {}
data_mean = {}
data_std = {}
shape_output = self.preprocess_data(instance, inp, 0, train_h5)
train_len = len(train_h5[self._input_names_[0]])
for input_name in self._input_names_:
train_data = train_h5[input_name][:]
test_data = test_h5[input_name][:]
if type(getattr(shape_output, input_name + "_out")) == np.ndarray:
cur_shape = (train_len,) + getattr(shape_output, input_name + "_out").shape
else:
cur_shape = (train_len, 1)
train_data[input_name] = mx.nd.zeros(cur_shape)
for output_name in self._output_names_:
if type(getattr(shape_output, output_name + "_out")) == nd.array:
cur_shape = (train_len,) + getattr(shape_output, output_name + "_out").shape
else:
cur_shape = (train_len, 1)
train_label[output_name] = mx.nd.zeros(cur_shape)
train_shape = train_data.shape
test_shape = test_data.shape
for i in range(train_len):
output = self.preprocess_data(instance, inp, i, train_h5)
for input_name in self._input_names_:
train_data[input_name][i] = getattr(output, input_name + "_out")
for output_name in self._output_names_:
train_label[output_name][i] = getattr(shape_output, output_name + "_out")
comb_data[input_name] = mx.nd.zeros((train_shape[0]+test_shape[0], train_shape[1], width, height))
for i, img in enumerate(train_data):
img = img.transpose(1,2,0)
comb_data[input_name][i] = cv2.resize(img, (width, height)).reshape((train_shape[1],width,height))
for i, img in enumerate(test_data):
img = img.transpose(1, 2, 0)
comb_data[input_name][i+train_shape[0]] = cv2.resize(img, (width, height)).reshape((train_shape[1], width, height))
for input_name in self._input_names_:
data_mean[input_name + '_'] = nd.array(train_data[input_name][:].mean(axis=0))
data_std[input_name + '_'] = nd.array(train_data[input_name][:].asnumpy().std(axis=0) + 1e-5)
data_mean[input_name + '_'] = nd.array(comb_data[input_name][:].mean(axis=0))
data_std[input_name + '_'] = nd.array(comb_data[input_name][:].asnumpy().std(axis=0) + 1e-5)
if 'images' in train_h5:
train_images = train_h5['images']
train_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=batch_size,
shuffle=shuffle)
comb_label = {}
test_data = {}
test_label = {}
shape_output = self.preprocess_data(instance, inp, 0, test_h5)
test_len = len(test_h5[self._input_names_[0]])
for input_name in self._input_names_:
if type(getattr(shape_output, input_name + "_out")) == np.ndarray:
cur_shape = (test_len,) + getattr(shape_output, input_name + "_out").shape
else:
cur_shape = (test_len, 1)
test_data[input_name] = mx.nd.zeros(cur_shape)
for output_name in self._output_names_:
train_labels = train_h5[output_name][:]
test_labels = test_h5[output_name][:]
comb_label[output_name] = np.append(train_labels, test_labels, axis=0)
if type(getattr(shape_output, output_name + "_out")) == nd.array:
cur_shape = (test_len,) + getattr(shape_output, output_name + "_out").shape
else:
cur_shape = (test_len, 1)
test_label[output_name] = mx.nd.zeros(cur_shape)
for i in range(test_len):
output = self.preprocess_data(instance, inp, i, test_h5)
for input_name in self._input_names_:
test_data[input_name][i] = getattr(output, input_name + "_out")
for output_name in self._output_names_:
test_label[output_name][i] = getattr(shape_output, output_name + "_out")
train_iter = mx.io.NDArrayIter(data=comb_data,
label=comb_label,
if 'images' in test_h5:
test_images = test_h5['images']
test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label,
batch_size=batch_size)
test_iter = None
return train_iter, test_iter, data_mean, data_std, train_images, test_images
return train_iter, test_iter, data_mean, data_std
def preprocess_data(self, instance_wrapper, input_wrapper, index, data_h5):
for input_name in self._input_names_:
data = data_h5[input_name][0]
attr = getattr(input_wrapper, input_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
else:
data = type(attr)(data)
setattr(input_wrapper, input_name, data)
for output_name in self._output_names_:
data = data_h5[output_name][0]
attr = getattr(input_wrapper, output_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
else:
data = type(attr)(data)
setattr(input_wrapper, output_name, data)
return instance_wrapper.execute(input_wrapper)
def load_h5_files(self):
train_h5 = None
......
<#-- (c) https://github.com/MontiCore/monticore -->
import mxnet as mx
import numpy as np
import math
from mxnet import gluon
......@@ -52,10 +53,10 @@ class Reshape(gluon.HybridBlock):
class CustomRNN(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, bidirectional, **kwargs):
def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
super(CustomRNN, self).__init__(**kwargs)
with self.name_scope():
self.rnn = gluon.rnn.RNN(hidden_size=hidden_size, num_layers=num_layers,
self.rnn = gluon.rnn.RNN(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
bidirectional=bidirectional, activation='tanh', layout='NTC')
def hybrid_forward(self, F, data, state0):
......@@ -64,10 +65,10 @@ class CustomRNN(gluon.HybridBlock):
class CustomLSTM(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, bidirectional, **kwargs):
def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
super(CustomLSTM, self).__init__(**kwargs)
with self.name_scope():
self.lstm = gluon.rnn.LSTM(hidden_size=hidden_size, num_layers=num_layers,
self.lstm = gluon.rnn.LSTM(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
bidirectional=bidirectional, layout='NTC')
def hybrid_forward(self, F, data, state0, state1):
......@@ -76,10 +77,10 @@ class CustomLSTM(gluon.HybridBlock):
class CustomGRU(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, bidirectional, **kwargs):
def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
super(CustomGRU, self).__init__(**kwargs)
with self.name_scope():
self.gru = gluon.rnn.GRU(hidden_size=hidden_size, num_layers=num_layers,
self.gru = gluon.rnn.GRU(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
bidirectional=bidirectional, layout='NTC')
def hybrid_forward(self, F, data, state0):
......@@ -102,5 +103,30 @@ ${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
<#else>
return ${tc.join(tc.getStreamOutputNames(networkInstruction.body, false), ", ")}
</#if>
</#list>
def getInputs(self):
inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
input_dimensions = (${tc.join(dimensions[name], ",")})
input_domains = (${tc.join(domains[name], ",")})
inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
return inputs
def getOutputs(self):
outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
output_dimensions = (${tc.join(dimensions[name], ",")})
output_domains = (${tc.join(domains[name], ",")})
outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
return outputs
......@@ -181,16 +181,21 @@ class ${tc.fileNameWithoutEnding}:
num_epoch=10,
eval_metric='acc',
eval_metric_params={},
eval_train=False,
loss ='softmax_cross_entropy',
loss_params={},
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
context='gpu',
checkpoint_period=5,
log_period=50,
context='gpu',
save_attention_image=False,
use_teacher_forcing=False,
normalize=True):
normalize=True,
shuffle_data=False,
clip_global_grad_norm=None,
preprocessing = False):
if context == 'gpu':
mx_context = mx.gpu()
elif context == 'cpu':
......@@ -198,6 +203,12 @@ class ${tc.fileNameWithoutEnding}:
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
if preprocessing:
preproc_lib = "CNNPreprocessor_${tc.fileNameWithoutEnding?keep_after("CNNSupervisedTrainer_")}_executor"
train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
else:
train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
del optimizer_params['weight_decay']
......@@ -213,11 +224,6 @@ class ${tc.fileNameWithoutEnding}:
del optimizer_params['step_size']
del optimizer_params['learning_rate_decay']
train_batch_size = batch_size
test_batch_size = batch_size
train_iter, train_test_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(train_batch_size, test_batch_size)
if normalize:
self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
else:
......@@ -275,11 +281,20 @@ class ${tc.fileNameWithoutEnding}:
else:
logging.error("Invalid loss parameter.")