diff --git a/src/main/resources/templates/gluon/elements/Convolution.ftl b/src/main/resources/templates/gluon/elements/Convolution.ftl index e0098b280f752394b8dd211ad47c906a9502213f..73e43bc3c915a21a8b869055859533ee48b081b1 100644 --- a/src/main/resources/templates/gluon/elements/Convolution.ftl +++ b/src/main/resources/templates/gluon/elements/Convolution.ftl @@ -3,10 +3,18 @@ <#if element.padding??> self.${element.name}padding = Padding(padding=(${tc.join(element.padding, ",")})) </#if> +<#if element.partOfUnroll> + self.${element.name} = gluon.nn.Conv2D(channels=${element.channels?c}, + kernel_size=(${tc.join(element.kernel, ",")}), + strides=(${tc.join(element.stride, ",")}), + use_bias=${element.noBias?string("False", "True")}, + params=Net_${element.unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params()) +<#else> self.${element.name} = gluon.nn.Conv2D(channels=${element.channels?c}, kernel_size=(${tc.join(element.kernel, ",")}), strides=(${tc.join(element.stride, ",")}), use_bias=${element.noBias?string("False", "True")}) +</#if> <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> <#if element.padding??> diff --git a/src/main/resources/templates/gluon/elements/Embedding.ftl b/src/main/resources/templates/gluon/elements/Embedding.ftl index 692d9c9993ca38e220fb3671637b09b0d0ce7823..51313968c1c6c8a6f1e2f7b28d3442de859461e4 100644 --- a/src/main/resources/templates/gluon/elements/Embedding.ftl +++ b/src/main/resources/templates/gluon/elements/Embedding.ftl @@ -1,15 +1,12 @@ <#assign input = element.inputs[0]> <#if mode == "ARCHITECTURE_DEFINITION"> <#if element.partOfUnroll> - ${element.name} = Net_1.${element.name}(${input}) + self.${element.name} = gluon.nn.Embedding(input_dim=${element.inputDim?c}, output_dim=${element.outputDim?c}, + params=Net_${element.unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params()) <#else> self.${element.name} = gluon.nn.Embedding(input_dim=${element.inputDim?c}, output_dim=${element.outputDim?c}) </#if> <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> - <#if element.partOfUnroll> - ${element.name} = Net_1.${element.name}(${input}) - <#else> ${element.name} = self.${element.name}(${input}) - </#if> </#if> \ No newline at end of file diff --git a/src/main/resources/templates/gluon/elements/FullyConnected.ftl b/src/main/resources/templates/gluon/elements/FullyConnected.ftl index 6b85375e46f752a8e554a3351cd71354273cdef5..90ba739cc7a81d21dfe6728bfa133150bd04282d 100644 --- a/src/main/resources/templates/gluon/elements/FullyConnected.ftl +++ b/src/main/resources/templates/gluon/elements/FullyConnected.ftl @@ -4,8 +4,8 @@ <#assign flatten = element.flatten?string("True","False")> <#if mode == "ARCHITECTURE_DEFINITION"> <#if element.partOfUnroll> - <#assign unrollIndex = element.unrollIndex> - self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten}, params=Net_${unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params()) + self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten}, + params=Net_${element.unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params()) <#else> self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten}) </#if> diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java index e00120d9cad518c3e16a1951a0a8165cb273b615..9c6c68ca7dd97fc0ddf3216320cdfacc2f85019d 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java @@ -146,7 +146,7 @@ public class GenerationTest extends AbstractSymtabTest { @Test public void testRNNtest() throws IOException, TemplateException { Log.getFindings().clear(); - String[] args = {"-m", "src/test/resources/architectures", "-r", "RNNtest", "-o", "./target/generated-sources-cnnarch/"}; + String[] args = {"-m", "src/test/resources/valid_tests", "-r", "RNNtest", "-o", "./target/generated-sources-cnnarch/"}; CNNArch2GluonCli.main(args); assertTrue(Log.getFindings().isEmpty()); } diff --git a/src/test/resources/architectures/RNNtest.cnna b/src/test/resources/architectures/RNNtest.cnna deleted file mode 100644 index 75aa1ff7184fb08f05323447e6d9848c8ad0fbf3..0000000000000000000000000000000000000000 --- a/src/test/resources/architectures/RNNtest.cnna +++ /dev/null @@ -1,20 +0,0 @@ -architecture RNNtest(max_length=50, vocabulary_size=30001, hidden_size=500) { - def input Q(-oo:oo)^{max_length, vocabulary_size} source[2] - def output Q(-oo:oo)^{max_length, vocabulary_size} target[2] - - layer RNN(units=hidden_size, layers=2) encoder; - layer RNN(units=hidden_size, layers=2) decoder; - - source[0] -> - encoder; - - encoder.output -> - target[0]; - - encoder.state -> - decoder.state; - - source[1] -> - decoder -> - target[1]; -} diff --git a/src/test/resources/architectures/data_paths.txt b/src/test/resources/architectures/data_paths.txt index 947758e6ea06a409fa1ef064e031fab11ba3db2c..25531c98c1ec86faa46e5e2c9a846908e0073f09 100644 --- a/src/test/resources/architectures/data_paths.txt +++ b/src/test/resources/architectures/data_paths.txt @@ -3,5 +3,4 @@ CifarClassifierNetwork data/CifarClassifierNetwork ThreeInputCNN_M14 data/ThreeInputCNN_M14 Alexnet data/Alexnet MultipleOutputs data/MultipleOutputs -ResNeXt50 data/ResNeXt50 -RNNtest data/RNNtest \ No newline at end of file +ResNeXt50 data/ResNeXt50 \ No newline at end of file diff --git a/src/test/resources/valid_tests/RNNencdec.cnna b/src/test/resources/valid_tests/RNNencdec.cnna index cc105135d187419d0159435e06ab0039f7895118..fa9306e941976b4b6d71044c96169fce7683d8f0 100644 --- a/src/test/resources/valid_tests/RNNencdec.cnna +++ b/src/test/resources/valid_tests/RNNencdec.cnna @@ -1,13 +1,26 @@ -architecture RNNencdec(max_length=5, vocabulary_size=30000, hidden_size=1000){ - def input Q(0:1)^{vocabulary_size} source - def output Q(0:1)^{vocabulary_size} target[6] - - source -> Softmax() -> target[0]; - - timed <t=0> BeamSearchStart(max_length=5) { - source -> - FullyConnected(units=vocabulary_size) -> - Softmax() -> - target[t+1] - }; - } \ No newline at end of file +architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){ + def input N(0:29999)^{50} source + def output Q(0:29999)^{1} target[50] + + layer LSTM(units=hidden_size) encoder; + + source -> + Embedding(output_dim=hidden_size) -> + encoder; + + layer LSTM(units=hidden_size) decoder; + + 1 -> target[0]; + + encoder.state -> decoder.state; + + timed<t=1> BeamSearchStart(max_length=50) { + target[t-1] -> + Embedding(output_dim=hidden_size) -> + decoder -> + FullyConnected(units=vocabulary_size) -> + Softmax() -> + target[t] + }; + +} \ No newline at end of file diff --git a/src/test/resources/valid_tests/RNNtest.cnna b/src/test/resources/valid_tests/RNNtest.cnna new file mode 100644 index 0000000000000000000000000000000000000000..d2f07b1f1d2c31c2deeed9284b1b100ab94b6ad9 --- /dev/null +++ b/src/test/resources/valid_tests/RNNtest.cnna @@ -0,0 +1,14 @@ +architecture RNNtest(max_length=50, vocabulary_size=30000, hidden_size=1000){ + def input Q(0:1)^{vocabulary_size} source + def output Q(0:1)^{vocabulary_size} target[5] + + source -> Softmax() -> target[0]; + + timed <t=1> BeamSearchStart(max_length=5){ + (target[0] | target[t-1]) -> + Concatenate() -> + FullyConnected(units=30000) -> + Softmax() -> + target[t] + }; + } \ No newline at end of file diff --git a/src/test/resources/valid_tests/data_paths.txt b/src/test/resources/valid_tests/data_paths.txt index ed33971056f21e2940ff8ec11af69729d9cb7242..8f5f6e2aaa98092dcf567c7922e7a4f0c9a756af 100644 --- a/src/test/resources/valid_tests/data_paths.txt +++ b/src/test/resources/valid_tests/data_paths.txt @@ -5,4 +5,5 @@ Alexnet data/Alexnet ResNeXt50 data/ResNeXt50 MultipleStreams data/MultipleStreams Invariant data/Invariant -RNNencdec data/RNNencdec \ No newline at end of file +RNNencdec data/RNNencdec +RNNtest data/RNNtest \ No newline at end of file