Skip to content
Snippets Groups Projects
Commit 50f5124b authored by Christian Fuß's avatar Christian Fuß
Browse files

shifted logic to generate additional Unroll layers to CNNArchLang

parent b6332ec9
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #176015 failed
......@@ -102,38 +102,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
public void include(UnrollSymbol unrollElement, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
//setCurrentElement(unrollElement);
if(unrollElement.getBody().getElements().get(0).isInput()) {
include(unrollElement.getBody().getElements().get(0).getResolvedThis().get(), writer, netDefinitionMode);
}
//System.err.println("TIME: " + unrollElement.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get());
int timestep = 0;//unrollElement.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get();
while (timestep < unrollElement.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get()) {
System.err.println("i: " + timestep);
for (ArchitectureElementSymbol element : unrollElement.getBody().getElements()) {
previousElement = getCurrentElement();
setCurrentElement(element);
if (element.isAtomic() && !element.isInput() && !element.isOutput()) {
String templateName = element.getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
} else {
if(element.isOutput()) {
include(element.getResolvedThis().get(), writer, netDefinitionMode);
}
}
}
timestep++;
}
setCurrentElement(previousElement);
include(unrollElement.getBody(), writer, netDefinitionMode);
}
public void include(CompositeElementSymbol compositeElement, Writer writer, NetDefinitionMode netDefinitionMode){
......@@ -218,10 +187,14 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public Set<String> getUnrollOutputNames(UnrollSymbol unroll) {
Set<String> outputNames = new LinkedHashSet<>();
for (ArchitectureElementSymbol element : unroll.getBody().getLastAtomicElements()) {
if (element.isOutput()) {
outputNames.add(getName(element));
int timestep = 0;//unroll.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get()
while (timestep < unroll.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get()) {
for (ArchitectureElementSymbol element : unroll.getBody().getLastAtomicElements()) {
if (element.isOutput()) {
outputNames.add(getName(element));
}
}
timestep++;
}
outputNames.addAll(getStreamLayerVariableMembers(unroll.getBody(), "1", true).keySet());
......
......@@ -5,7 +5,9 @@
vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member], " * ")})
</#list>
<#list tc.architecture.outputs as output>
<#if tc.getName(output)??>
vector<float> ${tc.getName(output)}(${tc.join(output.ioDeclaration.type.dimensions, " * ")});
</#if>
</#list>
<#list tc.architecture.streams as stream>
......@@ -25,6 +27,7 @@ ${tc.include(unroll, "CPP_INLINE")}
</#list>
<#list tc.architecture.outputs as output>
<#if tc.getName(output)??>
<#assign shape = output.ioDeclaration.type.dimensions>
<#if shape?size == 1>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCol(${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}});
......@@ -35,4 +38,5 @@ ${tc.include(unroll, "CPP_INLINE")}
<#if shape?size == 3>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}});
</#if>
</#if>
</#list>
......@@ -2,7 +2,9 @@
${member} = mx.nd.zeros((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context)
</#list>
<#list tc.architecture.outputs as output>
<#if tc.getName(output)??>
${tc.getName(output)} = mx.nd.zeros((${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
</#if>
</#list>
<#list tc.architecture.streams as stream>
......
architecture RNNencdec(max_length=5, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target[3]
def output Q(0:1)^{vocabulary_size} target[6]
source -> target[0];
source -> Softmax() -> target[0];
timed <t=2> BeamSearchStart(max_length=5) {
target[t-1] ->
timed <t=0> BeamSearchStart(max_length=5) {
target[t] ->
FullyConnected(units=vocabulary_size) ->
Softmax() ->
target[t]
target[t+1]
};
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment