Commit 50f5124b authored by Christian Fuß's avatar Christian Fuß
Browse files

shifted logic to generate additional Unroll layers to CNNArchLang

parent b6332ec9
Pipeline #176015 failed with stages
in 3 minutes and 36 seconds
...@@ -102,38 +102,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -102,38 +102,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
} }
public void include(UnrollSymbol unrollElement, Writer writer, NetDefinitionMode netDefinitionMode){ public void include(UnrollSymbol unrollElement, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement(); include(unrollElement.getBody(), writer, netDefinitionMode);
//setCurrentElement(unrollElement);
if(unrollElement.getBody().getElements().get(0).isInput()) {
include(unrollElement.getBody().getElements().get(0).getResolvedThis().get(), writer, netDefinitionMode);
}
//System.err.println("TIME: " + unrollElement.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get());
int timestep = 0;//unrollElement.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get();
while (timestep < unrollElement.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get()) {
System.err.println("i: " + timestep);
for (ArchitectureElementSymbol element : unrollElement.getBody().getElements()) {
previousElement = getCurrentElement();
setCurrentElement(element);
if (element.isAtomic() && !element.isInput() && !element.isOutput()) {
String templateName = element.getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
} else {
if(element.isOutput()) {
include(element.getResolvedThis().get(), writer, netDefinitionMode);
}
}
}
timestep++;
}
setCurrentElement(previousElement);
} }
public void include(CompositeElementSymbol compositeElement, Writer writer, NetDefinitionMode netDefinitionMode){ public void include(CompositeElementSymbol compositeElement, Writer writer, NetDefinitionMode netDefinitionMode){
...@@ -218,10 +187,14 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -218,10 +187,14 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public Set<String> getUnrollOutputNames(UnrollSymbol unroll) { public Set<String> getUnrollOutputNames(UnrollSymbol unroll) {
Set<String> outputNames = new LinkedHashSet<>(); Set<String> outputNames = new LinkedHashSet<>();
for (ArchitectureElementSymbol element : unroll.getBody().getLastAtomicElements()) { int timestep = 0;//unroll.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get()
if (element.isOutput()) { while (timestep < unroll.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get()) {
outputNames.add(getName(element)); for (ArchitectureElementSymbol element : unroll.getBody().getLastAtomicElements()) {
if (element.isOutput()) {
outputNames.add(getName(element));
}
} }
timestep++;
} }
outputNames.addAll(getStreamLayerVariableMembers(unroll.getBody(), "1", true).keySet()); outputNames.addAll(getStreamLayerVariableMembers(unroll.getBody(), "1", true).keySet());
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member], " * ")}) vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member], " * ")})
</#list> </#list>
<#list tc.architecture.outputs as output> <#list tc.architecture.outputs as output>
<#if tc.getName(output)??>
vector<float> ${tc.getName(output)}(${tc.join(output.ioDeclaration.type.dimensions, " * ")}); vector<float> ${tc.getName(output)}(${tc.join(output.ioDeclaration.type.dimensions, " * ")});
</#if>
</#list> </#list>
<#list tc.architecture.streams as stream> <#list tc.architecture.streams as stream>
...@@ -25,6 +27,7 @@ ${tc.include(unroll, "CPP_INLINE")} ...@@ -25,6 +27,7 @@ ${tc.include(unroll, "CPP_INLINE")}
</#list> </#list>
<#list tc.architecture.outputs as output> <#list tc.architecture.outputs as output>
<#if tc.getName(output)??>
<#assign shape = output.ioDeclaration.type.dimensions> <#assign shape = output.ioDeclaration.type.dimensions>
<#if shape?size == 1> <#if shape?size == 1>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCol(${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}}); ${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCol(${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}});
...@@ -35,4 +38,5 @@ ${tc.include(unroll, "CPP_INLINE")} ...@@ -35,4 +38,5 @@ ${tc.include(unroll, "CPP_INLINE")}
<#if shape?size == 3> <#if shape?size == 3>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}}); ${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}});
</#if> </#if>
</#if>
</#list> </#list>
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
${member} = mx.nd.zeros((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context) ${member} = mx.nd.zeros((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context)
</#list> </#list>
<#list tc.architecture.outputs as output> <#list tc.architecture.outputs as output>
<#if tc.getName(output)??>
${tc.getName(output)} = mx.nd.zeros((${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context) ${tc.getName(output)} = mx.nd.zeros((${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
</#if>
</#list> </#list>
<#list tc.architecture.streams as stream> <#list tc.architecture.streams as stream>
......
architecture RNNencdec(max_length=5, vocabulary_size=30000, hidden_size=1000){ architecture RNNencdec(max_length=5, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target[3] def output Q(0:1)^{vocabulary_size} target[6]
source -> target[0]; source -> Softmax() -> target[0];
timed <t=2> BeamSearchStart(max_length=5) { timed <t=0> BeamSearchStart(max_length=5) {
target[t-1] -> target[t] ->
FullyConnected(units=vocabulary_size) -> FullyConnected(units=vocabulary_size) ->
Softmax() -> Softmax() ->
target[t] target[t+1]
}; };
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment