From 72ee3be948a248020d85867957f75b2c0c83d75b Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Tue, 5 Oct 2021 14:43:17 +0200 Subject: [PATCH 01/16] Add Reparametrize() layer --- .../lang/monticar/cnnarch/predefined/Add.java | 10 ++++++++- .../predefined/AllPredefinedLayers.java | 8 +++---- .../cnnarch/predefined/Reparametrize.java | 22 +++++++++++++++++++ 3 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparametrize.java diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java index 6b78a2b..1aae625 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java @@ -20,8 +20,16 @@ import java.util.List; public class Add extends PredefinedLayerDeclaration { + protected final String name; + private Add() { super(AllPredefinedLayers.ADD_NAME); + this.name = AllPredefinedLayers.ADD_NAME; + } + + protected Add(String name) { + super(name); + this.name = name; } @Override @@ -40,7 +48,7 @@ public class Add extends PredefinedLayerDeclaration { public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { errorIfInputIsEmpty(inputTypes, layer); if (inputTypes.size() == 1){ - Log.warn("Add layer has only one input stream. Layer can be removed." , layer.getSourcePosition()); + Log.warn(this.name + " layer has only one input stream. Layer can be removed." , layer.getSourcePosition()); } else if (inputTypes.size() > 1){ List heightList = new ArrayList<>(); diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java index ff0a2e3..ce01e25 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java @@ -10,12 +10,10 @@ package de.monticore.lang.monticar.cnnarch.predefined; import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationSymbol; -import jline.internal.Nullable; -import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.ArrayList; public class AllPredefinedLayers { @@ -38,6 +36,7 @@ public class AllPredefinedLayers { public static final String GET_NAME = "Get"; public static final String ADD_NAME = "Add"; public static final String CONCATENATE_NAME = "Concatenate"; + public static final String REPARAMETRIZE_NAME = "Reparametrize"; public static final String FLATTEN_NAME = "Flatten"; public static final String ONE_HOT_NAME = "OneHot"; public static final String BEAMSEARCH_NAME = "BeamSearch"; @@ -198,7 +197,8 @@ public class AllPredefinedLayers { LoadNetwork.create(), DotProductSelfAttention.create(), LargeMemory.create(), - EpisodicMemory.create()); + EpisodicMemory.create(), + Reparametrize.create()); } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparametrize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparametrize.java new file mode 100644 index 0000000..97c8634 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparametrize.java @@ -0,0 +1,22 @@ +/** + * + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnnarch.predefined; + +import java.util.ArrayList; + +public class Reparametrize extends Add { + + private Reparametrize() { super(AllPredefinedLayers.REPARAMETRIZE_NAME); } + + public static Reparametrize create() { + Reparametrize layerDeclaration = new Reparametrize(); + layerDeclaration.setParameters(new ArrayList<>()); + return layerDeclaration; + } + +} -- GitLab From abbf7d636fd62a26ab9234884f55b376e1a35b93 Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 18 Nov 2021 14:10:19 +0100 Subject: [PATCH 02/16] Change Reparameterize --- .../lang/monticar/cnnarch/predefined/Add.java | 45 +++-------------- .../cnnarch/predefined/Reparameterize.java | 48 +++++++++++++++++++ 2 files changed, 54 insertions(+), 39 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java index 1aae625..6c9837a 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Add.java @@ -8,10 +8,10 @@ /* (c) https://github.com/MontiCore/monticore */ package de.monticore.lang.monticar.cnnarch.predefined; -import de.monticore.lang.monticar.cnnarch._symboltable.*; -import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes; -import de.se_rwth.commons.Joiners; -import de.se_rwth.commons.logging.Log; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.PredefinedLayerDeclaration; +import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol; import org.jscience.mathematics.number.Rational; import java.util.ArrayList; @@ -20,17 +20,7 @@ import java.util.List; public class Add extends PredefinedLayerDeclaration { - protected final String name; - - private Add() { - super(AllPredefinedLayers.ADD_NAME); - this.name = AllPredefinedLayers.ADD_NAME; - } - - protected Add(String name) { - super(name); - this.name = name; - } + private Add() {super(AllPredefinedLayers.ADD_NAME);} @Override public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { @@ -47,30 +37,7 @@ public class Add extends PredefinedLayerDeclaration { @Override public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { errorIfInputIsEmpty(inputTypes, layer); - if (inputTypes.size() == 1){ - Log.warn(this.name + " layer has only one input stream. Layer can be removed." , layer.getSourcePosition()); - } - else if (inputTypes.size() > 1){ - List heightList = new ArrayList<>(); - List widthList = new ArrayList<>(); - List channelsList = new ArrayList<>(); - for (ArchTypeSymbol shape : inputTypes){ - heightList.add(shape.getHeight()); - widthList.add(shape.getWidth()); - channelsList.add(shape.getChannels()); - } - int countEqualHeights = (int)heightList.stream().distinct().count(); - int countEqualWidths = (int)widthList.stream().distinct().count(); - int countEqualNumberOfChannels = (int)channelsList.stream().distinct().count(); - if (countEqualHeights != 1 || countEqualWidths != 1 || countEqualNumberOfChannels != 1){ - Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " + - "Shapes of all input streams must be equal. " + - "Input heights: " + Joiners.COMMA.join(heightList) + ". " + - "Input widths: " + Joiners.COMMA.join(widthList) + ". " + - "Number of input channels: " + Joiners.COMMA.join(channelsList) + ". " - , layer.getSourcePosition()); - } - } + errorIfMultipleInputShapesAreNotEqual(inputTypes, layer, HandlingSingleInputs.IGNORED); } public static Add create(){ diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java new file mode 100644 index 0000000..5308a3a --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java @@ -0,0 +1,48 @@ +/** + * + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnnarch.predefined; + +import de.monticore.lang.monticar.cnnarch._symboltable.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class Reparameterize extends PredefinedLayerDeclaration { + + private Reparameterize() { super(AllPredefinedLayers.REPARAMETERIZE_NAME); } + + @Override + public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + if (layer.getInputTypes().isEmpty()){ + return layer.getInputTypes(); + } + else { + return Collections.singletonList(layer.getInputTypes().get(0)); + } + } + + @Override + public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + errorIfInputIsEmpty(inputTypes,layer); + errorIfInputNotFlattened(inputTypes,layer); + errorIfMultipleInputShapesAreNotEqual(inputTypes, layer, HandlingSingleInputs.ALLOWED); + } + + public static Reparameterize create() { + Reparameterize layerDeclaration = new Reparameterize(); + layerDeclaration.setParameters(new ArrayList<>(Arrays.asList( + new ParameterSymbol.Builder() + .name(AllPredefinedLayers.PDF_NAME) + .constraints(Constraints.REPARAMETERIZE_PDFS) + .defaultValue(AllPredefinedLayers.PDF_NORMAL) + .build()))); + return layerDeclaration; + } +} -- GitLab From 5dad1324d935af9fadcffa1f6a67da65ddab824c Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 18 Nov 2021 14:49:07 +0100 Subject: [PATCH 03/16] Add Layers VectorQuantize, ConcatLabels --- .../cnnarch/predefined/ConcatLabels.java | 42 ++++++++++++++ .../cnnarch/predefined/VectorQuantize.java | 56 +++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ConcatLabels.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ConcatLabels.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ConcatLabels.java new file mode 100644 index 0000000..759e5d3 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ConcatLabels.java @@ -0,0 +1,42 @@ +/** + * + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnnarch.predefined; + +import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.PredefinedLayerDeclaration; +import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class ConcatLabels extends PredefinedLayerDeclaration { + private ConcatLabels(){ super(AllPredefinedLayers.CONCAT_LABELS_NAME);} + + @Override + public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + if (layer.getInputTypes().isEmpty()){ + return layer.getInputTypes(); + } + else { + return Collections.singletonList(layer.getInputTypes().get(0)); + } + } + + @Override + public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + errorIfInputSizeIsNotOne(inputTypes, layer); + } + + public static ConcatLabels create(){ + ConcatLabels declaration = new ConcatLabels(); + declaration.setParameters(new ArrayList<>()); + return declaration; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java new file mode 100644 index 0000000..6054766 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java @@ -0,0 +1,56 @@ +/** + * + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnnarch.predefined; + +import de.monticore.lang.monticar.cnnarch._symboltable.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class VectorQuantize extends PredefinedLayerDeclaration { + + private VectorQuantize(){ super(AllPredefinedLayers.VECTOR_QUANTIZE_NAME);} + + @Override + public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + if (layer.getInputTypes().isEmpty()){ + return layer.getInputTypes(); + } + else { + return Collections.singletonList(layer.getInputTypes().get(0)); + } + } + + @Override + public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { + errorIfInputSizeIsNotOne(inputTypes, layer); + } + + public static VectorQuantize create(){ + VectorQuantize declaration = new VectorQuantize(); + List parameters = new ArrayList<>(Arrays.asList( + new ParameterSymbol.Builder() + .name(AllPredefinedLayers.NUM_EMBEDDINGS_NAME) + .constraints(Constraints.INTEGER, Constraints.POSITIVE) + .build(), + new ParameterSymbol.Builder() + .name(AllPredefinedLayers.EMA_NAME) + .constraints(Constraints.BOOLEAN) + .defaultValue(true) + .build(), + new ParameterSymbol.Builder() + .name(AllPredefinedLayers.BETA_NAME) + .constraints(Constraints.NUMBER) + .defaultValue(0.25) + .build())); + declaration.setParameters(parameters); + return declaration; + } +} -- GitLab From 1555c9e50a3d05a0d038a8998409230ea7073e94 Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 18 Nov 2021 14:50:58 +0100 Subject: [PATCH 04/16] Add TypeErrors for input not flat and if shapes are not equal --- .../PredefinedLayerDeclaration.java | 56 +++++++++++++++++-- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java index cfede4b..9d9c672 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java @@ -11,13 +11,11 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; import de.monticore.lang.monticar.ranges._ast.ASTRange; +import de.se_rwth.commons.Joiners; import de.se_rwth.commons.logging.Log; import org.jscience.mathematics.number.Rational; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.BinaryOperator; import java.util.stream.Stream; @@ -174,6 +172,43 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol } } + /** + * Check if Inputs of Layer have the same shape + */ + protected static enum HandlingSingleInputs { + ALLOWED, IGNORED, RESTRICTED + } + protected static void errorIfMultipleInputShapesAreNotEqual(List inputTypes, LayerSymbol layer, HandlingSingleInputs handling) { + if (inputTypes.size() == 1){ + if (handling == HandlingSingleInputs.IGNORED) { + Log.warn(layer.getName() + " layer has only one input stream. Layer can be removed.", layer.getSourcePosition()); + } else if (handling == HandlingSingleInputs.RESTRICTED){ + Log.error(layer.getName() + " layer has only one input stream.", layer.getSourcePosition()); + } + } + else if (inputTypes.size() > 1){ + List heightList = new ArrayList<>(); + List widthList = new ArrayList<>(); + List channelsList = new ArrayList<>(); + for (ArchTypeSymbol shape : inputTypes){ + heightList.add(shape.getHeight()); + widthList.add(shape.getWidth()); + channelsList.add(shape.getChannels()); + } + int countEqualHeights = (int)heightList.stream().distinct().count(); + int countEqualWidths = (int)widthList.stream().distinct().count(); + int countEqualNumberOfChannels = (int)channelsList.stream().distinct().count(); + if (countEqualHeights != 1 || countEqualWidths != 1 || countEqualNumberOfChannels != 1){ + Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " + + "Shapes of all input streams must be equal. " + + "Input heights: " + Joiners.COMMA.join(heightList) + ". " + + "Input widths: " + Joiners.COMMA.join(widthList) + ". " + + "Number of input channels: " + Joiners.COMMA.join(channelsList) + ". " + , layer.getSourcePosition()); + } + } + } + //check input for convolution and pooling protected static void errorIfInputSmallerThanKernel(List inputTypes, LayerSymbol layer) { if (!inputTypes.isEmpty()) { @@ -203,6 +238,19 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol } } + protected static void errorIfInputNotFlattened(List inputTypes, LayerSymbol layer) { + if (!inputTypes.isEmpty()) { + for (ArchTypeSymbol inputType : layer.getInputTypes()) { + int height = inputType.getHeight(); + int width = inputType.getWidth(); + if (height != 1 || width != 1) { + Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input." + + " Input layer must be flat, consider using a 'Flatten()' layer.", layer.getSourcePosition()); + } + } + } + } + //output type function for convolution and poolingee protected static List computeConvAndPoolOutputShape(ArchTypeSymbol inputType, LayerSymbol method, int channels) { -- GitLab From cfdfdb923bb531b8e28c702ad845983015d56827 Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 18 Nov 2021 14:52:18 +0100 Subject: [PATCH 05/16] Add a list of layers that parameterize the loss --- .../SerialCompositeElementSymbol.java | 14 ++++++++- .../predefined/AllPredefinedLayers.java | 31 ++++++++++++++++--- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java index 2586d82..3c87ca3 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/SerialCompositeElementSymbol.java @@ -9,12 +9,18 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; -import java.util.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; public class SerialCompositeElementSymbol extends CompositeElementSymbol { protected List> episodicSubNetworks = new ArrayList<>(new ArrayList<>()); protected boolean anyEpisodicLocalAdaptation = false; + protected boolean lossParameterizingElements = false; + protected void setElements(List elements) { ArchitectureElementSymbol previous = null; for (ArchitectureElementSymbol current : elements){ @@ -24,6 +30,10 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol { this.setAdaNet(true); this.setAdaLayer(current); } + if(AllPredefinedLayers.getLossParameterizingLayers().contains(current.getName())){ + // check if architecture has loss parametrizing layers + lossParameterizingElements = true; + } if(previous != null){ current.setInputElement(previous); previous.setOutputElement(current); @@ -72,6 +82,8 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol { public boolean getAnyEpisodicLocalAdaptation() { return anyEpisodicLocalAdaptation; } + public boolean hasLossParameterizingElements() { return lossParameterizingElements; } + @Override public void setInputElement(ArchitectureElementSymbol inputElement) { super.setInputElement(inputElement); diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java index 4acd8df..6b8d97b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java @@ -10,12 +10,10 @@ package de.monticore.lang.monticar.cnnarch.predefined; import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationSymbol; -import jline.internal.Nullable; -import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.ArrayList; public class AllPredefinedLayers { @@ -38,7 +36,7 @@ public class AllPredefinedLayers { public static final String GET_NAME = "Get"; public static final String ADD_NAME = "Add"; public static final String CONCATENATE_NAME = "Concatenate"; - public static final String REPARAMETRIZE_NAME = "Reparametrize"; + public static final String REPARAMETERIZE_NAME = "Reparameterize"; public static final String FLATTEN_NAME = "Flatten"; public static final String ONE_HOT_NAME = "OneHot"; public static final String BEAMSEARCH_NAME = "BeamSearch"; @@ -63,6 +61,8 @@ public class AllPredefinedLayers { public static final String CUSTOM_LAYER = "CustomLayer"; public static final String CONVOLUTION3D_NAME = "Convolution3D"; public static final String UP_CONVOLUTION3D_NAME = "UpConvolution3D"; + public static final String VECTOR_QUANTIZE_NAME = "VectorQuantize"; + public static final String CONCAT_LABELS_NAME = "ConcatLabels"; public static final String AdaNet_Name = "AdaNet"; //AdaNet layer @@ -115,6 +115,10 @@ public class AllPredefinedLayers { public static final Integer DEFAULT_UNITS = 20; public static final String DEFAULT_BLOCK = "default_block"; + //VAE Parameters + public static final String NUM_EMBEDDINGS_NAME = "num_embeddings"; + public static final String EMA_NAME = "ema"; + //parameters LoadNetwork layer public static final String NETWORK_DIR_NAME = "networkDir"; public static final String NETWORK_PREFIX_NAME = "networkPrefix"; @@ -164,6 +168,11 @@ public class AllPredefinedLayers { public static final String RANDOM = "random"; public static final String REPLACE_OLDEST = "replace_oldest"; public static final String NO_REPLACEMENT = "no_replacement"; + + // String values for Reparametrization Layer + public static final String PDF_NAME = "pdf"; + public static final String PDF_NORMAL = "normal"; + public static final String PDF_DIRICHLET = "dirichlet"; //possible activation values for the querry network in the memory layer public static final String MEMORY_ACTIVATION_LINEAR = "linear"; @@ -173,6 +182,15 @@ public class AllPredefinedLayers { public static final String MEMORY_ACTIVATION_SOFTRELU = "softrelu"; public static final String MEMORY_ACTIVATION_SOFTSIGN = "softsign"; + + public static List getLossParameterizingLayers() { + return Arrays.asList( + REPARAMETERIZE_NAME, + VECTOR_QUANTIZE_NAME + ); + } + + //list with all predefined layers public static List createList(){ return Arrays.asList( @@ -216,7 +234,10 @@ public class AllPredefinedLayers { EpisodicMemory.create(), Convolution3D.create(), UpConvolution3D.create(), - AdaNet.create()); + AdaNet.create(), + Reparameterize.create(), + VectorQuantize.create(), + ConcatLabels.create()); } -- GitLab From 95c33b4f20cbb43ac848c4a4a55eb0e7614dd145 Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 18 Nov 2021 14:54:56 +0100 Subject: [PATCH 06/16] Add a reference variable to the auxiliary Architecture of the trainedArchitecture --- .../_symboltable/ArchitectureSymbol.java | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java index 2664bd5..5388913 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java @@ -11,18 +11,18 @@ package de.monticore.lang.monticar.cnnarch._symboltable; +import de.monticore.lang.monticar.cnnarch._cocos.CheckLayerPathParameter; import de.monticore.lang.monticar.cnnarch.helper.Utils; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables; -import de.monticore.lang.monticar.cnnarch._cocos.CheckLayerPathParameter; import de.monticore.symboltable.CommonScopeSpanningSymbol; import de.monticore.symboltable.Scope; import de.monticore.symboltable.Symbol; -import org.apache.commons.math3.ml.neuralnet.Network; -import java.lang.RuntimeException; -import java.lang.NullPointerException; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; public class ArchitectureSymbol extends CommonScopeSpanningSymbol { @@ -36,6 +36,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol { private String dataPath; private String weightsPath; private String componentName; + private ArchitectureSymbol auxiliaryArchitecture; private boolean AdaNet = false; //attribute for the path for custom python files private String customPyFilesPath; @@ -52,6 +53,14 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol { public void setLayerVariableDeclarations(List layerVariableDeclarations) { this.layerVariableDeclarations = layerVariableDeclarations; } + public void setAuxiliaryArchitecture(ArchitectureSymbol auxiliaryArchitecture){ + this.auxiliaryArchitecture = auxiliaryArchitecture; + } + + public ArchitectureSymbol getAuxiliaryArchitecture(){ + return this.auxiliaryArchitecture; + } + public String getAdaNetUtils(){return this.adaNetUtils;} public void setAdaNetUtils(String adaNetUtils){this.adaNetUtils=adaNetUtils;} public boolean containsAdaNet() { -- GitLab From 1704437fc5e8f6c5987a9f9aa97360df3f129edf Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 18 Nov 2021 14:57:16 +0100 Subject: [PATCH 07/16] Constrain the possible pdfs of the Reparameterize Layer --- .../cnnarch/_symboltable/Constraints.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java index 8cdf99d..c39a67e 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/Constraints.java @@ -393,6 +393,25 @@ public enum Constraints { public String msgString() { return "an axis between 0 and 1"; } + }, + REPARAMETERIZE_PDFS { + @Override + public boolean isValid(ArchSimpleExpressionSymbol exp) { + Optional optString= exp.getStringValue(); + if (optString.isPresent()){ + if ( optString.get().equals(AllPredefinedLayers.PDF_NORMAL)){ + //|| optString.get().equals(AllPredefinedLayers.PDF_DIRICHLET)){ + return true; + } + } + return false; + } + @Override + protected String msgString() { + return AllPredefinedLayers.PDF_NORMAL; //+ " or " + //+ AllPredefinedLayers.PDF_DIRICHLET; + + } }; protected abstract boolean isValid(ArchSimpleExpressionSymbol exp); -- GitLab From 8c3132f95d533626634abdca1fc5587e623d2b6f Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 18 Nov 2021 14:57:53 +0100 Subject: [PATCH 08/16] Add Testcases for VAE and VQVAE --- .../lang/monticar/cnnarch/SymtabTest.java | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java index 823634e..db169a3 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java @@ -56,4 +56,37 @@ public class SymtabTest extends AbstractSymtabTest { a.getArchitecture().getStreams().get(0).getOutputTypes(); } + @Test + public void testAdd(){ + Scope symTab = createSymTab("src/test/resources/architectures"); + CNNArchCompilationUnitSymbol a = symTab.resolve( + "Add", + CNNArchCompilationUnitSymbol.KIND).orElse(null); + assertNotNull(a); + a.resolve(); + a.getArchitecture().getStreams().get(0).getOutputTypes(); + } + + @Test + public void testVAE_Encoder(){ + Scope symTab = createSymTab("src/test/resources/architectures"); + CNNArchCompilationUnitSymbol a = symTab.resolve( + "VAE_Encoder", + CNNArchCompilationUnitSymbol.KIND).orElse(null); + assertNotNull(a); + a.resolve(); + a.getArchitecture().getStreams().get(0).getOutputTypes(); + } + + @Test + public void testVQVAE_Decoder(){ + Scope symTab = createSymTab("src/test/resources/architectures"); + CNNArchCompilationUnitSymbol a = symTab.resolve( + "VQVAE_Decoder", + CNNArchCompilationUnitSymbol.KIND).orElse(null); + assertNotNull(a); + a.resolve(); + a.getArchitecture().getStreams().get(0).getOutputTypes(); + } + } -- GitLab From 0260f324d6a55fbb2fcf4f971c0d3006ef14cc9d Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Wed, 1 Dec 2021 22:31:48 +0100 Subject: [PATCH 09/16] Remove EMA parameter, Rewrite Reparameterize, Remove ConcatLabels --- pom.xml | 2 +- .../PredefinedLayerDeclaration.java | 40 ++++++++++-------- .../predefined/AllPredefinedLayers.java | 4 +- .../cnnarch/predefined/ConcatLabels.java | 42 ------------------- .../cnnarch/predefined/Reparametrize.java | 22 ---------- .../cnnarch/predefined/VectorQuantize.java | 5 --- 6 files changed, 25 insertions(+), 90 deletions(-) delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ConcatLabels.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparametrize.java diff --git a/pom.xml b/pom.xml index 3a7d09b..0cd5c56 100644 --- a/pom.xml +++ b/pom.xml @@ -382,4 +382,4 @@ - \ No newline at end of file + diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java index 9d9c672..6f694c0 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/PredefinedLayerDeclaration.java @@ -172,12 +172,19 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol } } - /** - * Check if Inputs of Layer have the same shape - */ protected static enum HandlingSingleInputs { ALLOWED, IGNORED, RESTRICTED } + + /** + * Check if Inputs of Layer have the same shape + * @param inputTypes: List of input Types + * @param layer: curremt Layer + * @param handling: HandlingSingleInputs Enum Value, either ALLOWED, IGNORED or RESTRICTED + * ALLOWED will skip the Check if there are only one Input. + * IGNORED will print a message that the Layer is redundant, but the Input may still be passed + * RESTRICTED will throw an error. This indicates that the Layer does not allow + */ protected static void errorIfMultipleInputShapesAreNotEqual(List inputTypes, LayerSymbol layer, HandlingSingleInputs handling) { if (inputTypes.size() == 1){ if (handling == HandlingSingleInputs.IGNORED) { @@ -209,6 +216,19 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol } } + protected static void errorIfInputNotFlattened(List inputTypes, LayerSymbol layer) { + if (!inputTypes.isEmpty()) { + for (ArchTypeSymbol inputType : layer.getInputTypes()) { + int height = inputType.getHeight(); + int width = inputType.getWidth(); + if (height != 1 || width != 1) { + Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input." + + " Input layer must be flat, consider using a 'Flatten()' layer.", layer.getSourcePosition()); + } + } + } + } + //check input for convolution and pooling protected static void errorIfInputSmallerThanKernel(List inputTypes, LayerSymbol layer) { if (!inputTypes.isEmpty()) { @@ -238,20 +258,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol } } - protected static void errorIfInputNotFlattened(List inputTypes, LayerSymbol layer) { - if (!inputTypes.isEmpty()) { - for (ArchTypeSymbol inputType : layer.getInputTypes()) { - int height = inputType.getHeight(); - int width = inputType.getWidth(); - if (height != 1 || width != 1) { - Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input." + - " Input layer must be flat, consider using a 'Flatten()' layer.", layer.getSourcePosition()); - } - } - } - } - - //output type function for convolution and poolingee protected static List computeConvAndPoolOutputShape(ArchTypeSymbol inputType, LayerSymbol method, int channels) { if (method.getIntTupleValue(AllPredefinedLayers.PADDING_NAME).isPresent()){ //If the Padding is given in Tuple diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java index 6b8d97b..883127d 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/AllPredefinedLayers.java @@ -62,7 +62,6 @@ public class AllPredefinedLayers { public static final String CONVOLUTION3D_NAME = "Convolution3D"; public static final String UP_CONVOLUTION3D_NAME = "UpConvolution3D"; public static final String VECTOR_QUANTIZE_NAME = "VectorQuantize"; - public static final String CONCAT_LABELS_NAME = "ConcatLabels"; public static final String AdaNet_Name = "AdaNet"; //AdaNet layer @@ -236,8 +235,7 @@ public class AllPredefinedLayers { UpConvolution3D.create(), AdaNet.create(), Reparameterize.create(), - VectorQuantize.create(), - ConcatLabels.create()); + VectorQuantize.create()); } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ConcatLabels.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ConcatLabels.java deleted file mode 100644 index 759e5d3..0000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/ConcatLabels.java +++ /dev/null @@ -1,42 +0,0 @@ -/** - * - * (c) https://github.com/MontiCore/monticore - * - * The license generally applicable for this project - * can be found under https://github.com/MontiCore/monticore. - */ -package de.monticore.lang.monticar.cnnarch.predefined; - -import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; -import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; -import de.monticore.lang.monticar.cnnarch._symboltable.PredefinedLayerDeclaration; -import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -public class ConcatLabels extends PredefinedLayerDeclaration { - private ConcatLabels(){ super(AllPredefinedLayers.CONCAT_LABELS_NAME);} - - @Override - public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { - if (layer.getInputTypes().isEmpty()){ - return layer.getInputTypes(); - } - else { - return Collections.singletonList(layer.getInputTypes().get(0)); - } - } - - @Override - public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { - errorIfInputSizeIsNotOne(inputTypes, layer); - } - - public static ConcatLabels create(){ - ConcatLabels declaration = new ConcatLabels(); - declaration.setParameters(new ArrayList<>()); - return declaration; - } -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparametrize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparametrize.java deleted file mode 100644 index 97c8634..0000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparametrize.java +++ /dev/null @@ -1,22 +0,0 @@ -/** - * - * (c) https://github.com/MontiCore/monticore - * - * The license generally applicable for this project - * can be found under https://github.com/MontiCore/monticore. - */ -package de.monticore.lang.monticar.cnnarch.predefined; - -import java.util.ArrayList; - -public class Reparametrize extends Add { - - private Reparametrize() { super(AllPredefinedLayers.REPARAMETRIZE_NAME); } - - public static Reparametrize create() { - Reparametrize layerDeclaration = new Reparametrize(); - layerDeclaration.setParameters(new ArrayList<>()); - return layerDeclaration; - } - -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java index 6054766..09c038f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java @@ -40,11 +40,6 @@ public class VectorQuantize extends PredefinedLayerDeclaration { .name(AllPredefinedLayers.NUM_EMBEDDINGS_NAME) .constraints(Constraints.INTEGER, Constraints.POSITIVE) .build(), - new ParameterSymbol.Builder() - .name(AllPredefinedLayers.EMA_NAME) - .constraints(Constraints.BOOLEAN) - .defaultValue(true) - .build(), new ParameterSymbol.Builder() .name(AllPredefinedLayers.BETA_NAME) .constraints(Constraints.NUMBER) -- GitLab From ae0a34c114fd62c9df93f82917bee6faf438b187 Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Wed, 15 Dec 2021 16:36:06 +0100 Subject: [PATCH 10/16] Add types for Reparameterize and VectorQuantize --- .../cnnarch/predefined/Reparameterize.java | 18 +++++++++++------- .../cnnarch/predefined/VectorQuantize.java | 16 ++++++++++------ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java index 5308a3a..6adef11 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java @@ -20,19 +20,23 @@ public class Reparameterize extends PredefinedLayerDeclaration { @Override public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { - if (layer.getInputTypes().isEmpty()){ - return layer.getInputTypes(); - } - else { - return Collections.singletonList(layer.getInputTypes().get(0)); - } + int channels = layer.getInputTypes().get(0).getChannels(); + int height = layer.getInputTypes().get(0).getHeight(); + int width = layer.getInputTypes().get(0).getWidth(); + + return Collections.singletonList(new ArchTypeSymbol.Builder() + .channels(channels) + .height(height) + .width(width) + .elementType("-oo","oo") + .build()); } @Override public void checkInput(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { errorIfInputIsEmpty(inputTypes,layer); errorIfInputNotFlattened(inputTypes,layer); - errorIfMultipleInputShapesAreNotEqual(inputTypes, layer, HandlingSingleInputs.ALLOWED); + errorIfMultipleInputShapesAreNotEqual(inputTypes, layer, HandlingSingleInputs.RESTRICTED); } public static Reparameterize create() { diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java index 09c038f..8e2fdaf 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java @@ -20,12 +20,16 @@ public class VectorQuantize extends PredefinedLayerDeclaration { @Override public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { - if (layer.getInputTypes().isEmpty()){ - return layer.getInputTypes(); - } - else { - return Collections.singletonList(layer.getInputTypes().get(0)); - } + int channels = layer.getInputTypes().get(0).getChannels(); + int height = layer.getInputTypes().get(0).getHeight(); + int width = layer.getInputTypes().get(0).getWidth(); + + return Collections.singletonList(new ArchTypeSymbol.Builder() + .channels(channels) + .height(height) + .width(width) + .elementType("-oo","oo") + .build()); } @Override -- GitLab From 4e3c896bfa4c44e48f85c50f05bf44647da22ed9 Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 20 Jan 2022 13:53:55 +0100 Subject: [PATCH 11/16] Update Types --- .../monticar/cnnarch/predefined/Reparameterize.java | 12 ++++-------- .../monticar/cnnarch/predefined/VectorQuantize.java | 12 ++++-------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java index 6adef11..79fc8e3 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/Reparameterize.java @@ -20,15 +20,11 @@ public class Reparameterize extends PredefinedLayerDeclaration { @Override public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { - int channels = layer.getInputTypes().get(0).getChannels(); - int height = layer.getInputTypes().get(0).getHeight(); - int width = layer.getInputTypes().get(0).getWidth(); - return Collections.singletonList(new ArchTypeSymbol.Builder() - .channels(channels) - .height(height) - .width(width) - .elementType("-oo","oo") + .channels(layer.getInputTypes().get(0).getChannels()) + .height(layer.getInputTypes().get(0).getHeight()) + .width(layer.getInputTypes().get(0).getWidth()) + .elementType("-oo", "oo") .build()); } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java index 8e2fdaf..7903e2f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/predefined/VectorQuantize.java @@ -20,15 +20,11 @@ public class VectorQuantize extends PredefinedLayerDeclaration { @Override public List computeOutputTypes(List inputTypes, LayerSymbol layer, VariableSymbol.Member member) { - int channels = layer.getInputTypes().get(0).getChannels(); - int height = layer.getInputTypes().get(0).getHeight(); - int width = layer.getInputTypes().get(0).getWidth(); - return Collections.singletonList(new ArchTypeSymbol.Builder() - .channels(channels) - .height(height) - .width(width) - .elementType("-oo","oo") + .channels(layer.getInputTypes().get(0).getChannels()) + .height(layer.getInputTypes().get(0).getHeight()) + .width(layer.getInputTypes().get(0).getWidth()) + .elementType("-oo", "oo") .build()); } -- GitLab From b8ef1259cfc0a0213d1d02fb4101905a33a31229 Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 20 Jan 2022 13:54:24 +0100 Subject: [PATCH 12/16] Modify VAE Test --- .../lang/monticar/cnnarch/SymtabTest.java | 24 ++----------------- .../resources/architectures/VAELayers.cnna | 20 ++++++++++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) create mode 100644 src/test/resources/architectures/VAELayers.cnna diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java index db169a3..9fa7996 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java @@ -56,37 +56,17 @@ public class SymtabTest extends AbstractSymtabTest { a.getArchitecture().getStreams().get(0).getOutputTypes(); } - @Test - public void testAdd(){ - Scope symTab = createSymTab("src/test/resources/architectures"); - CNNArchCompilationUnitSymbol a = symTab.resolve( - "Add", - CNNArchCompilationUnitSymbol.KIND).orElse(null); - assertNotNull(a); - a.resolve(); - a.getArchitecture().getStreams().get(0).getOutputTypes(); - } @Test - public void testVAE_Encoder(){ + public void testVAELayers(){ Scope symTab = createSymTab("src/test/resources/architectures"); CNNArchCompilationUnitSymbol a = symTab.resolve( - "VAE_Encoder", + "VAELayers", CNNArchCompilationUnitSymbol.KIND).orElse(null); assertNotNull(a); a.resolve(); a.getArchitecture().getStreams().get(0).getOutputTypes(); } - @Test - public void testVQVAE_Decoder(){ - Scope symTab = createSymTab("src/test/resources/architectures"); - CNNArchCompilationUnitSymbol a = symTab.resolve( - "VQVAE_Decoder", - CNNArchCompilationUnitSymbol.KIND).orElse(null); - assertNotNull(a); - a.resolve(); - a.getArchitecture().getStreams().get(0).getOutputTypes(); - } } diff --git a/src/test/resources/architectures/VAELayers.cnna b/src/test/resources/architectures/VAELayers.cnna new file mode 100644 index 0000000..bc525d9 --- /dev/null +++ b/src/test/resources/architectures/VAELayers.cnna @@ -0,0 +1,20 @@ +architecture VAELayers{ + def input Q(-oo:oo)^{1,28,28} data + def output Q(-oo:oo)^{1,28,28} res + + data -> + VectorQuantize(num_embeddings=64) + FullyConnected(units=400) -> + FullyConnected(units=4) -> + Split(n=2) -> + ( + [0] + | + [1] + ) -> + Reparameterize() -> + FullyConnected(units=400) -> + FullyConnected(units=784) -> + Reshape(shape=(1,28,28)) -> + res; +} \ No newline at end of file -- GitLab From 210f576e4d824d8284367faf87266a133b993ddc Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 27 Jan 2022 14:37:47 +0100 Subject: [PATCH 13/16] Fix VAE Test --- src/test/resources/architectures/VAELayers.cnna | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/resources/architectures/VAELayers.cnna b/src/test/resources/architectures/VAELayers.cnna index bc525d9..be9a889 100644 --- a/src/test/resources/architectures/VAELayers.cnna +++ b/src/test/resources/architectures/VAELayers.cnna @@ -3,7 +3,7 @@ architecture VAELayers{ def output Q(-oo:oo)^{1,28,28} res data -> - VectorQuantize(num_embeddings=64) + VectorQuantize(num_embeddings=64,beta=0.25) -> FullyConnected(units=400) -> FullyConnected(units=4) -> Split(n=2) -> @@ -12,7 +12,7 @@ architecture VAELayers{ | [1] ) -> - Reparameterize() -> + Reparameterize(pdf=normal) -> FullyConnected(units=400) -> FullyConnected(units=784) -> Reshape(shape=(1,28,28)) -> -- GitLab From 353d9d8045b6f10e3ef5548057cf76009ccd9d91 Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 27 Jan 2022 14:38:34 +0100 Subject: [PATCH 14/16] Fix VAE Test --- src/test/resources/architectures/VAELayers.cnna | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/resources/architectures/VAELayers.cnna b/src/test/resources/architectures/VAELayers.cnna index be9a889..c248909 100644 --- a/src/test/resources/architectures/VAELayers.cnna +++ b/src/test/resources/architectures/VAELayers.cnna @@ -12,7 +12,7 @@ architecture VAELayers{ | [1] ) -> - Reparameterize(pdf=normal) -> + Reparameterize(pdf="normal") -> FullyConnected(units=400) -> FullyConnected(units=784) -> Reshape(shape=(1,28,28)) -> -- GitLab From a699519ffd35cbb9fb0d8cede842909f840ab15d Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 27 Jan 2022 15:12:58 +0100 Subject: [PATCH 15/16] Add layer descriptions to README.md --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index 3e6139f..1562e0d 100644 --- a/README.md +++ b/README.md @@ -802,6 +802,23 @@ All predefined methods start with a capital letter and all constructed methods h * **padding** ({"valid", "same", "no_loss"}, optional, default="same3d"): One of "valid3d", "same3d" or "simple3d". "valid" means no padding. "same" results in padding the input such that the output has the same length as the original input divided by the stride (rounded up). "simple3d" results constant padding of size 1 (same as (1,1,1). UpConvolution3D also accepts tuples of form (height, widht, depth) as input. * **no_bias** (boolean, optional, default=false): Whether to disable the bias parameter. + +* **Reparameterize(pdf="normal")** + + Must be used inorder to model VAEs, $\beta$-VAEs and Conditional VAEs. Applies the Reparameterization Trick and samples a Code from the approximating Distribution. + + * **pdf** ({"normal"}, optional, default="normal"): + * normal: Takes in 2 Datastreams and applies the Reparameterization Trick for a normal distribution. + + +* **VectorQuantize(num_embeddings,beta=0.25)** + + Must be used inorder to model VQ-VAEs. Use this layer to quantize the pixels of incoming feature maps with a vector from the codebook. + + * **num_embeddings** (integer > 0, required): Number of vectors within the codebook. + * **beta**: (float, optional, default=0.25): Commitment cost factor that weights the commitment term of the VQ-VAE Loss function. "We found the resulting algorithm to be + quite robust to β, as the results did not vary for values of β ranging from 0.1 to 2.0" [van den Oord et al. 2017] + * **EpisodicMemory(replayMemoryStoreProb=1, maxStoredSamples=-1, memoryReplacementStrategy="replace_oldest", useReplay=true, replayInterval, replayBatchSize=-1, replaySteps, replayGradientSteps=1, useLocalAdaption=true, localAdaptionGradientSteps=1, localAdaptionK=1, queryNetDir=-1, queryNetPrefix=-1, queryNetNumInputs=1)** -- GitLab From 6dfea1de983ffd9280e24f1ae570045ee03690fd Mon Sep 17 00:00:00 2001 From: "@celik-furkan" Date: Thu, 27 Jan 2022 15:15:00 +0100 Subject: [PATCH 16/16] Add layer descriptions to README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1562e0d..f3f1d77 100644 --- a/README.md +++ b/README.md @@ -805,7 +805,7 @@ All predefined methods start with a capital letter and all constructed methods h * **Reparameterize(pdf="normal")** - Must be used inorder to model VAEs, $\beta$-VAEs and Conditional VAEs. Applies the Reparameterization Trick and samples a Code from the approximating Distribution. + Must be used inorder to model VAEs, β-VAEs and Conditional VAEs. Applies the Reparameterization Trick and samples a Code from the approximating Distribution. * **pdf** ({"normal"}, optional, default="normal"): * normal: Takes in 2 Datastreams and applies the Reparameterization Trick for a normal distribution. -- GitLab