Commit 26a7c749 authored by Christian Fuß's avatar Christian Fuß
Browse files

adjusted a few minor things

parent 4c58d527
Pipeline #172965 failed with stages
......@@ -26,6 +26,7 @@ import de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import java.io.Writer;
import java.util.*;
......@@ -39,6 +40,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer, NetDefinitionMode netDefinitionMode){
System.err.println("include called. templateName: " + templateWithoutFileEnding);
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
......@@ -100,14 +102,15 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(unrollElement);
if(unrollElement.getDeclaration().getBody().getElements().get(0).isInput()) {
include(unrollElement.getDeclaration().getBody().getElements().get(0).getResolvedThis().get(), writer, netDefinitionMode);
if(unrollElement.getBody().getElements().get(0).isInput()) {
include(unrollElement.getBody().getElements().get(0).getResolvedThis().get(), writer, netDefinitionMode);
}
for(int i=0; i < (int)unrollElement.getDeclaration().getParameters().get(0).getExpression().getValue().get(); i++) {
for(int i = 0; i < (int)unrollElement.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH_NAME).get(); i++) {
for (ArchitectureElementSymbol element : unrollElement.getDeclaration().getBody().getElements()) {
System.err.println("i: " + i);
for (ArchitectureElementSymbol element : unrollElement.getBody().getElements()) {
previousElement = getCurrentElement();
setCurrentElement(element);
......@@ -156,9 +159,9 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
public void include(ArchitectureElementSymbol architectureElementSymbol, String netDefinitionMode) {
for(int i=0; i < ((ASTStream)architectureElementSymbol.getAstNode().get()).getElementsList().size(); i++){
System.err.println(((ASTStream)architectureElementSymbol.getAstNode().get()).getElementsList().get(i).getSymbol().getName());
}
//System.err.println("INCLUDE: " + ((SerialCompositeElementSymbol)architectureElementSymbol).getElements().toString());
System.err.println(architectureElementSymbol.getSpannedScope().getSpanningSymbol().get().getClass());
System.err.println("isUnroll? " + (architectureElementSymbol.getSpannedScope().getSpanningSymbol().get() instanceof UnrollSymbol));
include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode));
}
......@@ -177,15 +180,8 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
List<String> names = new ArrayList<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
if(element instanceof UnrollSymbol){
for(ArchitectureElementSymbol sublayer: ((UnrollSymbol) element).getDeclaration().getBody().getFirstAtomicElements()){
names.add(getName(sublayer));
}
}else {
names.add(getName(element));
}
names.add(getName(element));
}
return names;
}
......@@ -193,15 +189,26 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
List<String> names = new ArrayList<>();
for (ArchitectureElementSymbol element : stream.getLastAtomicElements()) {
if(element instanceof UnrollSymbol){
for(ArchitectureElementSymbol sublayer: ((UnrollSymbol) element).getDeclaration().getBody().getLastAtomicElements()){
names.add(getName(sublayer));
}
}else {
names.add(getName(element));
}
names.add(getName(element));
}
return names;
}
public List<String> getUnrollInputNames(UnrollSymbol unroll) {
List<String> names = new ArrayList<>();
for (ArchitectureElementSymbol element : unroll.getFirstAtomicElements()) {
names.add(getName(element));
}
return names;
}
public List<String> getUnrollOutputNames(UnrollSymbol unroll) {
List<String> names = new ArrayList<>();
for (ArchitectureElementSymbol element : unroll.getLastAtomicElements()) {
names.add(getName(element));
}
return names;
}
}
......@@ -96,4 +96,24 @@ ${tc.include(stream, "FORWARD_FUNCTION")}
return outputs[0]
</#if>
</#if>
</#list>
\ No newline at end of file
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isNetwork()>
class Net_${unroll?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_${unroll?index}, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
${tc.include(unroll, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getUnrollInputNames(unroll), ", ")}):
outputs = []
${tc.include(unroll, "FORWARD_FUNCTION")}
<#if tc.getUnrollOutputNames(unroll)?size gt 1>
return tuple(outputs)
<#else>
return outputs[0]
</#if>
</#if>
</#list>
architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
architecture RNNencdec(max_length=5, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target
unroll BeamSearchStart(max_length=max_length) {
timed <t> BeamSearchStart(max_length=5) {
source ->
FullyConnected(units=17) ->
Softmax() ->
FullyConnected(units=vocabulary_size) ->
Softmax() ->
target
};
}
\ No newline at end of file
......@@ -4,4 +4,5 @@ ThreeInputCNN_M14 data/ThreeInputCNN_M14
Alexnet data/Alexnet
ResNeXt50 data/ResNeXt50
MultipleStreams data/MultipleStreams
Invariant data/Invariant
\ No newline at end of file
Invariant data/Invariant
RNNencdec data/RNNencdec
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment