Commit 1296641d authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns Committed by Thomas Michael Timmermanns

Implemented generation of CNNCreator.

parent 7a403446
......@@ -89,7 +89,7 @@ public class CNNArchGenerator {
temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("Network", Target.PYTHON);
temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("execute", Target.CPP);
......@@ -98,9 +98,32 @@ public class CNNArchGenerator {
temp = archTc.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
checkValidGeneration(architecture);
return fileContentMap;
}
private void checkValidGeneration(ArchitectureSymbol architecture){
if (architecture.getInputs().size() > 1){
Log.warn("This cnn architecture has multiple inputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
, architecture.getSourcePosition());
}
if (architecture.getOutputs().size() > 1){
Log.warn("This cnn architecture has multiple outputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
, architecture.getSourcePosition());
}
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1){
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the generator."
, architecture.getSourcePosition());
}
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public void generateFiles(ArchitectureSymbol architecture) throws IOException{
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
......
......@@ -97,11 +97,11 @@ public class CNNArchTemplateController {
}
public String getArchitectureName(){
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getName();
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getName().replaceAll("\\.","_");
}
public String getFullArchitectureName(){
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getFullName();
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getFullName().replaceAll("\\.","_");
}
public List<String> getCurrentInputs(){
......
import mxnet as mx
import logging
import os
import shutil
import h5py
import sys
class ${tc.fileNameWithoutEnding}:
<#list tc.architectureInputs as input>
${input} = None
</#list>
<#list tc.architectureOutputs as output>
${output} = None
</#list>
outputGroup_ = None
module_ = None
begin_epoch_ = 0
train_iter_ = None
test_iter_ = None
context_ = None
checkpoint_period_ = 1
_data_dir_ = "data/${tc.fullArchitectureName}/"
_model_dir_ = "model/${tc.fullArchitectureName}/"
_model_prefix_ = "${tc.architectureName}"
_input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}]
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})</#list>]
_output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}]
def __init__(self, context=mx.gpu()):
self.context_ = context
self.construct()
self.outputGroup_ = mx.symbol.Group([${tc.join(tc.architectureOutputs, ",", "self.", "")}])
self.module_ = mx.mod.Module(symbol=self.outputGroup_,
data_names=self._input_names_,
label_names=self._output_names_,
context=self.context_)
def load(self):
lastEpoch = 0
param_file = None
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
if param_file != None:
logging.info("Loading checkpoint: " + param_file)
self.begin_epoch_ = lastEpoch
self.module_.load(prefix=self._model_dir_ + self._model_prefix_,
epoch=lastEpoch,
data_names=self._input_names_,
label_names=self._output_names_,
context=self.context_)
def getH5ArrayIter(self, batch_size):
train_path = self._data_dir_ + "train.h5"
test_path = self._data_dir_ + "test.h5"
if os.path.isfile(train_path):
train_file = h5py.File(train_path, 'r')
if self._input_names_[0] in train_file and self._output_names_[0] in train_file:
train_iter = mx.io.NDArrayIter(train_file[self._input_names_[0]],
train_file[self._output_names_[0]],
batch_size=batch_size,
data_name=self._input_names_[0],
label_name=self._output_names_[0])
else:
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the datasets: "
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'")
sys.exit(1)
test_iter = None
if os.path.isfile(test_path):
test_file = h5py.File(test_path, 'r')
if self._input_names_[0] in test_file and self._output_names_[0] in test_file:
test_iter = mx.io.NDArrayIter(test_file[self._input_names_[0]],
test_file[self._output_names_[0]],
batch_size=batch_size,
data_name=self._input_names_[0],
label_name=self._output_names_[0])
else:
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the datasets: "
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'")
sys.exit(1)
else:
logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
return train_iter, test_iter
else:
logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
sys.exit(1)
def train(self, batch_size,
train_iter=None,
test_iter=None,
num_epoch=10,
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True):
if train_iter == None:
train_iter, test_iter = self.getH5ArrayIter(batch_size)
if load_checkpoint:
self.load()
else:
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
try:
os.makedirs(self._model_dir_)
except OSError:
if not os.path.isdir(self._model_dir_):
raise
self.module_.fit(
train_data=train_iter,
eval_data=test_iter,
optimizer=optimizer,
optimizer_params=optimizer_params,
batch_end_callback=mx.callback.Speedometer(batch_size),
epoch_end_callback=mx.callback.do_checkpoint(prefix=self._model_dir_ + self._model_prefix_, period=self.checkpoint_period_),
begin_epoch=self.begin_epoch_,
num_epoch=num_epoch + self.begin_epoch_)
def construct(self):
${tc.include(tc.architecture.body)}
\ No newline at end of file
<#if tc.targetLanguage == ".py">
import mxnet as mx
import logging
import os
import errno
import shutil
import numpy as np
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
logging.basicConfig(level=logging.DEBUG)
class ${tc.fileNameWithoutEnding}:
<#list tc.architectureInputs as input>
${input} = None
</#list>
<#list tc.architectureOutputs as output>
${output} = None
</#list>
Module = None
_checkpoint_dir = 'checkpoints/'
def load(self):
self.Module.load(prefix=self._checkpoint_dir)
self.Module.bind(for_training=False,
data_shapes=[('data', (1,3,224,224))],
label_shapes=self.Module._label_shapes)
def predict(self, image):
# compute the predict probabilities
self.Module.forward(Batch([mx.nd.array(image)]))
prob = self.Module.get_outputs()[0].asnumpy()
# top-5
prob = np.squeeze(prob)
return np.argsort(prob)[::-1]
def train(self, train_iter, test_iter, batch_size, optimizer, num_epoch, checkpoint_period):
shutil.rmtree(self._checkpoint_dir)
try:
os.makedirs(self._checkpoint_dir)
except OSError:
if not os.path.isdir(self._checkpoint_dir):
raise
self.Module.fit(
train_data=train_iter,
eval_data=test_iter,
optimizer=optimizer,
batch_end_callback=mx.callback.Speedometer(batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(prefix=self._checkpoint_dir+'${tc.architecture.name}', period=checkpoint_period),
num_epoch=num_epoch)
def __init__(self, context=mx.gpu()):
${tc.include(tc.architecture.body)}
self.Module = mx.mod.Module(symbol=mx.symbol.Group([${tc.join(tc.architectureOutputs, ",", "self.", "")}]),
data_names=[${tc.join(tc.architectureInputs, ",", "'", "'")}],
label_names=[${tc.join(tc.architectureOutputs, ",", "'", "_label'")}],
context=context)
<#elseif tc.targetLanguage == ".cpp">
#ifndef ${tc.fileNameWithoutEnding?upper_case}
#define ${tc.fileNameWithoutEnding?upper_case}
#include "mxnet-cpp/MxNetCpp.h"
using namespace std;
using namespace mxnet::cpp;
class ${tc.fileNameWithoutEnding}{
bool m_isTrained;
<#list tc.architectureInputs as input>
Symbol m_${input}; //change to ${input}
</#list>
<#list tc.architectureOutputs as output>
Symbol m_${output};
</#list>
Module m_module;
public:
${tc.fileNameWithoutEnding}(Context context = Context::gpu());
void predict();
void train();
<#list tc.architectureInputs as input>
Symbol get${input?capitalize}();
</#list>
<#list tc.architectureOutputs as output>
Symbol get${output?capitalize}();
</#list>
Module getModule();
};
<#list tc.architectureInputs as input>
Symbol ${tc.fileNameWithoutEnding}::get${input?capitalize}(){
return m_${input};
}
</#list>
<#list tc.architectureOutputs as output>
Symbol ${tc.fileNameWithoutEnding}::get${output?capitalize}(){
return m_${output};
}
</#list>
Module ${tc.fileNameWithoutEnding}::getModule(){
return m_module;
}
${tc.fileNameWithoutEnding}::${tc.fileNameWithoutEnding}(){
${tc.include(tc.architecture.body)}
auto _group = Operator("Group")
.SetInput("data", {${tc.join(tc.architectureOutputs, ",", "m_", "")}});
.CreateSymbol();
m_module = Module(symbol=group),
data_names=[${tc.join(tc.architectureInputs, ",", "'", "'")}],
label_names=[${tc.join(tc.architectureOutputs, ",", "'", "_label'")}],
context=context);
}
#endif
</#if>
\ No newline at end of file
......@@ -10,12 +10,12 @@
<#list tc.architecture.outputs as output>
<#assign shape = output.definition.type.dimensions>
<#if shape?size == 1>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCol(CNN_${tc.getName(output)}, {${shape[0]?c}});
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCol(CNN_${tc.getName(output)}, std::vector<int> {${shape[0]?c}});
</#if>
<#if shape?size == 2>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToMat(CNN_${tc.getName(output)}, {${shape[0]?c}, ${shape[1]?c}});
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToMat(CNN_${tc.getName(output)}, std::vector<int> {${shape[0]?c}, ${shape[1]?c}});
</#if>
<#if shape?size == 3>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(CNN_${tc.getName(output)}, {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}});
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(CNN_${tc.getName(output)}, std::vector<int> {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}});
</#if>
</#list>
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment