Commit f4ecf597 authored by Christian Fuß's avatar Christian Fuß
Browse files

various changes to templates in order to support unrolls

parent 50f5124b
Pipeline #177577 failed with stages
in 31 seconds
......@@ -48,7 +48,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
getTemplateConfiguration().processTemplate(ftlContext, templatePath, writer);
}
public void include(VariableSymbol element, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(VariableSymbol element, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(element);
......@@ -66,13 +66,13 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
else {
include(element.getResolvedThis().get(), writer, netDefinitionMode);
include(element.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ConstantSymbol constant, Writer writer, NetDefinitionMode netDefinitionMode) {
public void include(ConstantSymbol constant, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode) {
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(constant);
......@@ -80,78 +80,85 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Const", writer, netDefinitionMode);
}
else {
include(constant.getResolvedThis().get(), writer, netDefinitionMode);
include(constant.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(LayerSymbol layer, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
getCurrentElement().setPartOfUnroll(partOfUnroll);
if (layer.isAtomic()){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
else {
include(layer.getResolvedThis().get(), writer, netDefinitionMode);
include(layer.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(UnrollSymbol unrollElement, Writer writer, NetDefinitionMode netDefinitionMode){
include(unrollElement.getBody(), writer, netDefinitionMode);
public void include(UnrollSymbol unrollElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
include(unrollElement.getBody(), partOfUnroll, writer, netDefinitionMode);
String templateName = unrollElement.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
public void include(CompositeElementSymbol compositeElement, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(CompositeElementSymbol compositeElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, writer, netDefinitionMode);
include(element, partOfUnroll, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode);
include((CompositeElementSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
}
else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer, netDefinitionMode);
include((LayerSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
}
else if (architectureElement instanceof ConstantSymbol) {
include((ConstantSymbol) architectureElement, writer, netDefinitionMode);
include((ConstantSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
}
else {
include((VariableSymbol) architectureElement, writer, netDefinitionMode);
include((VariableSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
}
}
public void include(ArchitectureElementSymbol architectureElementSymbol, String netDefinitionMode) {
include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode));
include(architectureElementSymbol, false, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(UnrollSymbol unrollSymbol, String netDefinitionMode) {
include(unrollSymbol, NetDefinitionMode.fromString(netDefinitionMode));
public void include(ArchitectureElementSymbol architectureElementSymbol, boolean partOfUnroll, String netDefinitionMode) {
include(architectureElementSymbol, partOfUnroll, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(UnrollSymbol unrollSymbol, boolean partOfUnroll, String netDefinitionMode) {
include(unrollSymbol, partOfUnroll, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElement, NetDefinitionMode netDefinitionMode){
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include(architectureElement, getWriter(), netDefinitionMode);
include(architectureElement, partOfUnroll, getWriter(), netDefinitionMode);
}
public void include(UnrollSymbol unroll, NetDefinitionMode netDefinitionMode){
public void include(UnrollSymbol unroll, boolean partOfUnroll, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include(unroll, getWriter(), netDefinitionMode);
include(unroll, partOfUnroll, getWriter(), netDefinitionMode);
}
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream) {
......@@ -187,14 +194,10 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public Set<String> getUnrollOutputNames(UnrollSymbol unroll) {
Set<String> outputNames = new LinkedHashSet<>();
int timestep = 0;//unroll.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get()
while (timestep < unroll.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get()) {
for (ArchitectureElementSymbol element : unroll.getBody().getLastAtomicElements()) {
if (element.isOutput()) {
outputNames.add(getName(element));
}
for (ArchitectureElementSymbol element : unroll.getBody().getElements()) {
if (element.isOutput()) {
outputNames.add(getName(element));
}
timestep++;
}
outputNames.addAll(getStreamLayerVariableMembers(unroll.getBody(), "1", true).keySet());
......
......@@ -9,10 +9,12 @@ from CNNNet_${tc.fullArchitectureName} import Net_${stream?index}
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${unroll?index}
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${tc.architecture.streams?size + body?index}
</#if>
</#list>
</#list>
class ${tc.fileNameWithoutEnding}:
_model_dir_ = "model/${tc.componentName}/"
......@@ -68,12 +70,14 @@ class ${tc.fileNameWithoutEnding}:
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
self.networks[${unroll?index}] = Net_${unroll?index}(data_mean=data_mean, data_std=data_std)
self.networks[${unroll?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${unroll?index}].hybridize()
self.networks[${unroll?index}](<#list tc.getUnrollInputDimensions(unroll) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
self.networks[${tc.architecture.streams?size + body?index}] = Net_${tc.architecture.streams?size + body?index}(data_mean=data_mean, data_std=data_std)
self.networks[${tc.architecture.streams?size + body?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${tc.architecture.streams?size + body?index}].hybridize()
self.networks[${tc.architecture.streams?size + body?index}](<#list tc.getStreamInputDimensions(body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
</#if>
</#list>
</#list>
if not os.path.exists(self._model_dir_):
......
......@@ -95,21 +95,18 @@ ${tc.include(stream, "FORWARD_FUNCTION")}
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
class Net_${tc.architecture.streams?size + unroll?index}(gluon.HybridBlock):
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
class Net_${tc.architecture.streams?size + body?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_${unroll?index}, self).__init__(**kwargs)
super(Net_${tc.architecture.streams?size + body?index}, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
${tc.include(unroll, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getUnrollInputNames(unroll), ", ")}):
outputs = []
${tc.include(unroll, "FORWARD_FUNCTION")}
<#if tc.getUnrollOutputNames(unroll)?size gt 1>
return tuple(outputs)
<#else>
return outputs[0]
</#if>
${tc.include(body, false, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(body), ", ")}):
${tc.include(body, true, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(body), ", ")}
</#if>
</#list>
</#list>
......@@ -189,8 +189,9 @@ class ${tc.fileNameWithoutEnding}:
<#list tc.architectureOutputs as output_name>
mx.nd.argmax(${output_name}, axis=1)<#sep>,
</#list>
]
]
<#include "elements/BeamSearchStart.ftl">
metric.update(preds=predictions, labels=labels)
train_metric_score = metric.get()[1]
......@@ -224,6 +225,7 @@ class ${tc.fileNameWithoutEnding}:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
for i, network in self._networks.items():
network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
......
<#-- BeamSearchPredictions = [
<#list tc.architectureOutputs as output_name>
mx.nd.topk(${output_name}, axis=1, k=4)<#sep>,
</#list>
]
logging.info("BeamSearch indices: " + str(BeamSearchPredictions))-->
\ No newline at end of file
<#-- This template is not used if the followiing architecture element is an output. See Output.ftl -->
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = Softmax()
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
<#if element.partOfUnroll>
${element.name} = Net_1.${element.name}(${input})
<#else>
self.${element.name} = gluon.nn.Embedding(input_dim=${element.inputDim?c}, output_dim=${element.outputDim?c})
<#include "OutputShape.ftl">
</#if>
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
<#if element.partOfUnroll>
${element.name} = Net_1.${element.name}(${input})
<#else>
${element.name} = self.${element.name}(${input})
</#if>
</#if>
\ No newline at end of file
......@@ -3,8 +3,16 @@
<#assign use_bias = element.noBias?string("False","True")>
<#assign flatten = element.flatten?string("True","False")>
<#if mode == "ARCHITECTURE_DEFINITION">
<#if (element.partOfUnroll && false)>
${element.name} = Net_1().${element.name}(${input})
<#else>
self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten})
<#include "OutputShape.ftl">
</#if>
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
<#if (element.partOfUnroll && false)>
${element.name} = Net_1().${element.name}(${input})
<#else>
${element.name} = self.${element.name}(${input})
</#if>
</#if>
\ No newline at end of file
<#-- This template is not used if the followiing architecture element is an output. See Output.ftl -->
<#-- This template is not used if the following architecture element is an output. See Output.ftl -->
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = Softmax()
......
<#list tc.architecture.inputs as input>
<#if tc.getName(input)??>
vector<float> ${tc.getName(input)} = CNNTranslator::translate(${input.name}<#if input.arrayAccess.isPresent()>[${input.arrayAccess.get().intValue.get()?c}]</#if>);
</#if>
</#list>
<#list tc.getLayerVariableMembers("1")?keys as member>
vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member], " * ")})
......@@ -19,12 +21,14 @@ ${tc.include(stream, "CPP_INLINE")}
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
_predictor_${unroll?index}_.predict(${tc.join(tc.getUnrollInputNames(unroll), ", ")}, ${tc.join(tc.getUnrollOutputNames(unroll), ", ")});
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
_predictor_${tc.architecture.streams?size + body?index}_.predict(${tc.join(tc.getStreamInputNames(body), ", ")}, ${tc.join(tc.getStreamOutputNames(body), ", ")});
<#else>
${tc.include(unroll, "CPP_INLINE")}
${tc.include(unroll, true, "CPP_INLINE")}
</#if>
</#list>
</#list>
<#list tc.architecture.outputs as output>
<#if tc.getName(output)??>
......
......@@ -16,9 +16,11 @@ ${tc.include(stream, "PYTHON_INLINE")}
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
${tc.join(tc.getUnrollOutputNames(unroll), ", ")} = self._networks[${unroll?index}](${tc.join(tc.getUnrollInputNames(unroll), ", ")})
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
${tc.join(tc.getStreamOutputNames(body), ", ")} = self._networks[${tc.architecture.streams?size + body?index}](${tc.join(tc.getStreamInputNames(body), ", ")})
<#else>
${tc.include(unroll, "PYTHON_INLINE")}
${tc.include(unroll, true, "PYTHON_INLINE")}
</#if>
</#list>
</#list>
\ No newline at end of file
......@@ -180,8 +180,8 @@ class CNNSupervisedTrainer_Alexnet:
predictions = [
mx.nd.argmax(predictions_, axis=1)
]
mx.nd.argmax(predictions_, axis=1)]
metric.update(preds=predictions, labels=labels)
train_metric_score = metric.get()[1]
......@@ -210,6 +210,7 @@ class CNNSupervisedTrainer_Alexnet:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
for i, network in self._networks.items():
network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
......
......@@ -180,8 +180,8 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
predictions = [
mx.nd.argmax(softmax_, axis=1)
]
mx.nd.argmax(softmax_, axis=1)]
metric.update(preds=predictions, labels=labels)
train_metric_score = metric.get()[1]
......@@ -210,6 +210,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
for i, network in self._networks.items():
network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
......
......@@ -180,8 +180,8 @@ class CNNSupervisedTrainer_VGG16:
predictions = [
mx.nd.argmax(predictions_, axis=1)
]
mx.nd.argmax(predictions_, axis=1)]
metric.update(preds=predictions, labels=labels)
train_metric_score = metric.get()[1]
......@@ -210,6 +210,7 @@ class CNNSupervisedTrainer_VGG16:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
for i, network in self._networks.items():
network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
......
......@@ -5,7 +5,7 @@ architecture RNNencdec(max_length=5, vocabulary_size=30000, hidden_size=1000){
source -> Softmax() -> target[0];
timed <t=0> BeamSearchStart(max_length=5) {
target[t] ->
source ->
FullyConnected(units=vocabulary_size) ->
Softmax() ->
target[t+1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment