diff --git a/README.md b/README.md index 3e6139fcbb9684662e2bbc9e50e22d14d99245c4..f3f1d77c095a24c008b65eabb0c4d4b248a7b7db 100644 --- a/README.md +++ b/README.md @@ -802,6 +802,23 @@ All predefined methods start with a capital letter and all constructed methods h * **padding** ({"valid", "same", "no_loss"}, optional, default="same3d"): One of "valid3d", "same3d" or "simple3d". "valid" means no padding. "same" results in padding the input such that the output has the same length as the original input divided by the stride (rounded up). "simple3d" results constant padding of size 1 (same as (1,1,1). UpConvolution3D also accepts tuples of form (height, widht, depth) as input. * **no_bias** (boolean, optional, default=false): Whether to disable the bias parameter. + +* **Reparameterize(pdf="normal")** + + Must be used inorder to model VAEs, β-VAEs and Conditional VAEs. Applies the Reparameterization Trick and samples a Code from the approximating Distribution. + + * **pdf** ({"normal"}, optional, default="normal"): + * normal: Takes in 2 Datastreams and applies the Reparameterization Trick for a normal distribution. + + +* **VectorQuantize(num_embeddings,beta=0.25)** + + Must be used inorder to model VQ-VAEs. Use this layer to quantize the pixels of incoming feature maps with a vector from the codebook. + + * **num_embeddings** (integer > 0, required): Number of vectors within the codebook. + * **beta**: (float, optional, default=0.25): Commitment cost factor that weights the commitment term of the VQ-VAE Loss function. "We found the resulting algorithm to be + quite robust to β, as the results did not vary for values of β ranging from 0.1 to 2.0" [van den Oord et al. 2017] + * **EpisodicMemory(replayMemoryStoreProb=1, maxStoredSamples=-1, memoryReplacementStrategy="replace_oldest", useReplay=true, replayInterval, replayBatchSize=-1, replaySteps, replayGradientSteps=1, useLocalAdaption=true, localAdaptionGradientSteps=1, localAdaptionK=1, queryNetDir=-1, queryNetPrefix=-1, queryNetNumInputs=1)** diff --git a/pom.xml b/pom.xml index 3a7d09b451531d43b9642f06be6a3dd9ff9395ac..0cd5c56e636c7b5a9d5a59d5a509009ef0be8937 100644 --- a/pom.xml +++ b/pom.xml @@ -382,4 +382,4 @@ - \ No newline at end of file + diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java index 2664bd5b889689084c7ca5923a42cc6407725f46..5388913346013cfd8b92aa496e0283cc9ba66a4a 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java @@ -11,18 +11,18 @@ package de.monticore.lang.monticar.cnnarch._symboltable; +import de.monticore.lang.monticar.cnnarch._cocos.CheckLayerPathParameter; import de.monticore.lang.monticar.cnnarch.helper.Utils; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables; -import de.monticore.lang.monticar.cnnarch._cocos.CheckLayerPathParameter; import de.monticore.symboltable.CommonScopeSpanningSymbol; import de.monticore.symboltable.Scope; import de.monticore.symboltable.Symbol; -import org.apache.commons.math3.ml.neuralnet.Network; -import java.lang.RuntimeException; -import java.lang.NullPointerException; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; public class ArchitectureSymbol extends CommonScopeSpanningSymbol { @@ -36,6 +36,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol { private String dataPath; private String weightsPath; private String componentName; + private ArchitectureSymbol auxiliaryArchitecture; private boolean AdaNet = false; //attribute for the path for custom python files private String customPyFilesPath; @@ -52,6 +53,14 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol { public void setLayerVariableDeclarations(List layerVariableDeclarations) { this.layerVariableDeclarations = layerVariableDeclarations; } + public void setAuxiliaryArchitecture(ArchitectureSymbol auxiliaryArchitecture){ + this.auxiliaryArchitecture = auxiliaryArchitecture; + } + + public ArchitectureSymbol getAuxiliaryArchitecture(){ + return this.auxiliaryArchitecture; + } + public String getAdaNetUtils(){return this.adaNetUtils;} public void setAdaNetUtils(String adaNetUtils){this.adaNetUtils=adaNetUtils;} public boolean containsAdaNet() { diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java index 8cdf99d40d00347677fd612321f465cebbb420cc..c39a67ea1b6db4edee61206f19751bc86cfa302f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java @@ -393,6 +393,25 @@ public enum Constraints { public String msgString() { return "an axis between 0 and 1"; } + }, + REPARAMETERIZE_PDFS { + @Override + public boolean isValid(ArchSimpleExpressionSymbol exp) { + Optional optString= exp.getStringValue(); + if (optString.isPresent()){ + if ( optString.get().equals(AllPredefinedLayers.PDF_NORMAL)){ + //|| optString.get().equals(AllPredefinedLayers.PDF_DIRICHLET)){ + return true; + } + } + return false; + } + @Override + protected String msgString() { + return AllPredefinedLayers.PDF_NORMAL; //+ " or " + //+ AllPredefinedLayers.PDF_DIRICHLET; + + } }; protected abstract boolean isValid(ArchSimpleExpressionSymbol exp); diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java index cfede4ba8865433ac1c3472b00734294e06a15a3..6f694c0a1f1ba88041c5ab613a86045db149bb7d 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java @@ -11,13 +11,11 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; import de.monticore.lang.monticar.ranges._ast.ASTRange; +import de.se_rwth.commons.Joiners; import de.se_rwth.commons.logging.Log; import org.jscience.mathematics.number.Rational; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.BinaryOperator; import java.util.stream.Stream; @@ -174,6 +172,63 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol } } + protected static enum HandlingSingleInputs { + ALLOWED, IGNORED, RESTRICTED + } + + /** + * Check if Inputs of Layer have the same shape + * @param inputTypes: List of input Types + * @param layer: curremt Layer + * @param handling: HandlingSingleInputs Enum Value, either ALLOWED, IGNORED or RESTRICTED + * ALLOWED will skip the Check if there are only one Input. + * IGNORED will print a message that the Layer is redundant, but the Input may still be passed + * RESTRICTED will throw an error. This indicates that the Layer does not allow + */ + protected static void errorIfMultipleInputShapesAreNotEqual(List inputTypes, LayerSymbol layer, HandlingSingleInputs handling) { + if (inputTypes.size() == 1){ + if (handling == HandlingSingleInputs.IGNORED) { + Log.warn(layer.getName() + " layer has only one input stream. Layer can be removed.", layer.getSourcePosition()); + } else if (handling == HandlingSingleInputs.RESTRICTED){ + Log.error(layer.getName() + " layer has only one input stream.", layer.getSourcePosition()); + } + } + else if (inputTypes.size() > 1){ + List heightList = new ArrayList<>(); + List widthList = new ArrayList<>(); + List channelsList = new ArrayList<>(); + for (ArchTypeSymbol shape : inputTypes){ + heightList.add(shape.getHeight()); + widthList.add(shape.getWidth()); + channelsList.add(shape.getChannels()); + } + int countEqualHeights = (int)heightList.stream().distinct().count(); + int countEqualWidths = (int)widthList.stream().distinct().count(); + int countEqualNumberOfChannels = (int)channelsList.stream().distinct().count(); + if (countEqualHeights != 1 || countEqualWidths != 1 || countEqualNumberOfChannels != 1){ + Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " + + "Shapes of all input streams must be equal. " + + "Input heights: " + Joiners.COMMA.join(heightList) + ". " + + "Input widths: " + Joiners.COMMA.join(widthList) + ". " + + "Number of input channels: " + Joiners.COMMA.join(channelsList) + ". " + , layer.getSourcePosition()); + } + } + } + + protected static void errorIfInputNotFlattened(List inputTypes, LayerSymbol layer) { + if (!inputTypes.isEmpty()) { + for (ArchTypeSymbol inputType : layer.getInputTypes()) { + int height = inputType.getHeight(); + int width = inputType.getWidth(); + if (height != 1 || width != 1) { + Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input." + + " Input layer must be flat, consider using a 'Flatten()' layer.", layer.getSourcePosition()); + } + } + } + } + //check input for convolution and pooling protected static void errorIfInputSmallerThanKernel(List inputTypes, LayerSymbol layer) { if (!inputTypes.isEmpty()) { @@ -203,7 +258,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol } } - //output type function for convolution and poolingee protected static List computeConvAndPoolOutputShape(ArchTypeSymbol inputType, LayerSymbol method, int channels) { if (method.getIntTupleValue(AllPredefinedLayers.PADDING_NAME).isPresent()){ //If the Padding is given in Tuple diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java index 2586d827b0621a11c0a36a30e1b66738201f8303..3c87ca3aca8549d05942295c9e52ed5fd5e3332e 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java @@ -9,12 +9,18 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; -import java.util.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; public class SerialCompositeElementSymbol extends CompositeElementSymbol { protected List> episodicSubNetworks = new ArrayList<>(new ArrayList<>()); protected boolean anyEpisodicLocalAdaptation = false; + protected boolean lossParameterizingElements = false; + protected void setElements(List elements) { ArchitectureElementSymbol previous = null; for (ArchitectureElementSymbol current : elements){ @@ -24,6 +30,10 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol { this.setAdaNet(true); this.setAdaLayer(current); } + if(AllPredefinedLayers.getLossParameterizingLayers().contains(current.getName())){ + // check if architecture has loss parametrizing layers + lossParameterizingElements = true; + } if(previous != null){ current.setInputElement(previous); previous.setOutputElement(current); @@ -72,6 +82,8 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol { public boolean getAnyEpisodicLocalAdaptation() { return anyEpisodicLocalAdaptation; } + public boolean hasLossParameterizingElements() { return lossParameterizingElements; } + @Override public void setInputElement(ArchitectureElementSymbol inputElement) { super.setInputElement(inputElement); diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java index 6b78a2ba352844bd53812097282c3dca2e813489..6c9837a981a29406e2aa387b2d16356682831f0f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java @@ -8,10 +8,10 @@ /* (c) https://github.com/MontiCore/monticore */ package de.monticore.lang.monticar.cnnarch.predefined; -import de.monticore.lang.monticar.cnnarch._symboltable.*; -import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes; -import de.se_rwth.commons.Joiners; -import de.se_rwth.commons.logging.Log; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.PredefinedLayerDeclaration; +import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol; import org.jscience.mathematics.number.Rational; import java.util.ArrayList; @@ -20,9 +20,7 @@ import java.util.List; public class Add extends PredefinedLayerDeclaration { - private Add() { - super(AllPredefinedLayers.ADD_NAME); - } + private Add() {super(AllPredefinedLayers.ADD_NAME);} @Override public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { @@ -39,30 +37,7 @@ public class Add extends PredefinedLayerDeclaration { @Override public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { errorIfInputIsEmpty(inputTypes, layer); - if (inputTypes.size() == 1){ - Log.warn("Add layer has only one input stream. Layer can be removed." , layer.getSourcePosition()); - } - else if (inputTypes.size() > 1){ - List heightList = new ArrayList<>(); - List widthList = new ArrayList<>(); - List channelsList = new ArrayList<>(); - for (ArchTypeSymbol shape : inputTypes){ - heightList.add(shape.getHeight()); - widthList.add(shape.getWidth()); - channelsList.add(shape.getChannels()); - } - int countEqualHeights = (int)heightList.stream().distinct().count(); - int countEqualWidths = (int)widthList.stream().distinct().count(); - int countEqualNumberOfChannels = (int)channelsList.stream().distinct().count(); - if (countEqualHeights != 1 || countEqualWidths != 1 || countEqualNumberOfChannels != 1){ - Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " + - "Shapes of all input streams must be equal. " + - "Input heights: " + Joiners.COMMA.join(heightList) + ". " + - "Input widths: " + Joiners.COMMA.join(widthList) + ". " + - "Number of input channels: " + Joiners.COMMA.join(channelsList) + ". " - , layer.getSourcePosition()); - } - } + errorIfMultipleInputShapesAreNotEqual(inputTypes, layer, HandlingSingleInputs.IGNORED); } public static Add create(){ diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java index d670086c30317a166e742e69fdaf23bc1c78fc9e..883127d068ca87eb090e1021f6ce195761d4a8cc 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java @@ -10,12 +10,10 @@ package de.monticore.lang.monticar.cnnarch.predefined; import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationSymbol; -import jline.internal.Nullable; -import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.ArrayList; public class AllPredefinedLayers { @@ -38,6 +36,7 @@ public class AllPredefinedLayers { public static final String GET_NAME = "Get"; public static final String ADD_NAME = "Add"; public static final String CONCATENATE_NAME = "Concatenate"; + public static final String REPARAMETERIZE_NAME = "Reparameterize"; public static final String FLATTEN_NAME = "Flatten"; public static final String ONE_HOT_NAME = "OneHot"; public static final String BEAMSEARCH_NAME = "BeamSearch"; @@ -62,6 +61,7 @@ public class AllPredefinedLayers { public static final String CUSTOM_LAYER = "CustomLayer"; public static final String CONVOLUTION3D_NAME = "Convolution3D"; public static final String UP_CONVOLUTION3D_NAME = "UpConvolution3D"; + public static final String VECTOR_QUANTIZE_NAME = "VectorQuantize"; public static final String AdaNet_Name = "AdaNet"; //AdaNet layer @@ -114,6 +114,10 @@ public class AllPredefinedLayers { public static final Integer DEFAULT_UNITS = 20; public static final String DEFAULT_BLOCK = "default_block"; + //VAE Parameters + public static final String NUM_EMBEDDINGS_NAME = "num_embeddings"; + public static final String EMA_NAME = "ema"; + //parameters LoadNetwork layer public static final String NETWORK_DIR_NAME = "networkDir"; public static final String NETWORK_PREFIX_NAME = "networkPrefix"; @@ -163,6 +167,11 @@ public class AllPredefinedLayers { public static final String RANDOM = "random"; public static final String REPLACE_OLDEST = "replace_oldest"; public static final String NO_REPLACEMENT = "no_replacement"; + + // String values for Reparametrization Layer + public static final String PDF_NAME = "pdf"; + public static final String PDF_NORMAL = "normal"; + public static final String PDF_DIRICHLET = "dirichlet"; //possible activation values for the querry network in the memory layer public static final String MEMORY_ACTIVATION_LINEAR = "linear"; @@ -172,6 +181,15 @@ public class AllPredefinedLayers { public static final String MEMORY_ACTIVATION_SOFTRELU = "softrelu"; public static final String MEMORY_ACTIVATION_SOFTSIGN = "softsign"; + + public static List getLossParameterizingLayers() { + return Arrays.asList( + REPARAMETERIZE_NAME, + VECTOR_QUANTIZE_NAME + ); + } + + //list with all predefined layers public static List createList(){ return Arrays.asList( @@ -215,7 +233,9 @@ public class AllPredefinedLayers { EpisodicMemory.create(), Convolution3D.create(), UpConvolution3D.create(), - AdaNet.create()); + AdaNet.create(), + Reparameterize.create(), + VectorQuantize.create()); } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java new file mode 100644 index 0000000000000000000000000000000000000000..79fc8e3030a4ed8f2948455f2c5bcebbb60ecbad --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java @@ -0,0 +1,48 @@ +/** + * + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnnarch.predefined; + +import de.monticore.lang.monticar.cnnarch._symboltable.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class Reparameterize extends PredefinedLayerDeclaration { + + private Reparameterize() { super(AllPredefinedLayers.REPARAMETERIZE_NAME); } + + @Override + public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + return Collections.singletonList(new ArchTypeSymbol.Builder() + .channels(layer.getInputTypes().get(0).getChannels()) + .height(layer.getInputTypes().get(0).getHeight()) + .width(layer.getInputTypes().get(0).getWidth()) + .elementType("-oo", "oo") + .build()); + } + + @Override + public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + errorIfInputIsEmpty(inputTypes,layer); + errorIfInputNotFlattened(inputTypes,layer); + errorIfMultipleInputShapesAreNotEqual(inputTypes, layer, HandlingSingleInputs.RESTRICTED); + } + + public static Reparameterize create() { + Reparameterize layerDeclaration = new Reparameterize(); + layerDeclaration.setParameters(new ArrayList<>(Arrays.asList( + new ParameterSymbol.Builder() + .name(AllPredefinedLayers.PDF_NAME) + .constraints(Constraints.REPARAMETERIZE_PDFS) + .defaultValue(AllPredefinedLayers.PDF_NORMAL) + .build()))); + return layerDeclaration; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java new file mode 100644 index 0000000000000000000000000000000000000000..7903e2fc434a8525cc6c3fddbb08d441fe1c4937 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java @@ -0,0 +1,51 @@ +/** + * + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnnarch.predefined; + +import de.monticore.lang.monticar.cnnarch._symboltable.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class VectorQuantize extends PredefinedLayerDeclaration { + + private VectorQuantize(){ super(AllPredefinedLayers.VECTOR_QUANTIZE_NAME);} + + @Override + public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + return Collections.singletonList(new ArchTypeSymbol.Builder() + .channels(layer.getInputTypes().get(0).getChannels()) + .height(layer.getInputTypes().get(0).getHeight()) + .width(layer.getInputTypes().get(0).getWidth()) + .elementType("-oo", "oo") + .build()); + } + + @Override + public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + errorIfInputSizeIsNotOne(inputTypes, layer); + } + + public static VectorQuantize create(){ + VectorQuantize declaration = new VectorQuantize(); + List parameters = new ArrayList<>(Arrays.asList( + new ParameterSymbol.Builder() + .name(AllPredefinedLayers.NUM_EMBEDDINGS_NAME) + .constraints(Constraints.INTEGER, Constraints.POSITIVE) + .build(), + new ParameterSymbol.Builder() + .name(AllPredefinedLayers.BETA_NAME) + .constraints(Constraints.NUMBER) + .defaultValue(0.25) + .build())); + declaration.setParameters(parameters); + return declaration; + } +} diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java index 823634e1c57bdfd0a9a4e9f1bcee2a4853178492..9fa7996ef6f057ea156175657945f59df30d177f 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java @@ -56,4 +56,17 @@ public class SymtabTest extends AbstractSymtabTest { a.getArchitecture().getStreams().get(0).getOutputTypes(); } + + @Test + public void testVAELayers(){ + Scope symTab = createSymTab("src/test/resources/architectures"); + CNNArchCompilationUnitSymbol a = symTab.resolve( + "VAELayers", + CNNArchCompilationUnitSymbol.KIND).orElse(null); + assertNotNull(a); + a.resolve(); + a.getArchitecture().getStreams().get(0).getOutputTypes(); + } + + } diff --git a/src/test/resources/architectures/VAELayers.cnna b/src/test/resources/architectures/VAELayers.cnna new file mode 100644 index 0000000000000000000000000000000000000000..c248909f96cd266aa5772788f98162c53eac9869 --- /dev/null +++ b/src/test/resources/architectures/VAELayers.cnna @@ -0,0 +1,20 @@ +architecture VAELayers{ + def input Q(-oo:oo)^{1,28,28} data + def output Q(-oo:oo)^{1,28,28} res + + data -> + VectorQuantize(num_embeddings=64,beta=0.25) -> + FullyConnected(units=400) -> + FullyConnected(units=4) -> + Split(n=2) -> + ( + [0] + | + [1] + ) -> + Reparameterize(pdf="normal") -> + FullyConnected(units=400) -> + FullyConnected(units=784) -> + Reshape(shape=(1,28,28)) -> + res; +} \ No newline at end of file