diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 3b4055745a70669d3cc3ae582dba115c1fc12a95..c8121cf150898070aff59c0e580496edc7f7ec51 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -19,7 +19,7 @@ git masterJobLinux:
integrationMXNetJobLinux:
stage: linux
- image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/applications/gans/mnist-infogan/gans_mxnet:latest
+ image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.5
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest=IntegrationMXNetTest
@@ -33,7 +33,7 @@ integrationCaffe2JobLinux:
integrationGluonJobLinux:
stage: linux
- image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/applications/gans/mnist-infogan/gans_mxnet:latest
+ image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.5
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest=IntegrationGluonTest
diff --git a/README.md b/README.md
index 5e5f7d4655a6ee7937c3335abdecdeab6cd0ca98..4524b85c793d6e69e6f58e983ec59a0100494e88 100644
--- a/README.md
+++ b/README.md
@@ -17,11 +17,11 @@ See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/E
* Deep learning backend:
* MXNet
* training - generated is Python code. Required is Python 2.7 or higher, Python packages `h5py`, `mxnet` (for training on CPU) or e.g. `mxnet-cu75` for CUDA 7.5 (for training on GPU with CUDA, concrete package should be selected according to CUDA version). Follow [official instructions on MXNet site](https://mxnet.incubator.apache.org/install/index.html?platform=Linux&language=Python&processor=CPU)
- * prediction - generated code is C++. Install MXNet using [official instructions on MXNet site](https://mxnet.incubator.apache.org) for C++.
+ * prediction - generated code is C++.
* Caffe2
* training - generated is Python code. Follow [ official instructions on Caffe2 site ](https://caffe2.ai/docs/getting-started.html?platform=ubuntu&configuration=prebuilt)
- * See the scripts under Installation for better instructions, as an old caffe vversion is used that needs special considerations.
+ * See the scripts under Installation for better instructions, as an old caffe version is used that needs special considerations.
* Gluon
@@ -30,16 +30,14 @@ See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/E
* prediction - generated code is C++.
## Installation
-The two bash scripts found under [installation scripts](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/tree/tensorflow_group/src/main/resources/installation_scripts)
-should build and install all prerequisits for all backends as of 26.09.2019.
-Note that the installation may take some time (hours) and you will need some disk space (> 60GB) for all backends. Also enough RAM or a big
-enough swapspace is advisable (>10GB) for the installation of the cpp part of tensorflow. This scripts were tested with a completly clean Ubuntu 16.04,
+A new bash script for mxnet/gluon can be found [installation scripts](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/-/tree/master/src/main/resources/installation_scripts)
+changing the installation process for mxnet for cpp. This fill now install the full cpp api and not the reduced c api. This script will install all dependencies both for python and cpp as of 26.08.2020.
+Additionally a similar docker script used for the git ci pipeline can be found in the gluon subfolder at [Docker images](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/-/tree/master/src/test/resources/docker).
+The other two bash scripts found in the installation_scripts folder are outdated but may be consulted for installation guidlines for other backends.
+Note that the installation may take some time (hours) enough RAM or a big enough swapspace is advisable (>10GB). This scripts were tested with a completly clean Ubuntu 16.04,
without system updates installed. Using another Ubuntu version or installing other stuff, system updates included might/ have caused problems.
If you want to install the backends with CUDA GPU support(only MXNet/Gluon and Tensorflow, the used caffe2 version does not work with GPU support anymore),
-you have to install CUDA 10.0(!!), CUDNN and NCCL (Obtainable from the nvidai webpage. You can follow their instructions.) inbetween the two scripts.
-Furthermore you will have to change the pip commands for mxnet and tensorflow to the respective commented out parts.
-Also docker images for the cpu version of each backend are provided at [Docker images](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/tree/tensorflow_group/src/test/resources/docker),
-though some of these might be outdated.
+you have to install CUDA 10.0(mxnet/ gluon also works with newer version and maybe older), CUDNN and NCCL (Obtainable from the nvidai webpage).
### HowTo
1. Define a EMADL component containing architecture of a neural network and save it in a `.emadl` file. For more information on architecture language please refer to [CNNArchLang project](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/languages/CNNArchLang). An example of NN architecture:
diff --git a/pom.xml b/pom.xml
index 06b0cb4852aee9d48bb2af511d9781c4f302ddce..3011a371998cef7a22e5fdc9dd54216fb8f4fa28 100644
--- a/pom.xml
+++ b/pom.xml
@@ -9,19 +9,19 @@
de.monticore.lang.monticar
embedded-montiarc-emadl-generator
- 0.4.0
+ 0.4.1
- 0.2.11-SNAPSHOT
- 0.3.10-SNAPSHOT
- 0.0.6-SNAPSHOT
+ 0.2.12-SNAPSHOT
+ 0.3.12-SNAPSHOT
+ 0.0.7-SNAPSHOT
0.2.17-SNAPSHOT
0.2.14-SNAPSHOT
- 0.2.11-SNAPSHOT
+ 0.2.12-SNAPSHOT
0.1.0-SNAPSHOT
0.0.19-SNAPSHOT
0.1.6
diff --git a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLAbstractSymtab.java b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLAbstractSymtab.java
index dbc5cf31cd36bd05fb59f05c73d329ef4310a350..5a35d39f647c4565a8504c1c7141748effd2704c 100644
--- a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLAbstractSymtab.java
+++ b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLAbstractSymtab.java
@@ -7,6 +7,7 @@ import de.monticore.lang.embeddedmontiarc.LogConfig;
import de.monticore.lang.embeddedmontiarc.helper.ConstantPortHelper;
import de.monticore.lang.monticar.emadl._symboltable.EMADLLanguage;
import de.monticore.lang.monticar.emadl.tagging.dltag.DataPathTagSchema;
+import de.monticore.lang.monticar.emadl.tagging.dltag.LayerPathParameterTagSchema;
import de.monticore.lang.monticar.enumlang._symboltable.EnumLangLanguage;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.optimization.ThreadingOptimizer;
@@ -41,6 +42,7 @@ public class EMADLAbstractSymtab {
TagThresholdTagSchema.registerTagTypes(tagging);
TagDelayTagSchema.registerTagTypes(tagging);
DataPathTagSchema.registerTagTypes(tagging);
+ LayerPathParameterTagSchema.registerTagTypes(tagging);
return tagging;
}
diff --git a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java
index 8fecb4471919df7bfc57d6daa61f40701cbcf97b..c3a465ba0cff1b9d3e6a6979326e235675c9c0e0 100644
--- a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java
+++ b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java
@@ -26,6 +26,7 @@ import de.monticore.lang.monticar.cnntrain._symboltable.PreprocessingComponentSy
import de.monticore.lang.monticar.emadl._cocos.DataPathCocos;
import de.monticore.lang.monticar.emadl._cocos.EMADLCocos;
import de.monticore.lang.monticar.emadl.tagging.dltag.DataPathSymbol;
+import de.monticore.lang.monticar.emadl.tagging.dltag.LayerPathParameterSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.ArmadilloHelper;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
@@ -394,7 +395,7 @@ public class EMADLGenerator {
// TODO: Replace warinings with errors, until then use this method
stopGeneratorIfWarning();
- Log.warn("Tagging info for symbol was found, ignoring data_paths.txt: " + dataPath);
+ Log.warn("Tagging info for DataPath symbol was found, ignoring data_paths.txt: " + dataPath);
}
else {
Path dataPathDefinition = Paths.get(getModelsPath(), "data_paths.txt");
@@ -426,6 +427,37 @@ public class EMADLGenerator {
return weightsPath;
}
+ protected HashMap getLayerPathParameterTags(TaggingResolver taggingResolver, EMAComponentSymbol component, EMAComponentInstanceSymbol instance){
+ List instanceTags = new LinkedList<>();
+
+ boolean isChildComponent = instance.getEnclosingComponent().isPresent();
+
+ if (isChildComponent) {
+ // get all instantiated components of parent
+ List instantiationSymbols = (List) instance
+ .getEnclosingComponent().get().getComponentType().getReferencedSymbol().getSubComponents();
+
+ // filter corresponding instantiation of instance and add tags
+ instantiationSymbols.stream().filter(e -> e.getName().equals(instance.getName())).findFirst()
+ .ifPresent(symbol -> instanceTags.addAll(taggingResolver.getTags(symbol, LayerPathParameterSymbol.KIND)));
+ }
+
+ List tags = !instanceTags.isEmpty() ? instanceTags
+ : (List) taggingResolver.getTags(component, LayerPathParameterSymbol.KIND);
+
+ HashMap layerPathParameterTags = new HashMap();
+ if (!tags.isEmpty()) {
+ for(TagSymbol tag: tags) {
+ LayerPathParameterSymbol layerPathParameterSymbol = (LayerPathParameterSymbol) tag;
+ layerPathParameterTags.put(layerPathParameterSymbol.getId(), layerPathParameterSymbol.getPath());
+ }
+ // TODO: Replace warinings with errors, until then use this method
+ stopGeneratorIfWarning();
+ Log.warn("Tagging info for LayerPathParameter symbols was found.");
+ }
+ return layerPathParameterTags;
+ }
+
protected void generateComponent(List fileContents,
Set allInstances,
TaggingResolver taggingResolver,
@@ -448,8 +480,10 @@ public class EMADLGenerator {
cnnArchGenerator.check(architecture.get());
String dPath = getDataPath(taggingResolver, EMAComponentSymbol, componentInstanceSymbol);
String wPath = getWeightsPath(EMAComponentSymbol, componentInstanceSymbol);
+ HashMap layerPathParameterTags = getLayerPathParameterTags(taggingResolver, EMAComponentSymbol, componentInstanceSymbol);
architecture.get().setDataPath(dPath);
architecture.get().setWeightsPath(wPath);
+ architecture.get().processLayerPathParameterTags(layerPathParameterTags);
architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
if (processedArchitecture != null) {
diff --git a/src/main/resources/installation_scripts/install_after_cuda b/src/main/resources/installation_scripts/legacy_script_all_backends/install_after_cuda
similarity index 100%
rename from src/main/resources/installation_scripts/install_after_cuda
rename to src/main/resources/installation_scripts/legacy_script_all_backends/install_after_cuda
diff --git a/src/main/resources/installation_scripts/install_before_cuda b/src/main/resources/installation_scripts/legacy_script_all_backends/install_before_cuda
similarity index 100%
rename from src/main/resources/installation_scripts/install_before_cuda
rename to src/main/resources/installation_scripts/legacy_script_all_backends/install_before_cuda
diff --git a/src/main/resources/installation_scripts/mxnet_gluon_installation_script.sh b/src/main/resources/installation_scripts/mxnet_gluon_installation_script.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a13eb17e8baf3a3d4850efc94096267882329931
--- /dev/null
+++ b/src/main/resources/installation_scripts/mxnet_gluon_installation_script.sh
@@ -0,0 +1,33 @@
+sudo apt-get update -y
+sudp apt-get install -y build-essential git openjdk-8-jdk maven ninja-build ccache libopenblas-dev libblas-dev /
+ liblapack-dev libopencv-dev libarmadillo-dev cmake python2.7 python-dev /
+ python-numpy python3-pip python3-pip swig unzip libboost-all-dev
+
+sudo update-alternatives --config java
+
+pip3 install --user --upgrade "cmake>=3.13.2"
+
+wget https://bootstrap.pypa.io/get-pip.py
+python get-pip.py
+pip install --user h5py matplotlib numpy==1.16.5 mxnet==1.5.1.post0 #The newest version installed the curent standard version installed v1.7.0 cant be run with python2 the current standard of the EMDAL2CPP generator,
+ #As the needed numpy version is not suported anymore for python2 (python will no longer be supported).
+ #Further more not all test work with mxnet v1.6.0. And when just using v1.5.1 you can't compile against the libmxnet.so needed for compiling
+ #the cpp prediction part, the same holds for 1.6.0 but not 1.5.1.post0 and versions later than 1.6.0 (1.7.0) as there was some compression done for this
+ #library which was then droped again by the developers of mxnet.
+ #You could alternativly also use python 3.6 instead of 2.7, then you could also use the newest numpy version.
+ #Note that you then have to also set the PYTHON_PATH acordingly, or specifiy the python path for all applications in their build scripts
+ #and test currently only run on the python specified in the PYTHON_PATH.
+ #If you want to use mxnet with cuda install f.e. mxnet-cu100 for cuda 10.0 (if v1.5.1 is the newest version (here it specifing post0 is not
+ #neccescery), otherwise no gurantee for the numpy dependency, see above), #of course then you have to install cuda and cudnn beforehand.
+
+git clone --recursive https://github.com/apache/incubator-mxnet.git mxnet
+cd mxnet && git checkout tags/1.5.0 && git submodule update --recursive --init
+cd mxnet && mkdir build && cd build && cmake -DUSE_CPP_PACKAGE=1 -DUSE_CUDA=0 -GNinja .. && ninja -v
+cd mxnet && cp -r include/mxnet /usr/include/mxnet && cp -r cpp-package/include/mxnet-cpp /usr/include/ && cp -r 3rdparty/tvm/nnvm/include/nnvm /usr/include/ && cp -r 3rdparty/dmlc-core/include/dmlc /usr/include/
+
+#you have to have armadillo-9.600.6.zip in your current folder
+unzip armadillo.zip -d .
+cd armadillo-9.600.6 && cmake . && make && make install
+
+mkdir -p /root/.config/matplotlib
+echo "backend : Agg" > /root/.config/matplotlib/matplotlibrc
\ No newline at end of file
diff --git a/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java
index 6c988f46bf6bd1d9aa26316b1fbea73de47629f2..e14e7ed20ed5fd93bf855618ada9f884b2b8ec80 100644
--- a/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java
+++ b/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java
@@ -42,7 +42,7 @@ public class GenerationTest extends AbstractSymtabTest {
"cifar10_cifar10Classifier.cpp",
"cifar10_cifar10Classifier.h",
"CNNCreator_cifar10_cifar10Classifier_net.py",
- "CNNBufferFile.h",
+ "CNNModelLoader.h",
"CNNPredictor_cifar10_cifar10Classifier_net.h",
"cifar10_cifar10Classifier_net.h",
"CNNTranslator.h",
@@ -106,6 +106,13 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue(Log.getFindings().isEmpty());
}
+ @Test
+ public void testEpisodicMemorySimpleGeneration() throws IOException, TemplateException {
+ Log.getFindings().clear();
+ String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON", "-f", "n", "-c", "n"};
+ EMADLGeneratorCli.main(args);
+ }
+
@Test
public void testMultipleInstances() throws IOException, TemplateException {
try {
@@ -175,7 +182,7 @@ public class GenerationTest extends AbstractSymtabTest {
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code/gluon"),
Arrays.asList(
- "CNNBufferFile.h",
+ "CNNModelLoader.h",
"CNNNet_mnist_mnistClassifier_net.py",
"mnist_mnistClassifier.cpp",
"mnist_mnistClassifier.h",
@@ -183,7 +190,6 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNPredictor_mnist_mnistClassifier_net.h",
"CNNDataLoader_mnist_mnistClassifier_net.py",
"CNNSupervisedTrainer_mnist_mnistClassifier_net.py",
- "mnist_mnistClassifier_net.h",
"HelperA.h",
"CNNTranslator.h",
"mnist_mnistClassifier_calculateClass.h",
@@ -215,7 +221,7 @@ public class GenerationTest extends AbstractSymtabTest {
"cartpole_master_dqn.h",
"cartpole_master_policy.h",
"CMakeLists.txt",
- "CNNBufferFile.h",
+ "CNNModelLoader.h",
"CNNCreator_cartpole_master_dqn.py",
"CNNNet_cartpole_master_dqn.py",
"CNNPredictor_cartpole_master_dqn.h",
@@ -260,7 +266,7 @@ public class GenerationTest extends AbstractSymtabTest {
"mountaincar_master.h",
"mountaincar_master_actor.h",
"CMakeLists.txt",
- "CNNBufferFile.h",
+ "CNNModelLoader.h",
"CNNCreator_mountaincar_master_actor.py",
"CNNNet_mountaincar_master_actor.py",
"CNNPredictor_mountaincar_master_actor.h",
@@ -300,9 +306,6 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNTrainer_defaultGAN_defaultGANConnector_predictor.py",
"defaultGAN_defaultGANConnector.cpp",
"defaultGAN_defaultGANConnector.h",
- "defaultGAN_defaultGANConnector_predictor.h",
- "defaultGAN_defaultGANConnector.cpp",
- "defaultGAN_defaultGANConnector.h",
"defaultGAN_defaultGANConnector_predictor.h"
)
);
@@ -361,7 +364,7 @@ public class GenerationTest extends AbstractSymtabTest {
EMADLGeneratorCli.main(args);
assertEquals(Log.getFindings().size(), 1);
assertEquals(Log.getFindings().get(0).toString(),
- "Tagging info for symbol was found, ignoring data_paths.txt: src/test/resources/models");
+ "Tagging info for DataPath symbol was found, ignoring data_paths.txt: src/test/resources/models");
assertTrue(Log.getErrorCount() == 0);
}
diff --git a/src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java b/src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java
index 1a64af83fbff4ee85b56691f6274218c6f94595e..91afa7469ebb5a9c8e94dd0e473d52cbd61dfede 100644
--- a/src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java
+++ b/src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java
@@ -70,6 +70,16 @@ public class IntegrationGluonTest extends IntegrationTest {
assertTrue(Log.getFindings().isEmpty());
}
+ @Test
+ public void testEpisodicMemorySimple() {
+ Log.getFindings().clear();
+
+ deleteHashFile(Paths.get("./target/generated-sources-emadl/episodicMemorySimple/episodicMemorySimple.training_hash"));
+
+ String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON"};
+ EMADLGeneratorCli.main(args);
+ }
+
@Test
public void testGluonPreprocessingWithSupervised() {
Log.getFindings().clear();
diff --git a/src/test/java/de/monticore/lang/monticar/emadl/IntegrationPythonWrapperTest.java b/src/test/java/de/monticore/lang/monticar/emadl/IntegrationPythonWrapperTest.java
index 4334b7ea3f16c1a883d4177b6a160369a8d7c2a6..6e18f1455e6edc06ca637ac53c3d40a67f0e4d4d 100644
--- a/src/test/java/de/monticore/lang/monticar/emadl/IntegrationPythonWrapperTest.java
+++ b/src/test/java/de/monticore/lang/monticar/emadl/IntegrationPythonWrapperTest.java
@@ -27,7 +27,7 @@ public class IntegrationPythonWrapperTest extends AbstractSymtabTest {
Paths.get("./src/test/resources/target_code/gluon/reinforcementModel/torcs"),
Arrays.asList(
"CMakeLists.txt",
- "CNNBufferFile.h",
+ "CNNModelLoader.h",
"torcs_agent_torcsAgent.cpp",
"torcs_agent_torcsAgent.h",
"torcs_agent_torcsAgent_dqn.h",
@@ -82,7 +82,7 @@ public class IntegrationPythonWrapperTest extends AbstractSymtabTest {
Paths.get("./src/test/resources/target_code/gluon/reinforcementModel/torcs_td3"),
Arrays.asList(
"CMakeLists.txt",
- "CNNBufferFile.h",
+ "CNNModelLoader.h",
"torcs_agent_torcsAgent.cpp",
"torcs_agent_torcsAgent.h",
"torcs_agent_torcsAgent_actor.h",
diff --git a/src/test/resources/docker/mxnet/Dockerfile b/src/test/resources/docker/mxnet/Dockerfile
index e374b94a22f412f2c85a73c88d2f681c91e190e2..7c5945181003075b454ec478a3367be2a97aedc4 100644
--- a/src/test/resources/docker/mxnet/Dockerfile
+++ b/src/test/resources/docker/mxnet/Dockerfile
@@ -1,19 +1,22 @@
FROM maven:3-jdk-8
-RUN apt-get update && \
- apt-get install -y --no-install-recommends \
- git \
- libgtk2.0-dev \
- python-subprocess32 \
- python-tk \
- wget python gcc \
- build-essential cmake \
- liblapack-dev libblas-dev libboost-dev libarmadillo-dev && \
- rm -rf /var/lib/apt/lists/*
-RUN git clone https://github.com/apache/incubator-mxnet.git mxnet-source && \
- cd mxnet-source && git checkout tags/1.4.0 && cd .. && \
- cp -r mxnet-source/include/mxnet /usr/include/mxnet && \
- rm -r mxnet-source
+RUN apt-get update
+RUN apt-get install -y build-essential git ninja-build ccache libopenblas-dev libblas-dev liblapack-dev libopencv-dev libarmadillo-dev cmake python2.7 python-dev python-numpy python3-pip python3-pip swig unzip libboost-all-dev
+
+RUN pip3 install --user --upgrade "cmake>=3.13.2"
+
RUN wget https://bootstrap.pypa.io/get-pip.py
RUN python get-pip.py
-RUN pip install mxnet h5py opencv-python matplotlib
+RUN pip install --user h5py matplotlib numpy==1.16.5 mxnet==1.5.1.post0
+
+RUN git clone --recursive https://github.com/apache/incubator-mxnet.git mxnet
+RUN cd mxnet && git checkout tags/1.5.0 && git submodule update --recursive --init
+RUN cd mxnet && mkdir build && cd build && cmake -DUSE_CPP_PACKAGE=1 -DUSE_CUDA=0 -GNinja .. && ninja -v
+RUN cd mxnet && cp -r include/mxnet /usr/include/mxnet && cp -r cpp-package/include/mxnet-cpp /usr/include/ && cp -r 3rdparty/tvm/nnvm/include/nnvm /usr/include/ && cp -r 3rdparty/dmlc-core/include/dmlc /usr/include/
+
+ADD armadillo-9.600.6.zip /root/armadillo.zip
+RUN unzip /root/armadillo.zip -d /root/armadillo
+RUN cd /root/armadillo/armadillo-9.600.6 && cmake . && make && make install
+
+RUN mkdir -p /root/.config/matplotlib
+RUN echo "backend : Agg" > /root/.config/matplotlib/matplotlibrc
diff --git a/src/test/resources/docker/mxnet/armadillo-9.600.6.zip b/src/test/resources/docker/mxnet/armadillo-9.600.6.zip
new file mode 100644
index 0000000000000000000000000000000000000000..e053f3bc1670da20fb0a40eea7ee27a8ede111eb
Binary files /dev/null and b/src/test/resources/docker/mxnet/armadillo-9.600.6.zip differ
diff --git a/src/test/resources/models/episodicMemorySimple/Network.cnnt b/src/test/resources/models/episodicMemorySimple/Network.cnnt
new file mode 100644
index 0000000000000000000000000000000000000000..af61f45850c4c3529e2f3f0aac40bce4ed862389
--- /dev/null
+++ b/src/test/resources/models/episodicMemorySimple/Network.cnnt
@@ -0,0 +1,13 @@
+/* (c) https://github.com/MontiCore/monticore */
+configuration Network{
+ num_epoch:1
+ batch_size:5
+ normalize:false
+ context:cpu
+ load_checkpoint:false
+ loss:cross_entropy
+ optimizer:adam{
+ learning_rate:0.00003
+ weight_decay:0.01
+ }
+}
diff --git a/src/test/resources/models/episodicMemorySimple/Network.emadl b/src/test/resources/models/episodicMemorySimple/Network.emadl
new file mode 100644
index 0000000000000000000000000000000000000000..17341b277eaebbd1e4714fe2ec222f8764a66281
--- /dev/null
+++ b/src/test/resources/models/episodicMemorySimple/Network.emadl
@@ -0,0 +1,16 @@
+/* (c) https://github.com/MontiCore/monticore */
+package episodicMemorySimple;
+
+component Network{
+ ports in Z(0:oo)^{10} data,
+ out Q(0:1)^{33} softmax;
+
+ implementation CNN {
+ data ->
+ EpisodicMemory(replayInterval=10, replayBatchSize=100, replaySteps=1, replayGradientSteps=1, replayMemoryStoreProb=0.5, localAdaptionGradientSteps=30, maxStoredSamples=-1, localAdaptionK=32, queryNetDir="tag:simple", queryNetPrefix="simple_embedding-", queryNetNumInputs=1) ->
+ LoadNetwork(networkDir="tag:simple", networkPrefix="simple_embedding-", numInputs=1, outputShape=(1,768)) ->
+ FullyConnected(units=33) ->
+ Softmax() ->
+ softmax;
+ }
+}
diff --git a/src/test/resources/models/episodicMemorySimple/episodicMemorySimple.tag b/src/test/resources/models/episodicMemorySimple/episodicMemorySimple.tag
new file mode 100644
index 0000000000000000000000000000000000000000..0b5b8796a5b2257d0264e17928a0aa55b2ec6828
--- /dev/null
+++ b/src/test/resources/models/episodicMemorySimple/episodicMemorySimple.tag
@@ -0,0 +1,8 @@
+/* (c) https://github.com/MontiCore/monticore */
+package episodicMemorySimple;
+conforms to dltag.DataPathTagSchema, dltag.LayerPathParameterTagSchema;
+
+tags episodic {
+tag Network with DataPath = {path = src/test/resources/training_data/episodicMemorySimple, type = HDF5};
+tag Network with LayerPathParameter = {path = src/test/resources/pretrained/episodicMemorySimple, id = simple};
+}
diff --git a/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-0000.params b/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-0000.params
new file mode 100644
index 0000000000000000000000000000000000000000..3288444a7f261cd83beb3be9c9655c2e5a376e3f
Binary files /dev/null and b/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-0000.params differ
diff --git a/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-symbol.json b/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-symbol.json
new file mode 100644
index 0000000000000000000000000000000000000000..31c823adc6c18b5861cc3a3bd99ae6ad078b69ef
--- /dev/null
+++ b/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-symbol.json
@@ -0,0 +1,18 @@
+{
+ "nodes": [
+ {
+ "op": "null",
+ "name": "data",
+ "inputs": []
+ },
+ {
+ "op": "_copy",
+ "name": "simpleembedding0_identity0",
+ "inputs": [[0, 0, 0]]
+ }
+ ],
+ "arg_nodes": [0],
+ "node_row_ptr": [0, 1, 2],
+ "heads": [[1, 0, 0]],
+ "attrs": {"mxnet_version": ["int", 10501]}
+}
\ No newline at end of file
diff --git a/src/test/resources/target_code/CNNBufferFile.h b/src/test/resources/target_code/CNNBufferFile.h
deleted file mode 100644
index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000
--- a/src/test/resources/target_code/CNNBufferFile.h
+++ /dev/null
@@ -1,51 +0,0 @@
-#ifndef CNNBUFFERFILE_H
-#define CNNBUFFERFILE_H
-
-#include
-#include
-#include
-
-// Read file to buffer
-class BufferFile {
- public :
- std::string file_path_;
- int length_;
- char* buffer_;
-
- explicit BufferFile(std::string file_path)
- :file_path_(file_path) {
-
- std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
- if (!ifs) {
- std::cerr << "Can't open the file. Please check " << file_path << ". \n";
- length_ = 0;
- buffer_ = NULL;
- return;
- }
-
- ifs.seekg(0, std::ios::end);
- length_ = ifs.tellg();
- ifs.seekg(0, std::ios::beg);
- std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
-
- buffer_ = new char[sizeof(char) * length_];
- ifs.read(buffer_, length_);
- ifs.close();
- }
-
- int GetLength() {
- return length_;
- }
- char* GetBuffer() {
- return buffer_;
- }
-
- ~BufferFile() {
- if (buffer_) {
- delete[] buffer_;
- buffer_ = NULL;
- }
- }
-};
-
-#endif // CNNBUFFERFILE_H
diff --git a/src/test/resources/target_code/CNNModelLoader.h b/src/test/resources/target_code/CNNModelLoader.h
new file mode 100644
index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4
--- /dev/null
+++ b/src/test/resources/target_code/CNNModelLoader.h
@@ -0,0 +1,141 @@
+#ifndef CNNMODELLOADER
+#define CNNMODELLOADER
+
+#include
+
+#include
+#include
+#include
+
+using namespace mxnet::cpp;
+
+// Read files to load moddel symbol and parameters
+class ModelLoader {
+private:
+ Context ctx = Context::cpu();
+ std::vector network_symbol_list;
+ std::vector> network_param_map_list;
+
+ std::vector query_symbol_list;
+ std::vector> query_param_map_list;
+
+ std::vector> replay_memory;
+
+ std::vector loss_symbol;
+ std::vector> loss_param_map;
+
+
+ void checkFile(std::string file_path){
+ std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
+ if (!ifs) {
+ std::cerr << "Can't open the file. Please check " << file_path << ". \n";
+ return;
+ }
+
+ int length_;
+ ifs.seekg(0, std::ios::end);
+ length_ = ifs.tellg();
+ ifs.seekg(0, std::ios::beg);
+ std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
+ ifs.close();
+ }
+
+ void loadComponent(std::string json_path,
+ std::string param_path,
+ std::vector &symbols_list,
+ std::vector> ¶m_map_list){
+ checkFile(json_path);
+ symbols_list.push_back(Symbol::Load(json_path));
+ checkFile(param_path);
+ std::map params;
+ NDArray::Load(param_path, 0, ¶ms);
+ param_map_list.push_back(processParamMap(params));
+ }
+
+ std::map processParamMap(std::map param_map){
+ std::map processed_param_map;
+ if(!param_map.empty()){
+ for (const auto &pair : param_map) {
+ std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure)
+ processed_param_map[name] = pair.second.Copy(ctx);
+ }
+ }
+ return processed_param_map;
+ }
+
+public:
+ explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){
+
+ ctx = ctx_param;
+ std::string network_json_path;
+ std::string network_param_path;
+ std::string query_json_path;
+ std::string query_param_path;
+ std::string memory_path;
+ std::string loss_json_path;
+ std::string loss_param_path;
+
+ //Load network
+ if(!num_subnets){
+ network_json_path = file_prefix + "-symbol.json";
+ network_param_path = file_prefix + "-0000.params";
+ loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list);
+ }else{
+ for(int i=0; i < num_subnets; i++){
+ network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json";
+ network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params";
+ loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list);
+ if(i >= 1){
+ query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json";
+ query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params";
+ loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list);
+
+ memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000";
+ checkFile(memory_path);
+
+ std::map mem_map = NDArray::LoadToMap(memory_path);
+ for(auto &mem : mem_map){
+ mem.second = mem.second.Copy(ctx);
+ }
+ replay_memory.push_back(mem_map);
+ }
+ }
+ }
+
+ //Load Loss
+ loss_json_path = file_prefix + "_loss-symbol.json";
+ loss_param_path = file_prefix + "_loss-0000.params";
+ loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map);
+
+ NDArray::WaitAll();
+ }
+
+ std::vector GetNetworkSymbols() {
+ return network_symbol_list;
+ }
+
+ std::vector> GetNetworkParamMaps() {
+ return network_param_map_list;
+ }
+
+ Symbol GetLoss() {
+ return loss_symbol[0];
+ }
+
+ std::map GetLossParamMap() {
+ return loss_param_map[0];
+ }
+
+ std::vector GetQuerySymbols() {
+ return query_symbol_list;
+ }
+
+ std::vector> GetQueryParamMaps() {
+ return query_param_map_list;
+ }
+
+ std::vector> GetReplayMemory(){
+ return replay_memory;
+ }
+};
+#endif // CNNMODELLOADER
diff --git a/src/test/resources/target_code/gluon/CNNBufferFile.h b/src/test/resources/target_code/gluon/CNNBufferFile.h
deleted file mode 100644
index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000
--- a/src/test/resources/target_code/gluon/CNNBufferFile.h
+++ /dev/null
@@ -1,51 +0,0 @@
-#ifndef CNNBUFFERFILE_H
-#define CNNBUFFERFILE_H
-
-#include
-#include
-#include
-
-// Read file to buffer
-class BufferFile {
- public :
- std::string file_path_;
- int length_;
- char* buffer_;
-
- explicit BufferFile(std::string file_path)
- :file_path_(file_path) {
-
- std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
- if (!ifs) {
- std::cerr << "Can't open the file. Please check " << file_path << ". \n";
- length_ = 0;
- buffer_ = NULL;
- return;
- }
-
- ifs.seekg(0, std::ios::end);
- length_ = ifs.tellg();
- ifs.seekg(0, std::ios::beg);
- std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
-
- buffer_ = new char[sizeof(char) * length_];
- ifs.read(buffer_, length_);
- ifs.close();
- }
-
- int GetLength() {
- return length_;
- }
- char* GetBuffer() {
- return buffer_;
- }
-
- ~BufferFile() {
- if (buffer_) {
- delete[] buffer_;
- buffer_ = NULL;
- }
- }
-};
-
-#endif // CNNBUFFERFILE_H
diff --git a/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py b/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py
index 21ce4267c47a7ef296f301abf76932775e07bbfa..814f8ebaa5b866cf328ff459946f0af3a4571be2 100644
--- a/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py
+++ b/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py
@@ -2,6 +2,8 @@ import mxnet as mx
import logging
import os
import shutil
+import warnings
+import inspect
from CNNNet_mnist_mnistClassifier_net import Net_0
@@ -20,6 +22,10 @@ class CNNCreator_mnist_mnistClassifier_net:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0]*num_episodic_sub_nets
+ mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
@@ -30,22 +36,77 @@ class CNNCreator_mnist_mnistClassifier_net:
except OSError:
pass
+ if hasattr(network, 'episodic_sub_nets'):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
+ except OSError:
+ pass
+
+ for j in range(len(network.episodic_sub_nets)):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
+ except OSError:
+ pass
+
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
- epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
+ epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
+ if hasattr(network, 'episodic_sub_nets'):
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading Replay Memory: " + mem_files[j])
+ mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
- if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
- earliestLastEpoch = lastEpoch
+ if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
+ earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
@@ -56,27 +117,52 @@ class CNNCreator_mnist_mnistClassifier_net:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0] * num_episodic_sub_nets
+ mem_files = [None] * num_episodic_sub_nets
+
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
+ if hasattr(network, 'episodic_sub_nets'):
+ assert lastEpoch == lastMemEpoch
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading pretrained Replay Memory: " + mem_files[j])
+ mem_layer = \
+ [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
+ param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
- self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
- self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
+ self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[0].hybridize()
- self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context))
+ self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context[0]))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
diff --git a/src/test/resources/target_code/gluon/CNNModelLoader.h b/src/test/resources/target_code/gluon/CNNModelLoader.h
new file mode 100644
index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4
--- /dev/null
+++ b/src/test/resources/target_code/gluon/CNNModelLoader.h
@@ -0,0 +1,141 @@
+#ifndef CNNMODELLOADER
+#define CNNMODELLOADER
+
+#include
+
+#include
+#include
+#include
+
+using namespace mxnet::cpp;
+
+// Read files to load moddel symbol and parameters
+class ModelLoader {
+private:
+ Context ctx = Context::cpu();
+ std::vector network_symbol_list;
+ std::vector> network_param_map_list;
+
+ std::vector query_symbol_list;
+ std::vector> query_param_map_list;
+
+ std::vector> replay_memory;
+
+ std::vector loss_symbol;
+ std::vector> loss_param_map;
+
+
+ void checkFile(std::string file_path){
+ std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
+ if (!ifs) {
+ std::cerr << "Can't open the file. Please check " << file_path << ". \n";
+ return;
+ }
+
+ int length_;
+ ifs.seekg(0, std::ios::end);
+ length_ = ifs.tellg();
+ ifs.seekg(0, std::ios::beg);
+ std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
+ ifs.close();
+ }
+
+ void loadComponent(std::string json_path,
+ std::string param_path,
+ std::vector &symbols_list,
+ std::vector> ¶m_map_list){
+ checkFile(json_path);
+ symbols_list.push_back(Symbol::Load(json_path));
+ checkFile(param_path);
+ std::map params;
+ NDArray::Load(param_path, 0, ¶ms);
+ param_map_list.push_back(processParamMap(params));
+ }
+
+ std::map processParamMap(std::map param_map){
+ std::map processed_param_map;
+ if(!param_map.empty()){
+ for (const auto &pair : param_map) {
+ std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure)
+ processed_param_map[name] = pair.second.Copy(ctx);
+ }
+ }
+ return processed_param_map;
+ }
+
+public:
+ explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){
+
+ ctx = ctx_param;
+ std::string network_json_path;
+ std::string network_param_path;
+ std::string query_json_path;
+ std::string query_param_path;
+ std::string memory_path;
+ std::string loss_json_path;
+ std::string loss_param_path;
+
+ //Load network
+ if(!num_subnets){
+ network_json_path = file_prefix + "-symbol.json";
+ network_param_path = file_prefix + "-0000.params";
+ loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list);
+ }else{
+ for(int i=0; i < num_subnets; i++){
+ network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json";
+ network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params";
+ loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list);
+ if(i >= 1){
+ query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json";
+ query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params";
+ loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list);
+
+ memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000";
+ checkFile(memory_path);
+
+ std::map mem_map = NDArray::LoadToMap(memory_path);
+ for(auto &mem : mem_map){
+ mem.second = mem.second.Copy(ctx);
+ }
+ replay_memory.push_back(mem_map);
+ }
+ }
+ }
+
+ //Load Loss
+ loss_json_path = file_prefix + "_loss-symbol.json";
+ loss_param_path = file_prefix + "_loss-0000.params";
+ loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map);
+
+ NDArray::WaitAll();
+ }
+
+ std::vector GetNetworkSymbols() {
+ return network_symbol_list;
+ }
+
+ std::vector> GetNetworkParamMaps() {
+ return network_param_map_list;
+ }
+
+ Symbol GetLoss() {
+ return loss_symbol[0];
+ }
+
+ std::map GetLossParamMap() {
+ return loss_param_map[0];
+ }
+
+ std::vector GetQuerySymbols() {
+ return query_symbol_list;
+ }
+
+ std::vector> GetQueryParamMaps() {
+ return query_param_map_list;
+ }
+
+ std::vector> GetReplayMemory(){
+ return replay_memory;
+ }
+};
+#endif // CNNMODELLOADER
diff --git a/src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py b/src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py
index e376a095669594ec94a489b736580cf0aebb6c26..415132dc1f96139e94a977f7ab3ac87790f66e78 100644
--- a/src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py
+++ b/src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py
@@ -1,7 +1,10 @@
import mxnet as mx
import numpy as np
import math
-from mxnet import gluon
+import os
+import abc
+import warnings
+from mxnet import gluon, nd
class ZScoreNormalization(gluon.HybridBlock):
@@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
+
+class DotProductSelfAttention(gluon.HybridBlock):
+ def __init__(self,
+ scale_factor,
+ num_heads,
+ dim_model,
+ dim_keys,
+ dim_values,
+ use_proj_bias,
+ use_mask,
+ **kwargs):
+ super(DotProductSelfAttention, self).__init__(**kwargs)
+ with self.name_scope():
+ self.num_heads = num_heads
+ self.dim_model = dim_model
+ self.use_proj_bias = use_proj_bias
+ self.use_mask = use_mask
+
+ if dim_keys == -1:
+ self.dim_keys = int(dim_model / self.num_heads)
+ else:
+ self.dim_keys = dim_keys
+ if dim_values == -1:
+ self.dim_values = int(dim_model / self.num_heads)
+ else:
+ self.dim_values = dim_values
+
+ if scale_factor == -1:
+ self.scale_factor = math.sqrt(self.dim_keys)
+ else:
+ self.scale_factor = scale_factor
+
+ self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)
+
+ def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):
+
+ queries = F.Reshape(queries, shape=(0, 0,-1))
+ keys = F.Reshape(queries, shape=(0, 0, -1))
+ values = F.Reshape(queries, shape=(0, 0, -1))
+
+ head_queries = self.proj_q(queries)
+ head_keys = self.proj_k(keys)
+ head_values = self.proj_v(values)
+
+ head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
+ head_queries = F.transpose(head_queries, axes=(0,2,1,3))
+ head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)
+
+ head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
+ head_keys = F.transpose(head_keys, axes=(0,2,1,3))
+ head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)
+
+ score = F.batch_dot(head_queries, head_keys, transpose_b=True)
+ score = score * self.scale_factor
+ if self.use_mask:
+ mask = F.tile(mask, self.num_heads)
+ mask = F.repeat(mask, self.dim_model)
+ mask = F.reshape(mask, shape=(-1, self.dim_model))
+ weights = F.softmax(score, mask, use_length=self.use_mask)
+
+ head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
+ head_values = F.transpose(head_values, axes=(0,2,1,3))
+ head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)
+
+ ret = F.batch_dot(weights, head_values)
+ ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
+ ret = F.transpose(ret, axes=(0, 2, 1, 3))
+ ret = F.reshape(ret, shape=(0, 0, -1))
+
+ ret = self.proj_o(ret)
+
+ return ret
+
+
+class EpisodicReplayMemoryInterface(gluon.HybridBlock):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
+ super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
+
+ self.use_replay = use_replay
+ self.replay_interval = replay_interval
+ self.replay_batch_size = replay_batch_size
+ self.replay_steps = replay_steps
+ self.replay_gradient_steps = replay_gradient_steps
+ self.num_heads = num_heads
+
+ @abc.abstractmethod
+ def store_samples(self, data, y, query_network, store_prob, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def sample_memory(self, batch_size, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def get_query_network(self, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def save_memory(self, path):
+ pass
+
+ @abc.abstractmethod
+ def load_memory(self, path):
+ pass
+
+#Memory layer
+class LargeMemory(gluon.HybridBlock):
+ def __init__(self,
+ sub_key_size,
+ query_size,
+ query_act,
+ dist_measure,
+ k,
+ num_heads,
+ values_dim,
+ **kwargs):
+ super(LargeMemory, self).__init__(**kwargs)
+ with self.name_scope():
+ #Memory parameters
+ self.dist_measure = dist_measure
+ self.k = k
+ self.num_heads = num_heads
+ self.query_act = query_act
+ self.query_size = query_size
+ self.num_heads = num_heads
+
+ #Batch norm sub-layer
+ self.batch_norm = gluon.nn.BatchNorm()
+
+ #Memory sub-layer
+ self.sub_key_size = sub_key_size
+ sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
+
+ if values_dim == -1:
+ values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
+ else:
+ values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
+
+ self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
+ self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
+ self.values = self.params.get("values", shape=values_shape, differentiable=True)
+ self.label_memory = nd.array([])
+
+ self.get_query_network()
+
+ def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
+ x = self.batch_norm(x)
+
+ x = F.reshape(x, shape=(0, -1))
+
+ q = self.query_network(x)
+
+ q = F.reshape(q, shape=(0, self.num_heads, -1))
+
+ q_split = F.split(q, num_outputs=2, axis=-1)
+
+ if self.dist_measure == "l2":
+ q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
+ sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
+ q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
+ q1_dist = F.norm(q1_diff, axis=-1)
+ q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
+ sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
+ q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
+ q2_dist = F.norm(q2_diff, axis=-1)
+ else:
+ q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
+ q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
+ sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ q1 = [q1]
+ q2 = [q2]
+ sub_keys1_resh = [sub_keys1_resh ]
+ sub_keys2_resh = [sub_keys2_resh ]
+
+ q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
+ q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
+ for h in range(1, self.num_heads):
+ q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+ q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+
+ i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
+ i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
+
+ # Calculate cross product for keys at indices I1 and I2
+
+ # def head_take(data, state):
+ # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
+ #
+ # i1 = F.transpose(i1, axes=(1,0,2))
+ # i2 = F.transpose(i2, axes=(1, 0, 2))
+ # st = F.zeros(1)
+ # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
+ # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
+ # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
+ i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
+ i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
+ sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ i1 = [i1]
+ i2 = [i2]
+ sub_keys1 = [sub_keys1]
+ sub_keys2 = [sub_keys2]
+
+ k1 = F.take(sub_keys1[0], i1[0])
+ k2 = F.take(sub_keys2[0], i2[0])
+ for h in range(1, self.num_heads):
+ k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
+ k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
+
+ k1 = F.tile(k1, (1, 1, self.k, 1))
+ k2 = F.repeat(k2, self.k, 2)
+ c_cart = F.concat(k1, k2, dim=3)
+
+ q = F.reshape(q, shape=(-1,0), reverse=True)
+ q = F.reshape(q, shape=(0, 1, -1))
+ c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
+ if self.dist_measure == "l2":
+ k_diff = F.broadcast_sub(q, c_cart)
+ k_dist = F.norm(k_diff, axis=-1)
+ else:
+ k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
+ k_dist = F.reshape(k_dist, shape=(0, -1))
+
+ i = F.topk(k_dist, k=self.k, ret_typ="both")
+
+ w = F.softmax(i[0])
+ w = F.reshape(w, shape=(0,1,-1))
+ vi = F.take(values, i[1])
+ aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist)
+
+ ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True)
+ one_vec = F.ones((1, 1, self.num_heads))
+ one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0)
+ ret = F.batch_dot(one_vec, ret)
+ ret = F.reshape(ret, shape=(-1, 0), reverse=True)
+
+ return ret
+
+ def get_query_network(self):
+ if hasattr(self, 'query_network'):
+ return self.query_network
+ else:
+ self.query_network = gluon.nn.HybridSequential()
+ for size in self.query_size:
+ if self.query_act == "linear":
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False))
+ else:
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
+ return self.query_network
+
+
+#EpisodicMemory layer
+class EpisodicMemory(EpisodicReplayMemoryInterface):
+ def __init__(self,
+ replay_interval,
+ replay_batch_size,
+ replay_steps,
+ replay_gradient_steps,
+ store_prob,
+ max_stored_samples,
+ memory_replacement_strategy,
+ use_replay,
+ query_net_dir,
+ query_net_prefix,
+ query_net_num_inputs,
+ **kwargs):
+ super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
+ with self.name_scope():
+ #Replay parameters
+ self.store_prob = store_prob
+ self.max_stored_samples = max_stored_samples
+ self.memory_replacement_strategy = memory_replacement_strategy
+
+ self.query_net_dir = query_net_dir
+ self.query_net_prefix = query_net_prefix
+ self.query_net_num_inputs = query_net_num_inputs
+
+ #Memory
+ self.key_memory = nd.array([])
+ self.value_memory = nd.array([])
+ self.label_memory = nd.array([])
+
+ def hybrid_forward(self, F, *args):
+ #propagate the input as the rest is only used for replay
+ return [args, []]
+
+ def store_samples(self, data, y, query_network, store_prob, context):
+ if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
+ num_pus = len(data)
+ sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
+ num_inputs = len(data[0][0])
+ num_outputs = len(y)
+ mx_context = context[0]
+
+ if len(self.key_memory) == 0:
+ self.key_memory = nd.empty(0, ctx=mx.cpu())
+ self.value_memory = []
+ self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
+
+ ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
+
+ max_inds = [nd.max(ind[i]) for i in range(num_pus)]
+ if any(max_inds):
+ to_store_values = []
+ for i in range(num_inputs):
+ tmp_values = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_values, list):
+ tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
+ else:
+ tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_values.append(tmp_values)
+
+ to_store_labels = []
+ for i in range(num_outputs):
+ tmp_labels = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_labels, list):
+ tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
+ else:
+ tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_labels.append(tmp_labels)
+
+ to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
+
+ if self.key_memory.shape[0] == 0:
+ self.key_memory = to_store_keys.as_in_context(mx.cpu())
+ for i in range(num_inputs):
+ self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
+ for i in range(num_outputs):
+ self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
+ elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
+ num_to_store = to_store_keys.shape[0]
+ self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+ else:
+ self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+
+ def sample_memory(self, batch_size):
+ num_stored_samples = self.key_memory.shape[0]
+ if self.replay_batch_size == -1:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu())
+ else:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu())
+
+ num_outputs = len(self.label_memory)
+
+ sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind]
+ sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)]
+
+ return sample_batches
+
+ def get_query_network(self, context):
+ lastEpoch = 0
+ for file in os.listdir(self.query_net_dir):
+ if self.query_net_prefix in file and ".json" in file:
+ symbolFile = file
+
+ if self.query_net_prefix in file and ".param" in file:
+ epochStr = file.replace(".params", "").replace(self.query_net_prefix, "")
+ epoch = int(epochStr)
+ if epoch >= lastEpoch:
+ lastEpoch = epoch
+ weightFile = file
+
+ inputNames = []
+ if self.query_net_num_inputs == 1:
+ inputNames.append("data")
+ else:
+ for i in range(self.query_net_num_inputs):
+ inputNames.append("data" + str(i))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0])
+ net.hybridize()
+ return net
+
+ def save_memory(self, path):
+ mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
+ mem_dict = {entry[0]:entry[1] for entry in mem_arr}
+ nd.save(path, mem_dict)
+
+ def load_memory(self, path):
+ mem_dict = nd.load(path)
+ self.value_memory = []
+ self.label_memory = []
+ for key in sorted(mem_dict.keys()):
+ if key == "keys":
+ self.key_memory = mem_dict[key]
+ elif key.startswith("values_"):
+ self.value_memory.append(mem_dict[key])
+ elif key.startswith("labels_"):
+ self.label_memory.append(mem_dict[key])
+
+
+#Stream 0
class Net_0(gluon.HybridBlock):
- def __init__(self, data_mean=None, data_std=None, **kwargs):
+ def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
with self.name_scope():
if data_mean:
@@ -146,5 +562,5 @@ class Net_0(gluon.HybridBlock):
softmax3_ = F.softmax(fc3_, axis=-1)
predictions_ = F.identity(softmax3_)
- return predictions_
+ return [[predictions_]]
diff --git a/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h b/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h
index 75d7c61ecdf6ad99e934c09dbcee574020c2d693..55a9d6320cde26d83702a786447c4bca4e84234c 100644
--- a/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h
+++ b/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h
@@ -1,107 +1,149 @@
#ifndef CNNPREDICTOR_MNIST_MNISTCLASSIFIER_NET
#define CNNPREDICTOR_MNIST_MNISTCLASSIFIER_NET
-#include
+#include
#include
#include
#include
+
+#include
+#include
-#include
-
+using namespace mxnet::cpp;
+
class CNNPredictor_mnist_mnistClassifier_net_0{
public:
- const std::string json_file = "model/mnist.LeNetNetwork/model_0_newest-symbol.json";
- const std::string param_file = "model/mnist.LeNetNetwork/model_0_newest-0000.params";
- const std::vector input_keys = {
+ const std::string file_prefix = "model/mnist.LeNetNetwork/model_0_newest";
+
+ //network
+ const std::vector network_input_keys = {
"data"
};
- const std::vector> input_shapes = {{1, 1, 28, 28}};
- const bool use_gpu = false;
-
- PredictorHandle handle;
-
+ const std::vector> network_input_shapes = {{1, 1, 28, 28}};
+ std::vector network_input_sizes;
+ std::vector> network_arg_names;
+ std::vector network_handles;
+
+
+ //misc
+ Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu
+ int dtype = 0; //use data type (float32=0 float64=1 ...)
+
+
explicit CNNPredictor_mnist_mnistClassifier_net_0(){
- init(json_file, param_file, input_keys, input_shapes, use_gpu);
+ init(file_prefix, network_input_keys, network_input_shapes);
}
~CNNPredictor_mnist_mnistClassifier_net_0(){
- if(handle) MXPredFree(handle);
+ for(Executor * handle : network_handles){
+ delete handle;
+ }
+ MXNotifyShutdown();
}
void predict(const std::vector &in_image_,
std::vector &out_predictions_){
- MXPredSetInput(handle, input_keys[0].c_str(), in_image_.data(), static_cast(in_image_.size()));
-
- MXPredForward(handle);
- mx_uint output_index;
- mx_uint *shape = 0;
- mx_uint shape_len;
- size_t size;
-
- output_index = 0;
- MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
- size = 1;
- for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
- assert(size == out_predictions_.size());
- MXPredGetOutput(handle, output_index, &(out_predictions_[0]), out_predictions_.size());
+ NDArray input_temp;
+ input_temp = NDArray(network_input_shapes[0], ctx, false, dtype);
+ input_temp.SyncCopyFromCPU(in_image_.data(), network_input_sizes[0]);
+ input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]]));
+ NDArray::WaitAll();
+
+ network_handles[0]->Forward(false);
+ CheckMXNetError("Forward, predict, handle ind. 0");
+
+
+ std::vector output = network_handles.back()->outputs;
+ std::vector curr_output_shape;
+ size_t curr_output_size;
+ curr_output_shape = output[0].GetShape();
+ curr_output_size = 1;
+ for (mx_uint i : curr_output_shape) curr_output_size *= i;
+ //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs
+ assert((curr_output_size == out_predictions_.size()) || (curr_output_size == out_predictions_[0]));
+ output[0].SyncCopyToCPU(&out_predictions_);
+
}
+
+
+
+ Executor* initExecutor(Symbol &sym,
+ std::map ¶m_map,
+ const std::vector &exec_input_keys,
+ const std::vector> &exec_input_shapes){
+
+ const mx_uint num_exec_input_nodes = exec_input_keys.size();
+ for(mx_uint i = 0; i < num_exec_input_nodes; i++){
+ param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype);
+ }
- void init(const std::string &json_file,
- const std::string ¶m_file,
- const std::vector &input_keys,
- const std::vector> &input_shapes,
- const bool &use_gpu){
+ std::vector param_arrays;
+ std::vector grad_array;
+ std::vector grad_reqs;
+ std::vector aux_arrays;
+ std::map< std::string, NDArray> aux_map;
- BufferFile json_data(json_file);
- BufferFile param_data(param_file);
+ sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs,
+ &aux_arrays, param_map, std::map(),
+ std::map(), aux_map);
- int dev_type = use_gpu ? 2 : 1;
- int dev_id = 0;
+ Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays);
+ assert(handle);
+ return handle;
+ }
- if (json_data.GetLength() == 0 ||
- param_data.GetLength() == 0) {
- std::exit(-1);
+ std::vector getSizesOfShapes(const std::vector> shapes){
+ std::vector sizes;
+ for(std::vector shape : shapes){
+ mx_uint val = 1;
+ for(mx_uint i: shape){
+ val *= i;
+ }
+ sizes.push_back(val);
}
+ return sizes;
+ }
- const mx_uint num_input_nodes = input_keys.size();
-
- const char* input_keys_ptr[num_input_nodes];
- for(mx_uint i = 0; i < num_input_nodes; i++){
- input_keys_ptr[i] = input_keys[i].c_str();
+ void CheckMXNetError(std::string loc){
+ const char* err = MXGetLastError();
+ if (err && err[0] != 0) {
+ std::cout << "MXNet error at " << loc << err << std::endl;
+ exit(-1);
}
-
- mx_uint shape_data_size = 0;
- mx_uint input_shape_indptr[input_shapes.size() + 1];
- input_shape_indptr[0] = 0;
- for(mx_uint i = 0; i < input_shapes.size(); i++){
- shape_data_size += input_shapes[i].size();
- input_shape_indptr[i+1] = shape_data_size;
+ }
+
+ void init(const std::string &file_prefix,
+ const std::vector &network_input_keys,
+ const std::vector> &network_input_shapes){
+
+ CNNLAOptimizer_mnist_mnistClassifier_net optimizer_creator = CNNLAOptimizer_mnist_mnistClassifier_net();
+
+ if(optimizer_creator.getContextName() == "gpu"){
+ ctx = Context::gpu();
}
-
- mx_uint input_shape_data[shape_data_size];
- mx_uint index = 0;
- for(mx_uint i = 0; i < input_shapes.size(); i++){
- for(mx_uint j = 0; j < input_shapes[i].size(); j++){
- input_shape_data[index] = input_shapes[i][j];
- index++;
- }
+
+ network_input_sizes = getSizesOfShapes(network_input_shapes);
+
+ ModelLoader model_loader(file_prefix, 0, ctx);
+
+ std::vector network_symbols = model_loader.GetNetworkSymbols();
+ std::vector> network_param_maps;
+ network_param_maps = model_loader.GetNetworkParamMaps();
+
+ //Init handles
+ std::map> in_shape_map;
+ for(mx_uint i=0; i < network_input_keys.size(); i++){
+ in_shape_map[network_input_keys[i]] = network_input_shapes[i];
}
-
- MXPredCreate(static_cast(json_data.GetBuffer()),
- static_cast(param_data.GetBuffer()),
- static_cast(param_data.GetLength()),
- dev_type,
- dev_id,
- num_input_nodes,
- input_keys_ptr,
- input_shape_indptr,
- input_shape_data,
- &handle);
- assert(handle);
+ std::vector> in_shapes;
+ std::vector> aux_shapes;
+ std::vector> out_shapes;
+ network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes);
+ network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes));
+
}
};
-
#endif // CNNPREDICTOR_MNIST_MNISTCLASSIFIER_NET
diff --git a/src/test/resources/target_code/gluon/CNNSupervisedTrainer_mnist_mnistClassifier_net.py b/src/test/resources/target_code/gluon/CNNSupervisedTrainer_mnist_mnistClassifier_net.py
index b6f1a59372febc71f4311d93c99c8165570c861c..7403768f7343f1cfe11767ed31de2966d3e85924 100644
--- a/src/test/resources/target_code/gluon/CNNSupervisedTrainer_mnist_mnistClassifier_net.py
+++ b/src/test/resources/target_code/gluon/CNNSupervisedTrainer_mnist_mnistClassifier_net.py
@@ -7,7 +7,13 @@ import shutil
import pickle
import math
import sys
+import inspect
from mxnet import gluon, autograd, nd
+try:
+ import AdamW
+except:
+ pass
+
class CrossEntropyLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
@@ -54,7 +60,7 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. tokens in NLP applications
for i in self._ignore_indices:
- loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
+ loss = F.broadcast_mul(loss, F.logical_not(F.broadcast_equal(F.argmax(pred, axis=1), F.ones_like(F.argmax(pred, axis=1))*i) * F.broadcast_equal(F.argmax(pred, axis=1), label)))
return loss.mean(axis=self._batch_axis, exclude=True)
class DiceLoss(gluon.loss.Loss):
@@ -277,12 +283,21 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
shuffle_data=False,
clip_global_grad_norm=None,
preprocessing = False):
+ num_pus = 1
if context == 'gpu':
- mx_context = mx.gpu()
+ num_pus = mx.context.num_gpus()
+ if num_pus >= 1:
+ if num_pus == 1:
+ mx_context = [mx.gpu(0)]
+ else:
+ mx_context = [mx.gpu(i) for i in range(num_pus)]
+ else:
+ logging.error("Context argument is '" + context + "'. But no gpu is present in the system.")
elif context == 'cpu':
- mx_context = mx.cpu()
+ mx_context = [mx.cpu()]
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
+ single_pu_batch_size = int(batch_size/num_pus)
if preprocessing:
preproc_lib = "CNNPreprocessor_mnist_mnistClassifier_net_executor"
@@ -327,7 +342,10 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
if not os.path.isdir(self._net_creator._model_dir_):
raise
- trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
+ if optimizer == "adamw":
+ trainers = [mx.gluon.Trainer(network.collect_params(), AdamW.AdamW(**optimizer_params)) for network in self._networks.values() if len(network.collect_params().values()) != 0]
+ else:
+ trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
@@ -372,9 +390,16 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
loss_function = LogCoshLoss()
else:
logging.error("Invalid loss parameter.")
+
+ loss_function.hybridize()
+
+
tic = None
+ avg_speed = 0
+ n = 0
+
for epoch in range(begin_epoch, begin_epoch + num_epoch):
if shuffle_data:
if preprocessing:
@@ -389,31 +414,36 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
loss_total = 0
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
+
+
with autograd.record():
- labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
+ labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False) for i in range(1)]
- image_ = batch.data[0].as_in_context(mx_context)
+ image_ = gluon.utils.split_and_load(batch.data[0], ctx_list=mx_context, even_split=False)
- predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
+ predictions_ = [mx.nd.zeros((single_pu_batch_size, 10,), ctx=context) for context in mx_context]
nd.waitall()
-
lossList = []
+ for i in range(num_pus):
+ lossList.append([])
- predictions_ = self._networks[0](image_)
+ net_ret = [self._networks[0](image_[i]) for i in range(num_pus)]
+ predictions_ = [net_ret[i][0][0] for i in range(num_pus)]
+ [lossList[i].append(loss_function(predictions_[i], labels[0][i])) for i in range(num_pus)]
- lossList.append(loss_function(predictions_, labels[0]))
-
- loss = 0
- for element in lossList:
- loss = loss + element
- loss.backward()
+ losses = [0]*num_pus
+ for i in range(num_pus):
+ for element in lossList[i]:
+ losses[i] = losses[i] + element
- loss_total += loss.sum().asscalar()
+ for loss in losses:
+ loss.backward()
+ loss_total += loss.sum().asscalar()
+ global_loss_train += loss.sum().asscalar()
- global_loss_train += loss.sum().asscalar()
train_batches += 1
if clip_global_grad_norm:
@@ -426,7 +456,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
for trainer in trainers:
trainer.step(batch_size)
-
+
if tic is None:
tic = time.time()
else:
@@ -440,36 +470,39 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
loss_total = 0
logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg))
-
+
+ avg_speed += speed
+ n += 1
+
tic = time.time()
global_loss_train /= (train_batches * batch_size)
tic = None
-
if eval_train:
train_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_iter):
- labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
-
- image_ = batch.data[0].as_in_context(mx_context)
+ labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False)[0] for i in range(1)]
+ image_ = gluon.utils.split_and_load(batch.data[0], ctx_list=mx_context, even_split=False)[0]
- predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
+ predictions_ = mx.nd.zeros((single_pu_batch_size, 10,), ctx=mx_context[0])
nd.waitall()
- outputs = []
lossList = []
+ outputs = []
attentionList = []
- predictions_ = self._networks[0](image_)
+ net_ret = self._networks[0](image_)
+ predictions_ = net_ret[0][0]
outputs.append(predictions_)
lossList.append(loss_function(predictions_, labels[0]))
+
if save_attention_image == "True":
import matplotlib
matplotlib.use('Agg')
@@ -510,7 +543,6 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
os.makedirs(target_dir)
plt.savefig(target_dir + '/attention_train.png')
plt.close()
-
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
@@ -518,7 +550,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
else:
predictions.append(output_name)
- metric.update(preds=predictions, labels=labels)
+ metric.update(preds=predictions, labels=[labels[j] for j in range(len(labels))])
+
train_metric_score = metric.get()[1]
else:
train_metric_score = 0
@@ -529,25 +562,26 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(test_iter):
- if True:
- labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
-
- image_ = batch.data[0].as_in_context(mx_context)
+ if True:
+ labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False)[0] for i in range(1)]
+ image_ = gluon.utils.split_and_load(batch.data[0], ctx_list=mx_context, even_split=False)[0]
- predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
+ predictions_ = mx.nd.zeros((single_pu_batch_size, 10,), ctx=mx_context[0])
nd.waitall()
- outputs = []
lossList = []
+ outputs = []
attentionList = []
- predictions_ = self._networks[0](image_)
+ net_ret = self._networks[0](image_)
+ predictions_ = net_ret[0][0]
outputs.append(predictions_)
lossList.append(loss_function(predictions_, labels[0]))
+
if save_attention_image == "True":
if not eval_train:
import matplotlib
@@ -594,26 +628,40 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
loss = loss + element
global_loss_test += loss.sum().asscalar()
+
test_batches += 1
predictions = []
for output_name in outputs:
predictions.append(output_name)
- metric.update(preds=predictions, labels=labels)
+ metric.update(preds=predictions, labels=[labels[j] for j in range(len(labels))])
+
test_metric_score = metric.get()[1]
- global_loss_test /= (test_batches * batch_size)
+ global_loss_test /= (test_batches * single_pu_batch_size)
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
- if (epoch - begin_epoch) % checkpoint_period == 0:
+ if (epoch+1) % checkpoint_period == 0:
for i, network in self._networks.items():
network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
+ if hasattr(network, 'episodic_sub_nets'):
+ for j, net in enumerate(network.episodic_sub_nets):
+ episodic_layers[i][j].save_memory(self.parameter_path(i) + "_episodic_memory_sub_net_" + str(j + 1) + "-" + str(epoch).zfill(4))
for i, network in self._networks.items():
- network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch + 1).zfill(4) + '.params')
+ network.save_parameters(self.parameter_path(i) + '-' + str((num_epoch-1) + begin_epoch).zfill(4) + '.params')
network.export(self.parameter_path(i) + '_newest', epoch=0)
+
+ if hasattr(network, 'episodic_sub_nets'):
+ network.episodicsubnet0_.export(self.parameter_path(i) + '_newest_episodic_sub_net_' + str(0), epoch=0)
+ for j, net in enumerate(network.episodic_sub_nets):
+ net.export(self.parameter_path(i) + '_newest_episodic_sub_net_' + str(j+1), epoch=0)
+ episodic_query_networks[i][j].export(self.parameter_path(i) + '_newest_episodic_query_net_' + str(j+1), epoch=0)
+ episodic_layers[i][j].save_memory(self.parameter_path(i) + "_episodic_memory_sub_net_" + str(j + 1) + "-" + str((num_epoch - 1) + begin_epoch).zfill(4))
+ episodic_layers[i][j].save_memory(self.parameter_path(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
+ loss_function.export(self.parameter_path(i) + '_newest_loss', epoch=0)
def parameter_path(self, index):
return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)
diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNCreator_defaultGAN_defaultGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNCreator_defaultGAN_defaultGANConnector_predictor.py
index 9e5e33e65f52240401b74b96bd273e106bab9efb..fafc619d8d2a31e4f2145a10ff9b824cb35bcc20 100644
--- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNCreator_defaultGAN_defaultGANConnector_predictor.py
+++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNCreator_defaultGAN_defaultGANConnector_predictor.py
@@ -2,6 +2,8 @@ import mxnet as mx
import logging
import os
import shutil
+import warnings
+import inspect
from CNNNet_defaultGAN_defaultGANConnector_predictor import Net_0
@@ -20,6 +22,10 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0]*num_episodic_sub_nets
+ mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
@@ -30,22 +36,77 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor:
except OSError:
pass
+ if hasattr(network, 'episodic_sub_nets'):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
+ except OSError:
+ pass
+
+ for j in range(len(network.episodic_sub_nets)):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
+ except OSError:
+ pass
+
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
- epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
+ epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
+ if hasattr(network, 'episodic_sub_nets'):
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading Replay Memory: " + mem_files[j])
+ mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
- if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
- earliestLastEpoch = lastEpoch
+ if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
+ earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
@@ -56,27 +117,52 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0] * num_episodic_sub_nets
+ mem_files = [None] * num_episodic_sub_nets
+
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
+ if hasattr(network, 'episodic_sub_nets'):
+ assert lastEpoch == lastMemEpoch
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading pretrained Replay Memory: " + mem_files[j])
+ mem_layer = \
+ [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
+ param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
- self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
- self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
+ self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[0].hybridize()
- self.networks[0](mx.nd.zeros((1, 100,), ctx=context))
+ self.networks[0](mx.nd.zeros((1, 100,), ctx=context[0]))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py
index 42fa27f0786c314fa7ac0df5c545cc92027fc71a..9cfe8224c5bbd0a2f9db82adbd4fbd8bfdd4a921 100644
--- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py
+++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py
@@ -184,16 +184,16 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor:
del discriminator_optimizer_params['learning_rate_decay']
if normalize:
- self._net_creator_dis.construct(mx_context, data_mean=data_mean, data_std=data_std)
+ self._net_creator_dis.construct([mx_context], data_mean=data_mean, data_std=data_std)
else:
- self._net_creator_dis.construct(mx_context)
+ self._net_creator_dis.construct([mx_context])
- self._net_creator_gen.construct(mx_context)
+ self._net_creator_gen.construct([mx_context])
if self.use_qnet:
- self._net_creator_qnet.construct(mx_context)
+ self._net_creator_qnet.construct([mx_context])
if load_checkpoint:
- self._net_creator_qnet.load(mx_context)
+ self._net_creator_qnet.load([mx_context])
else:
if os.path.isdir(self._net_creator_qnet._model_dir_):
shutil.rmtree(self._net_creator_qnet._model_dir_)
@@ -206,8 +206,8 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor:
begin_epoch = 0
if load_checkpoint:
- begin_epoch = self._net_creator_dis.load(mx_context)
- self._net_creator_gen.load(mx_context)
+ begin_epoch = self._net_creator_dis.load([mx_context])
+ self._net_creator_gen.load([mx_context])
else:
if os.path.isdir(self._net_creator_dis._model_dir_):
shutil.rmtree(self._net_creator_dis._model_dir_)
@@ -351,9 +351,9 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor:
gen_input, exp_qnet_output = create_generator_input(batch)
with autograd.record():
- fake_data = gen_net(*gen_input)
+ fake_data = gen_net(*gen_input)[0][0]
fake_data.detach()
- discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)
+ discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_fake_dis, _ = discriminated_fake_dis
@@ -361,7 +361,7 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor:
real_labels = mx.nd.ones(discriminated_fake_dis.shape, ctx=mx_context)
loss_resultF = dis_loss(discriminated_fake_dis, fake_labels)
- discriminated_real_dis = dis_net(real_data, *dis_conditional_input)
+ discriminated_real_dis = dis_net(real_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_real_dis, _ = discriminated_real_dis
loss_resultR = dis_loss(discriminated_real_dis, real_labels)
@@ -372,8 +372,8 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor:
if batch_i % k_value == 0:
with autograd.record():
- fake_data = gen_net(*gen_input)
- discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)
+ fake_data = gen_net(*gen_input)[0][0]
+ discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = dis_loss(discriminated_fake_gen, real_labels)
@@ -381,7 +381,7 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor:
condition = batch.data[traindata_to_index[generator_target_name + "_"]]
loss_resultG = loss_resultG + gen_loss_weight * generator_loss_func(fake_data, condition)
if self.use_qnet:
- qnet_discriminated = [q_net(features)]
+ qnet_discriminated = [q_net(features)[0][0]]
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultG.backward()
diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNNet_defaultGAN_defaultGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNNet_defaultGAN_defaultGANConnector_predictor.py
index 32d8a55ff8131a54a11955dfd652338f43d97037..4333be26146ebdadfa61306efc8bf610778201cb 100644
--- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNNet_defaultGAN_defaultGANConnector_predictor.py
+++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNNet_defaultGAN_defaultGANConnector_predictor.py
@@ -1,7 +1,10 @@
import mxnet as mx
import numpy as np
import math
-from mxnet import gluon
+import os
+import abc
+import warnings
+from mxnet import gluon, nd
class ZScoreNormalization(gluon.HybridBlock):
@@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
+
+class DotProductSelfAttention(gluon.HybridBlock):
+ def __init__(self,
+ scale_factor,
+ num_heads,
+ dim_model,
+ dim_keys,
+ dim_values,
+ use_proj_bias,
+ use_mask,
+ **kwargs):
+ super(DotProductSelfAttention, self).__init__(**kwargs)
+ with self.name_scope():
+ self.num_heads = num_heads
+ self.dim_model = dim_model
+ self.use_proj_bias = use_proj_bias
+ self.use_mask = use_mask
+
+ if dim_keys == -1:
+ self.dim_keys = int(dim_model / self.num_heads)
+ else:
+ self.dim_keys = dim_keys
+ if dim_values == -1:
+ self.dim_values = int(dim_model / self.num_heads)
+ else:
+ self.dim_values = dim_values
+
+ if scale_factor == -1:
+ self.scale_factor = math.sqrt(self.dim_keys)
+ else:
+ self.scale_factor = scale_factor
+
+ self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)
+
+ def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):
+
+ queries = F.Reshape(queries, shape=(0, 0,-1))
+ keys = F.Reshape(queries, shape=(0, 0, -1))
+ values = F.Reshape(queries, shape=(0, 0, -1))
+
+ head_queries = self.proj_q(queries)
+ head_keys = self.proj_k(keys)
+ head_values = self.proj_v(values)
+
+ head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
+ head_queries = F.transpose(head_queries, axes=(0,2,1,3))
+ head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)
+
+ head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
+ head_keys = F.transpose(head_keys, axes=(0,2,1,3))
+ head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)
+
+ score = F.batch_dot(head_queries, head_keys, transpose_b=True)
+ score = score * self.scale_factor
+ if self.use_mask:
+ mask = F.tile(mask, self.num_heads)
+ mask = F.repeat(mask, self.dim_model)
+ mask = F.reshape(mask, shape=(-1, self.dim_model))
+ weights = F.softmax(score, mask, use_length=self.use_mask)
+
+ head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
+ head_values = F.transpose(head_values, axes=(0,2,1,3))
+ head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)
+
+ ret = F.batch_dot(weights, head_values)
+ ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
+ ret = F.transpose(ret, axes=(0, 2, 1, 3))
+ ret = F.reshape(ret, shape=(0, 0, -1))
+
+ ret = self.proj_o(ret)
+
+ return ret
+
+
+class EpisodicReplayMemoryInterface(gluon.HybridBlock):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
+ super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
+
+ self.use_replay = use_replay
+ self.replay_interval = replay_interval
+ self.replay_batch_size = replay_batch_size
+ self.replay_steps = replay_steps
+ self.replay_gradient_steps = replay_gradient_steps
+ self.num_heads = num_heads
+
+ @abc.abstractmethod
+ def store_samples(self, data, y, query_network, store_prob, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def sample_memory(self, batch_size, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def get_query_network(self, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def save_memory(self, path):
+ pass
+
+ @abc.abstractmethod
+ def load_memory(self, path):
+ pass
+
+#Memory layer
+class LargeMemory(gluon.HybridBlock):
+ def __init__(self,
+ sub_key_size,
+ query_size,
+ query_act,
+ dist_measure,
+ k,
+ num_heads,
+ values_dim,
+ **kwargs):
+ super(LargeMemory, self).__init__(**kwargs)
+ with self.name_scope():
+ #Memory parameters
+ self.dist_measure = dist_measure
+ self.k = k
+ self.num_heads = num_heads
+ self.query_act = query_act
+ self.query_size = query_size
+ self.num_heads = num_heads
+
+ #Batch norm sub-layer
+ self.batch_norm = gluon.nn.BatchNorm()
+
+ #Memory sub-layer
+ self.sub_key_size = sub_key_size
+ sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
+
+ if values_dim == -1:
+ values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
+ else:
+ values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
+
+ self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
+ self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
+ self.values = self.params.get("values", shape=values_shape, differentiable=True)
+ self.label_memory = nd.array([])
+
+ self.get_query_network()
+
+ def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
+ x = self.batch_norm(x)
+
+ x = F.reshape(x, shape=(0, -1))
+
+ q = self.query_network(x)
+
+ q = F.reshape(q, shape=(0, self.num_heads, -1))
+
+ q_split = F.split(q, num_outputs=2, axis=-1)
+
+ if self.dist_measure == "l2":
+ q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
+ sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
+ q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
+ q1_dist = F.norm(q1_diff, axis=-1)
+ q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
+ sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
+ q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
+ q2_dist = F.norm(q2_diff, axis=-1)
+ else:
+ q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
+ q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
+ sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ q1 = [q1]
+ q2 = [q2]
+ sub_keys1_resh = [sub_keys1_resh ]
+ sub_keys2_resh = [sub_keys2_resh ]
+
+ q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
+ q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
+ for h in range(1, self.num_heads):
+ q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+ q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+
+ i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
+ i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
+
+ # Calculate cross product for keys at indices I1 and I2
+
+ # def head_take(data, state):
+ # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
+ #
+ # i1 = F.transpose(i1, axes=(1,0,2))
+ # i2 = F.transpose(i2, axes=(1, 0, 2))
+ # st = F.zeros(1)
+ # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
+ # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
+ # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
+ i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
+ i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
+ sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ i1 = [i1]
+ i2 = [i2]
+ sub_keys1 = [sub_keys1]
+ sub_keys2 = [sub_keys2]
+
+ k1 = F.take(sub_keys1[0], i1[0])
+ k2 = F.take(sub_keys2[0], i2[0])
+ for h in range(1, self.num_heads):
+ k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
+ k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
+
+ k1 = F.tile(k1, (1, 1, self.k, 1))
+ k2 = F.repeat(k2, self.k, 2)
+ c_cart = F.concat(k1, k2, dim=3)
+
+ q = F.reshape(q, shape=(-1,0), reverse=True)
+ q = F.reshape(q, shape=(0, 1, -1))
+ c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
+ if self.dist_measure == "l2":
+ k_diff = F.broadcast_sub(q, c_cart)
+ k_dist = F.norm(k_diff, axis=-1)
+ else:
+ k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
+ k_dist = F.reshape(k_dist, shape=(0, -1))
+
+ i = F.topk(k_dist, k=self.k, ret_typ="both")
+
+ w = F.softmax(i[0])
+ w = F.reshape(w, shape=(0,1,-1))
+ vi = F.take(values, i[1])
+ aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist)
+
+ ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True)
+ one_vec = F.ones((1, 1, self.num_heads))
+ one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0)
+ ret = F.batch_dot(one_vec, ret)
+ ret = F.reshape(ret, shape=(-1, 0), reverse=True)
+
+ return ret
+
+ def get_query_network(self):
+ if hasattr(self, 'query_network'):
+ return self.query_network
+ else:
+ self.query_network = gluon.nn.HybridSequential()
+ for size in self.query_size:
+ if self.query_act == "linear":
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False))
+ else:
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
+ return self.query_network
+
+
+#EpisodicMemory layer
+class EpisodicMemory(EpisodicReplayMemoryInterface):
+ def __init__(self,
+ replay_interval,
+ replay_batch_size,
+ replay_steps,
+ replay_gradient_steps,
+ store_prob,
+ max_stored_samples,
+ memory_replacement_strategy,
+ use_replay,
+ query_net_dir,
+ query_net_prefix,
+ query_net_num_inputs,
+ **kwargs):
+ super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
+ with self.name_scope():
+ #Replay parameters
+ self.store_prob = store_prob
+ self.max_stored_samples = max_stored_samples
+ self.memory_replacement_strategy = memory_replacement_strategy
+
+ self.query_net_dir = query_net_dir
+ self.query_net_prefix = query_net_prefix
+ self.query_net_num_inputs = query_net_num_inputs
+
+ #Memory
+ self.key_memory = nd.array([])
+ self.value_memory = nd.array([])
+ self.label_memory = nd.array([])
+
+ def hybrid_forward(self, F, *args):
+ #propagate the input as the rest is only used for replay
+ return [args, []]
+
+ def store_samples(self, data, y, query_network, store_prob, context):
+ if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
+ num_pus = len(data)
+ sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
+ num_inputs = len(data[0][0])
+ num_outputs = len(y)
+ mx_context = context[0]
+
+ if len(self.key_memory) == 0:
+ self.key_memory = nd.empty(0, ctx=mx.cpu())
+ self.value_memory = []
+ self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
+
+ ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
+
+ max_inds = [nd.max(ind[i]) for i in range(num_pus)]
+ if any(max_inds):
+ to_store_values = []
+ for i in range(num_inputs):
+ tmp_values = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_values, list):
+ tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
+ else:
+ tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_values.append(tmp_values)
+
+ to_store_labels = []
+ for i in range(num_outputs):
+ tmp_labels = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_labels, list):
+ tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
+ else:
+ tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_labels.append(tmp_labels)
+
+ to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
+
+ if self.key_memory.shape[0] == 0:
+ self.key_memory = to_store_keys.as_in_context(mx.cpu())
+ for i in range(num_inputs):
+ self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
+ for i in range(num_outputs):
+ self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
+ elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
+ num_to_store = to_store_keys.shape[0]
+ self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+ else:
+ self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+
+ def sample_memory(self, batch_size):
+ num_stored_samples = self.key_memory.shape[0]
+ if self.replay_batch_size == -1:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu())
+ else:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu())
+
+ num_outputs = len(self.label_memory)
+
+ sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind]
+ sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)]
+
+ return sample_batches
+
+ def get_query_network(self, context):
+ lastEpoch = 0
+ for file in os.listdir(self.query_net_dir):
+ if self.query_net_prefix in file and ".json" in file:
+ symbolFile = file
+
+ if self.query_net_prefix in file and ".param" in file:
+ epochStr = file.replace(".params", "").replace(self.query_net_prefix, "")
+ epoch = int(epochStr)
+ if epoch >= lastEpoch:
+ lastEpoch = epoch
+ weightFile = file
+
+ inputNames = []
+ if self.query_net_num_inputs == 1:
+ inputNames.append("data")
+ else:
+ for i in range(self.query_net_num_inputs):
+ inputNames.append("data" + str(i))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0])
+ net.hybridize()
+ return net
+
+ def save_memory(self, path):
+ mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
+ mem_dict = {entry[0]:entry[1] for entry in mem_arr}
+ nd.save(path, mem_dict)
+
+ def load_memory(self, path):
+ mem_dict = nd.load(path)
+ self.value_memory = []
+ self.label_memory = []
+ for key in sorted(mem_dict.keys()):
+ if key == "keys":
+ self.key_memory = mem_dict[key]
+ elif key.startswith("values_"):
+ self.value_memory.append(mem_dict[key])
+ elif key.startswith("labels_"):
+ self.label_memory.append(mem_dict[key])
+
+
+#Stream 0
class Net_0(gluon.HybridBlock):
- def __init__(self, data_mean=None, data_std=None, **kwargs):
+ def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
with self.name_scope():
if data_mean:
@@ -177,5 +593,5 @@ class Net_0(gluon.HybridBlock):
tanh5_ = self.tanh5_(upconvolution5_)
data_ = F.identity(tanh5_)
- return data_
+ return [[data_]]
diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNPredictor_defaultGAN_defaultGANConnector_predictor.h b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNPredictor_defaultGAN_defaultGANConnector_predictor.h
index a99f9c1b8c799648903e501c686773b03aea1dbf..d7ad1aab8082d22192f169560e8ff50016ab9704 100644
--- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNPredictor_defaultGAN_defaultGANConnector_predictor.h
+++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNPredictor_defaultGAN_defaultGANConnector_predictor.h
@@ -1,108 +1,149 @@
#ifndef CNNPREDICTOR_DEFAULTGAN_DEFAULTGANCONNECTOR_PREDICTOR
#define CNNPREDICTOR_DEFAULTGAN_DEFAULTGANCONNECTOR_PREDICTOR
-#include
+#include
#include
#include
#include
+
+#include
+#include
-#include
-
+using namespace mxnet::cpp;
+
class CNNPredictor_defaultGAN_defaultGANConnector_predictor_0{
public:
- const std::string json_file = "model/defaultGAN.DefaultGANGenerator/model_0_newest-symbol.json";
- const std::string param_file = "model/defaultGAN.DefaultGANGenerator/model_0_newest-0000.params";
- const std::vector input_keys = {
+ const std::string file_prefix = "model/defaultGAN.DefaultGANGenerator/model_0_newest";
+
+ //network
+ const std::vector network_input_keys = {
"data"
};
- const std::vector> input_shapes = {{1, 100}};
- const bool use_gpu = false;
-
- PredictorHandle handle;
-
+ const std::vector> network_input_shapes = {{1, 100}};
+ std::vector network_input_sizes;
+ std::vector> network_arg_names;
+ std::vector network_handles;
+
+
+ //misc
+ Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu
+ int dtype = 0; //use data type (float32=0 float64=1 ...)
+
+
explicit CNNPredictor_defaultGAN_defaultGANConnector_predictor_0(){
- init(json_file, param_file, input_keys, input_shapes, use_gpu);
+ init(file_prefix, network_input_keys, network_input_shapes);
}
~CNNPredictor_defaultGAN_defaultGANConnector_predictor_0(){
- if(handle) MXPredFree(handle);
+ for(Executor * handle : network_handles){
+ delete handle;
+ }
+ MXNotifyShutdown();
}
void predict(const std::vector &in_noise_,
std::vector &out_data_){
- MXPredSetInput(handle, input_keys[0].c_str(), in_noise_.data(), static_cast(in_noise_.size()));
-
- MXPredForward(handle);
- mx_uint output_index;
- mx_uint *shape = 0;
- mx_uint shape_len;
- size_t size;
-
- output_index = 0;
- MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
- size = 1;
- for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
-
- assert(size == out_data_.size());
- MXPredGetOutput(handle, output_index, &(out_data_[0]), out_data_.size());
+ NDArray input_temp;
+ input_temp = NDArray(network_input_shapes[0], ctx, false, dtype);
+ input_temp.SyncCopyFromCPU(in_noise_.data(), network_input_sizes[0]);
+ input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]]));
+ NDArray::WaitAll();
+
+ network_handles[0]->Forward(false);
+ CheckMXNetError("Forward, predict, handle ind. 0");
+
+
+ std::vector output = network_handles.back()->outputs;
+ std::vector curr_output_shape;
+ size_t curr_output_size;
+ curr_output_shape = output[0].GetShape();
+ curr_output_size = 1;
+ for (mx_uint i : curr_output_shape) curr_output_size *= i;
+ //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs
+ assert((curr_output_size == out_data_.size()) || (curr_output_size == out_data_[0]));
+ output[0].SyncCopyToCPU(&out_data_);
+
}
+
+
+
+ Executor* initExecutor(Symbol &sym,
+ std::map ¶m_map,
+ const std::vector &exec_input_keys,
+ const std::vector> &exec_input_shapes){
+
+ const mx_uint num_exec_input_nodes = exec_input_keys.size();
+ for(mx_uint i = 0; i < num_exec_input_nodes; i++){
+ param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype);
+ }
- void init(const std::string &json_file,
- const std::string ¶m_file,
- const std::vector &input_keys,
- const std::vector> &input_shapes,
- const bool &use_gpu){
+ std::vector param_arrays;
+ std::vector grad_array;
+ std::vector grad_reqs;
+ std::vector aux_arrays;
+ std::map< std::string, NDArray> aux_map;
- BufferFile json_data(json_file);
- BufferFile param_data(param_file);
+ sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs,
+ &aux_arrays, param_map, std::map(),
+ std::map(), aux_map);
- int dev_type = use_gpu ? 2 : 1;
- int dev_id = 0;
+ Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays);
+ assert(handle);
+ return handle;
+ }
- if (json_data.GetLength() == 0 ||
- param_data.GetLength() == 0) {
- std::exit(-1);
+ std::vector getSizesOfShapes(const std::vector> shapes){
+ std::vector sizes;
+ for(std::vector shape : shapes){
+ mx_uint val = 1;
+ for(mx_uint i: shape){
+ val *= i;
+ }
+ sizes.push_back(val);
}
+ return sizes;
+ }
- const mx_uint num_input_nodes = input_keys.size();
-
- const char* input_keys_ptr[num_input_nodes];
- for(mx_uint i = 0; i < num_input_nodes; i++){
- input_keys_ptr[i] = input_keys[i].c_str();
+ void CheckMXNetError(std::string loc){
+ const char* err = MXGetLastError();
+ if (err && err[0] != 0) {
+ std::cout << "MXNet error at " << loc << err << std::endl;
+ exit(-1);
}
-
- mx_uint shape_data_size = 0;
- mx_uint input_shape_indptr[input_shapes.size() + 1];
- input_shape_indptr[0] = 0;
- for(mx_uint i = 0; i < input_shapes.size(); i++){
- shape_data_size += input_shapes[i].size();
- input_shape_indptr[i+1] = shape_data_size;
+ }
+
+ void init(const std::string &file_prefix,
+ const std::vector &network_input_keys,
+ const std::vector> &network_input_shapes){
+
+ CNNLAOptimizer_defaultGAN_defaultGANConnector_predictor optimizer_creator = CNNLAOptimizer_defaultGAN_defaultGANConnector_predictor();
+
+ if(optimizer_creator.getContextName() == "gpu"){
+ ctx = Context::gpu();
}
-
- mx_uint input_shape_data[shape_data_size];
- mx_uint index = 0;
- for(mx_uint i = 0; i < input_shapes.size(); i++){
- for(mx_uint j = 0; j < input_shapes[i].size(); j++){
- input_shape_data[index] = input_shapes[i][j];
- index++;
- }
+
+ network_input_sizes = getSizesOfShapes(network_input_shapes);
+
+ ModelLoader model_loader(file_prefix, 0, ctx);
+
+ std::vector network_symbols = model_loader.GetNetworkSymbols();
+ std::vector> network_param_maps;
+ network_param_maps = model_loader.GetNetworkParamMaps();
+
+ //Init handles
+ std::map> in_shape_map;
+ for(mx_uint i=0; i < network_input_keys.size(); i++){
+ in_shape_map[network_input_keys[i]] = network_input_shapes[i];
}
-
- MXPredCreate(static_cast(json_data.GetBuffer()),
- static_cast(param_data.GetBuffer()),
- static_cast(param_data.GetLength()),
- dev_type,
- dev_id,
- num_input_nodes,
- input_keys_ptr,
- input_shape_indptr,
- input_shape_data,
- &handle);
- assert(handle);
+ std::vector> in_shapes;
+ std::vector> aux_shapes;
+ std::vector> out_shapes;
+ network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes);
+ network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes));
+
}
};
-
#endif // CNNPREDICTOR_DEFAULTGAN_DEFAULTGANCONNECTOR_PREDICTOR
diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNTrainer_defaultGAN_defaultGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNTrainer_defaultGAN_defaultGANConnector_predictor.py
index 82954430835cdbe977415d4c694240a29e7e92c0..7119c688509d054a73c38bfad34796854df9bd61 100644
--- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNTrainer_defaultGAN_defaultGANConnector_predictor.py
+++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNTrainer_defaultGAN_defaultGANConnector_predictor.py
@@ -55,5 +55,3 @@ if __name__ == "__main__":
log_period=10,
print_images=True,
)
-
-
diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py
index 0e7501600ef4a5ae20cb0af65aa78f5090a93745..1920ecee6a5e558930fb6e2e3259ceed2ed98a2b 100644
--- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py
+++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py
@@ -2,6 +2,8 @@ import mxnet as mx
import logging
import os
import shutil
+import warnings
+import inspect
from CNNNet_defaultGAN_defaultGANDiscriminator import Net_0
@@ -20,6 +22,10 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0]*num_episodic_sub_nets
+ mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
@@ -30,22 +36,77 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator:
except OSError:
pass
+ if hasattr(network, 'episodic_sub_nets'):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
+ except OSError:
+ pass
+
+ for j in range(len(network.episodic_sub_nets)):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
+ except OSError:
+ pass
+
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
- epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
+ epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
+ if hasattr(network, 'episodic_sub_nets'):
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading Replay Memory: " + mem_files[j])
+ mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
- if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
- earliestLastEpoch = lastEpoch
+ if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
+ earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
@@ -56,27 +117,52 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0] * num_episodic_sub_nets
+ mem_files = [None] * num_episodic_sub_nets
+
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
+ if hasattr(network, 'episodic_sub_nets'):
+ assert lastEpoch == lastMemEpoch
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading pretrained Replay Memory: " + mem_files[j])
+ mem_layer = \
+ [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
+ param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
- self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
- self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
+ self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[0].hybridize()
- self.networks[0](mx.nd.zeros((1, 1,64,64,), ctx=context))
+ self.networks[0](mx.nd.zeros((1, 1,64,64,), ctx=context[0]))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNNet_defaultGAN_defaultGANDiscriminator.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNNet_defaultGAN_defaultGANDiscriminator.py
index 027ed194713072192cd56ece3f55216888ea8c94..6c220b51e34c969cf20defd9fee40498055e854a 100644
--- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNNet_defaultGAN_defaultGANDiscriminator.py
+++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNNet_defaultGAN_defaultGANDiscriminator.py
@@ -1,7 +1,10 @@
import mxnet as mx
import numpy as np
import math
-from mxnet import gluon
+import os
+import abc
+import warnings
+from mxnet import gluon, nd
class ZScoreNormalization(gluon.HybridBlock):
@@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
+
+class DotProductSelfAttention(gluon.HybridBlock):
+ def __init__(self,
+ scale_factor,
+ num_heads,
+ dim_model,
+ dim_keys,
+ dim_values,
+ use_proj_bias,
+ use_mask,
+ **kwargs):
+ super(DotProductSelfAttention, self).__init__(**kwargs)
+ with self.name_scope():
+ self.num_heads = num_heads
+ self.dim_model = dim_model
+ self.use_proj_bias = use_proj_bias
+ self.use_mask = use_mask
+
+ if dim_keys == -1:
+ self.dim_keys = int(dim_model / self.num_heads)
+ else:
+ self.dim_keys = dim_keys
+ if dim_values == -1:
+ self.dim_values = int(dim_model / self.num_heads)
+ else:
+ self.dim_values = dim_values
+
+ if scale_factor == -1:
+ self.scale_factor = math.sqrt(self.dim_keys)
+ else:
+ self.scale_factor = scale_factor
+
+ self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)
+
+ def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):
+
+ queries = F.Reshape(queries, shape=(0, 0,-1))
+ keys = F.Reshape(queries, shape=(0, 0, -1))
+ values = F.Reshape(queries, shape=(0, 0, -1))
+
+ head_queries = self.proj_q(queries)
+ head_keys = self.proj_k(keys)
+ head_values = self.proj_v(values)
+
+ head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
+ head_queries = F.transpose(head_queries, axes=(0,2,1,3))
+ head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)
+
+ head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
+ head_keys = F.transpose(head_keys, axes=(0,2,1,3))
+ head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)
+
+ score = F.batch_dot(head_queries, head_keys, transpose_b=True)
+ score = score * self.scale_factor
+ if self.use_mask:
+ mask = F.tile(mask, self.num_heads)
+ mask = F.repeat(mask, self.dim_model)
+ mask = F.reshape(mask, shape=(-1, self.dim_model))
+ weights = F.softmax(score, mask, use_length=self.use_mask)
+
+ head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
+ head_values = F.transpose(head_values, axes=(0,2,1,3))
+ head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)
+
+ ret = F.batch_dot(weights, head_values)
+ ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
+ ret = F.transpose(ret, axes=(0, 2, 1, 3))
+ ret = F.reshape(ret, shape=(0, 0, -1))
+
+ ret = self.proj_o(ret)
+
+ return ret
+
+
+class EpisodicReplayMemoryInterface(gluon.HybridBlock):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
+ super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
+
+ self.use_replay = use_replay
+ self.replay_interval = replay_interval
+ self.replay_batch_size = replay_batch_size
+ self.replay_steps = replay_steps
+ self.replay_gradient_steps = replay_gradient_steps
+ self.num_heads = num_heads
+
+ @abc.abstractmethod
+ def store_samples(self, data, y, query_network, store_prob, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def sample_memory(self, batch_size, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def get_query_network(self, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def save_memory(self, path):
+ pass
+
+ @abc.abstractmethod
+ def load_memory(self, path):
+ pass
+
+#Memory layer
+class LargeMemory(gluon.HybridBlock):
+ def __init__(self,
+ sub_key_size,
+ query_size,
+ query_act,
+ dist_measure,
+ k,
+ num_heads,
+ values_dim,
+ **kwargs):
+ super(LargeMemory, self).__init__(**kwargs)
+ with self.name_scope():
+ #Memory parameters
+ self.dist_measure = dist_measure
+ self.k = k
+ self.num_heads = num_heads
+ self.query_act = query_act
+ self.query_size = query_size
+ self.num_heads = num_heads
+
+ #Batch norm sub-layer
+ self.batch_norm = gluon.nn.BatchNorm()
+
+ #Memory sub-layer
+ self.sub_key_size = sub_key_size
+ sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
+
+ if values_dim == -1:
+ values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
+ else:
+ values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
+
+ self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
+ self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
+ self.values = self.params.get("values", shape=values_shape, differentiable=True)
+ self.label_memory = nd.array([])
+
+ self.get_query_network()
+
+ def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
+ x = self.batch_norm(x)
+
+ x = F.reshape(x, shape=(0, -1))
+
+ q = self.query_network(x)
+
+ q = F.reshape(q, shape=(0, self.num_heads, -1))
+
+ q_split = F.split(q, num_outputs=2, axis=-1)
+
+ if self.dist_measure == "l2":
+ q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
+ sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
+ q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
+ q1_dist = F.norm(q1_diff, axis=-1)
+ q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
+ sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
+ q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
+ q2_dist = F.norm(q2_diff, axis=-1)
+ else:
+ q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
+ q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
+ sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ q1 = [q1]
+ q2 = [q2]
+ sub_keys1_resh = [sub_keys1_resh ]
+ sub_keys2_resh = [sub_keys2_resh ]
+
+ q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
+ q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
+ for h in range(1, self.num_heads):
+ q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+ q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+
+ i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
+ i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
+
+ # Calculate cross product for keys at indices I1 and I2
+
+ # def head_take(data, state):
+ # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
+ #
+ # i1 = F.transpose(i1, axes=(1,0,2))
+ # i2 = F.transpose(i2, axes=(1, 0, 2))
+ # st = F.zeros(1)
+ # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
+ # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
+ # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
+ i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
+ i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
+ sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ i1 = [i1]
+ i2 = [i2]
+ sub_keys1 = [sub_keys1]
+ sub_keys2 = [sub_keys2]
+
+ k1 = F.take(sub_keys1[0], i1[0])
+ k2 = F.take(sub_keys2[0], i2[0])
+ for h in range(1, self.num_heads):
+ k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
+ k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
+
+ k1 = F.tile(k1, (1, 1, self.k, 1))
+ k2 = F.repeat(k2, self.k, 2)
+ c_cart = F.concat(k1, k2, dim=3)
+
+ q = F.reshape(q, shape=(-1,0), reverse=True)
+ q = F.reshape(q, shape=(0, 1, -1))
+ c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
+ if self.dist_measure == "l2":
+ k_diff = F.broadcast_sub(q, c_cart)
+ k_dist = F.norm(k_diff, axis=-1)
+ else:
+ k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
+ k_dist = F.reshape(k_dist, shape=(0, -1))
+
+ i = F.topk(k_dist, k=self.k, ret_typ="both")
+
+ w = F.softmax(i[0])
+ w = F.reshape(w, shape=(0,1,-1))
+ vi = F.take(values, i[1])
+ aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist)
+
+ ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True)
+ one_vec = F.ones((1, 1, self.num_heads))
+ one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0)
+ ret = F.batch_dot(one_vec, ret)
+ ret = F.reshape(ret, shape=(-1, 0), reverse=True)
+
+ return ret
+
+ def get_query_network(self):
+ if hasattr(self, 'query_network'):
+ return self.query_network
+ else:
+ self.query_network = gluon.nn.HybridSequential()
+ for size in self.query_size:
+ if self.query_act == "linear":
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False))
+ else:
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
+ return self.query_network
+
+
+#EpisodicMemory layer
+class EpisodicMemory(EpisodicReplayMemoryInterface):
+ def __init__(self,
+ replay_interval,
+ replay_batch_size,
+ replay_steps,
+ replay_gradient_steps,
+ store_prob,
+ max_stored_samples,
+ memory_replacement_strategy,
+ use_replay,
+ query_net_dir,
+ query_net_prefix,
+ query_net_num_inputs,
+ **kwargs):
+ super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
+ with self.name_scope():
+ #Replay parameters
+ self.store_prob = store_prob
+ self.max_stored_samples = max_stored_samples
+ self.memory_replacement_strategy = memory_replacement_strategy
+
+ self.query_net_dir = query_net_dir
+ self.query_net_prefix = query_net_prefix
+ self.query_net_num_inputs = query_net_num_inputs
+
+ #Memory
+ self.key_memory = nd.array([])
+ self.value_memory = nd.array([])
+ self.label_memory = nd.array([])
+
+ def hybrid_forward(self, F, *args):
+ #propagate the input as the rest is only used for replay
+ return [args, []]
+
+ def store_samples(self, data, y, query_network, store_prob, context):
+ if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
+ num_pus = len(data)
+ sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
+ num_inputs = len(data[0][0])
+ num_outputs = len(y)
+ mx_context = context[0]
+
+ if len(self.key_memory) == 0:
+ self.key_memory = nd.empty(0, ctx=mx.cpu())
+ self.value_memory = []
+ self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
+
+ ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
+
+ max_inds = [nd.max(ind[i]) for i in range(num_pus)]
+ if any(max_inds):
+ to_store_values = []
+ for i in range(num_inputs):
+ tmp_values = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_values, list):
+ tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
+ else:
+ tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_values.append(tmp_values)
+
+ to_store_labels = []
+ for i in range(num_outputs):
+ tmp_labels = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_labels, list):
+ tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
+ else:
+ tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_labels.append(tmp_labels)
+
+ to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
+
+ if self.key_memory.shape[0] == 0:
+ self.key_memory = to_store_keys.as_in_context(mx.cpu())
+ for i in range(num_inputs):
+ self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
+ for i in range(num_outputs):
+ self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
+ elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
+ num_to_store = to_store_keys.shape[0]
+ self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+ else:
+ self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+
+ def sample_memory(self, batch_size):
+ num_stored_samples = self.key_memory.shape[0]
+ if self.replay_batch_size == -1:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu())
+ else:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu())
+
+ num_outputs = len(self.label_memory)
+
+ sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind]
+ sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)]
+
+ return sample_batches
+
+ def get_query_network(self, context):
+ lastEpoch = 0
+ for file in os.listdir(self.query_net_dir):
+ if self.query_net_prefix in file and ".json" in file:
+ symbolFile = file
+
+ if self.query_net_prefix in file and ".param" in file:
+ epochStr = file.replace(".params", "").replace(self.query_net_prefix, "")
+ epoch = int(epochStr)
+ if epoch >= lastEpoch:
+ lastEpoch = epoch
+ weightFile = file
+
+ inputNames = []
+ if self.query_net_num_inputs == 1:
+ inputNames.append("data")
+ else:
+ for i in range(self.query_net_num_inputs):
+ inputNames.append("data" + str(i))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0])
+ net.hybridize()
+ return net
+
+ def save_memory(self, path):
+ mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
+ mem_dict = {entry[0]:entry[1] for entry in mem_arr}
+ nd.save(path, mem_dict)
+
+ def load_memory(self, path):
+ mem_dict = nd.load(path)
+ self.value_memory = []
+ self.label_memory = []
+ for key in sorted(mem_dict.keys()):
+ if key == "keys":
+ self.key_memory = mem_dict[key]
+ elif key.startswith("values_"):
+ self.value_memory.append(mem_dict[key])
+ elif key.startswith("labels_"):
+ self.label_memory.append(mem_dict[key])
+
+
+#Stream 0
class Net_0(gluon.HybridBlock):
- def __init__(self, data_mean=None, data_std=None, **kwargs):
+ def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
with self.name_scope():
if data_mean:
@@ -172,5 +588,5 @@ class Net_0(gluon.HybridBlock):
sigmoid5_ = self.sigmoid5_(conv5_)
dis_ = F.identity(sigmoid5_)
- return dis_
+ return [[dis_]]
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNCreator_infoGAN_infoGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNCreator_infoGAN_infoGANConnector_predictor.py
index ba115b47978ef21143be826f7ab0b6b639075682..c84c36bf315d49197e604ef09981a6d78073a932 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNCreator_infoGAN_infoGANConnector_predictor.py
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNCreator_infoGAN_infoGANConnector_predictor.py
@@ -2,6 +2,8 @@ import mxnet as mx
import logging
import os
import shutil
+import warnings
+import inspect
from CNNNet_infoGAN_infoGANConnector_predictor import Net_0
@@ -20,6 +22,10 @@ class CNNCreator_infoGAN_infoGANConnector_predictor:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0]*num_episodic_sub_nets
+ mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
@@ -30,22 +36,77 @@ class CNNCreator_infoGAN_infoGANConnector_predictor:
except OSError:
pass
+ if hasattr(network, 'episodic_sub_nets'):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
+ except OSError:
+ pass
+
+ for j in range(len(network.episodic_sub_nets)):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
+ except OSError:
+ pass
+
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
- epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
+ epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
+ if hasattr(network, 'episodic_sub_nets'):
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading Replay Memory: " + mem_files[j])
+ mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
- if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
- earliestLastEpoch = lastEpoch
+ if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
+ earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
@@ -56,27 +117,52 @@ class CNNCreator_infoGAN_infoGANConnector_predictor:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0] * num_episodic_sub_nets
+ mem_files = [None] * num_episodic_sub_nets
+
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
+ if hasattr(network, 'episodic_sub_nets'):
+ assert lastEpoch == lastMemEpoch
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading pretrained Replay Memory: " + mem_files[j])
+ mem_layer = \
+ [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
+ param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
- self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
- self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
+ self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[0].hybridize()
- self.networks[0](mx.nd.zeros((1, 62,), ctx=context), mx.nd.zeros((1, 10,), ctx=context))
+ self.networks[0](mx.nd.zeros((1, 62,), ctx=context[0]), mx.nd.zeros((1, 10,), ctx=context[0]))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNGanTrainer_infoGAN_infoGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNGanTrainer_infoGAN_infoGANConnector_predictor.py
index b49f11dc6de390d416e705a14d481baf90e49a51..426cc112a4ed84ab3f4158d47bbcbfdfc276e82e 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNGanTrainer_infoGAN_infoGANConnector_predictor.py
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNGanTrainer_infoGAN_infoGANConnector_predictor.py
@@ -184,16 +184,16 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor:
del discriminator_optimizer_params['learning_rate_decay']
if normalize:
- self._net_creator_dis.construct(mx_context, data_mean=data_mean, data_std=data_std)
+ self._net_creator_dis.construct([mx_context], data_mean=data_mean, data_std=data_std)
else:
- self._net_creator_dis.construct(mx_context)
+ self._net_creator_dis.construct([mx_context])
- self._net_creator_gen.construct(mx_context)
+ self._net_creator_gen.construct([mx_context])
if self.use_qnet:
- self._net_creator_qnet.construct(mx_context)
+ self._net_creator_qnet.construct([mx_context])
if load_checkpoint:
- self._net_creator_qnet.load(mx_context)
+ self._net_creator_qnet.load([mx_context])
else:
if os.path.isdir(self._net_creator_qnet._model_dir_):
shutil.rmtree(self._net_creator_qnet._model_dir_)
@@ -206,8 +206,8 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor:
begin_epoch = 0
if load_checkpoint:
- begin_epoch = self._net_creator_dis.load(mx_context)
- self._net_creator_gen.load(mx_context)
+ begin_epoch = self._net_creator_dis.load([mx_context])
+ self._net_creator_gen.load([mx_context])
else:
if os.path.isdir(self._net_creator_dis._model_dir_):
shutil.rmtree(self._net_creator_dis._model_dir_)
@@ -351,9 +351,9 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor:
gen_input, exp_qnet_output = create_generator_input(batch)
with autograd.record():
- fake_data = gen_net(*gen_input)
+ fake_data = gen_net(*gen_input)[0][0]
fake_data.detach()
- discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)
+ discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_fake_dis, _ = discriminated_fake_dis
@@ -361,7 +361,7 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor:
real_labels = mx.nd.ones(discriminated_fake_dis.shape, ctx=mx_context)
loss_resultF = dis_loss(discriminated_fake_dis, fake_labels)
- discriminated_real_dis = dis_net(real_data, *dis_conditional_input)
+ discriminated_real_dis = dis_net(real_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_real_dis, _ = discriminated_real_dis
loss_resultR = dis_loss(discriminated_real_dis, real_labels)
@@ -372,8 +372,8 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor:
if batch_i % k_value == 0:
with autograd.record():
- fake_data = gen_net(*gen_input)
- discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)
+ fake_data = gen_net(*gen_input)[0][0]
+ discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = dis_loss(discriminated_fake_gen, real_labels)
@@ -381,7 +381,7 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor:
condition = batch.data[traindata_to_index[generator_target_name + "_"]]
loss_resultG = loss_resultG + gen_loss_weight * generator_loss_func(fake_data, condition)
if self.use_qnet:
- qnet_discriminated = [q_net(features)]
+ qnet_discriminated = [q_net(features)[0][0]]
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultG.backward()
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNNet_infoGAN_infoGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNNet_infoGAN_infoGANConnector_predictor.py
index 4aeb33b350c6d3e898cabddec1499b3ad07a7f10..ae8d9f10d1a696e3f65391c6b0100078753ab836 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNNet_infoGAN_infoGANConnector_predictor.py
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNNet_infoGAN_infoGANConnector_predictor.py
@@ -1,7 +1,10 @@
import mxnet as mx
import numpy as np
import math
-from mxnet import gluon
+import os
+import abc
+import warnings
+from mxnet import gluon, nd
class ZScoreNormalization(gluon.HybridBlock):
@@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
+
+class DotProductSelfAttention(gluon.HybridBlock):
+ def __init__(self,
+ scale_factor,
+ num_heads,
+ dim_model,
+ dim_keys,
+ dim_values,
+ use_proj_bias,
+ use_mask,
+ **kwargs):
+ super(DotProductSelfAttention, self).__init__(**kwargs)
+ with self.name_scope():
+ self.num_heads = num_heads
+ self.dim_model = dim_model
+ self.use_proj_bias = use_proj_bias
+ self.use_mask = use_mask
+
+ if dim_keys == -1:
+ self.dim_keys = int(dim_model / self.num_heads)
+ else:
+ self.dim_keys = dim_keys
+ if dim_values == -1:
+ self.dim_values = int(dim_model / self.num_heads)
+ else:
+ self.dim_values = dim_values
+
+ if scale_factor == -1:
+ self.scale_factor = math.sqrt(self.dim_keys)
+ else:
+ self.scale_factor = scale_factor
+
+ self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)
+
+ def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):
+
+ queries = F.Reshape(queries, shape=(0, 0,-1))
+ keys = F.Reshape(queries, shape=(0, 0, -1))
+ values = F.Reshape(queries, shape=(0, 0, -1))
+
+ head_queries = self.proj_q(queries)
+ head_keys = self.proj_k(keys)
+ head_values = self.proj_v(values)
+
+ head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
+ head_queries = F.transpose(head_queries, axes=(0,2,1,3))
+ head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)
+
+ head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
+ head_keys = F.transpose(head_keys, axes=(0,2,1,3))
+ head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)
+
+ score = F.batch_dot(head_queries, head_keys, transpose_b=True)
+ score = score * self.scale_factor
+ if self.use_mask:
+ mask = F.tile(mask, self.num_heads)
+ mask = F.repeat(mask, self.dim_model)
+ mask = F.reshape(mask, shape=(-1, self.dim_model))
+ weights = F.softmax(score, mask, use_length=self.use_mask)
+
+ head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
+ head_values = F.transpose(head_values, axes=(0,2,1,3))
+ head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)
+
+ ret = F.batch_dot(weights, head_values)
+ ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
+ ret = F.transpose(ret, axes=(0, 2, 1, 3))
+ ret = F.reshape(ret, shape=(0, 0, -1))
+
+ ret = self.proj_o(ret)
+
+ return ret
+
+
+class EpisodicReplayMemoryInterface(gluon.HybridBlock):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
+ super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
+
+ self.use_replay = use_replay
+ self.replay_interval = replay_interval
+ self.replay_batch_size = replay_batch_size
+ self.replay_steps = replay_steps
+ self.replay_gradient_steps = replay_gradient_steps
+ self.num_heads = num_heads
+
+ @abc.abstractmethod
+ def store_samples(self, data, y, query_network, store_prob, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def sample_memory(self, batch_size, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def get_query_network(self, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def save_memory(self, path):
+ pass
+
+ @abc.abstractmethod
+ def load_memory(self, path):
+ pass
+
+#Memory layer
+class LargeMemory(gluon.HybridBlock):
+ def __init__(self,
+ sub_key_size,
+ query_size,
+ query_act,
+ dist_measure,
+ k,
+ num_heads,
+ values_dim,
+ **kwargs):
+ super(LargeMemory, self).__init__(**kwargs)
+ with self.name_scope():
+ #Memory parameters
+ self.dist_measure = dist_measure
+ self.k = k
+ self.num_heads = num_heads
+ self.query_act = query_act
+ self.query_size = query_size
+ self.num_heads = num_heads
+
+ #Batch norm sub-layer
+ self.batch_norm = gluon.nn.BatchNorm()
+
+ #Memory sub-layer
+ self.sub_key_size = sub_key_size
+ sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
+
+ if values_dim == -1:
+ values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
+ else:
+ values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
+
+ self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
+ self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
+ self.values = self.params.get("values", shape=values_shape, differentiable=True)
+ self.label_memory = nd.array([])
+
+ self.get_query_network()
+
+ def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
+ x = self.batch_norm(x)
+
+ x = F.reshape(x, shape=(0, -1))
+
+ q = self.query_network(x)
+
+ q = F.reshape(q, shape=(0, self.num_heads, -1))
+
+ q_split = F.split(q, num_outputs=2, axis=-1)
+
+ if self.dist_measure == "l2":
+ q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
+ sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
+ q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
+ q1_dist = F.norm(q1_diff, axis=-1)
+ q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
+ sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
+ q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
+ q2_dist = F.norm(q2_diff, axis=-1)
+ else:
+ q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
+ q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
+ sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ q1 = [q1]
+ q2 = [q2]
+ sub_keys1_resh = [sub_keys1_resh ]
+ sub_keys2_resh = [sub_keys2_resh ]
+
+ q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
+ q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
+ for h in range(1, self.num_heads):
+ q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+ q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+
+ i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
+ i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
+
+ # Calculate cross product for keys at indices I1 and I2
+
+ # def head_take(data, state):
+ # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
+ #
+ # i1 = F.transpose(i1, axes=(1,0,2))
+ # i2 = F.transpose(i2, axes=(1, 0, 2))
+ # st = F.zeros(1)
+ # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
+ # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
+ # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
+ i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
+ i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
+ sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ i1 = [i1]
+ i2 = [i2]
+ sub_keys1 = [sub_keys1]
+ sub_keys2 = [sub_keys2]
+
+ k1 = F.take(sub_keys1[0], i1[0])
+ k2 = F.take(sub_keys2[0], i2[0])
+ for h in range(1, self.num_heads):
+ k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
+ k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
+
+ k1 = F.tile(k1, (1, 1, self.k, 1))
+ k2 = F.repeat(k2, self.k, 2)
+ c_cart = F.concat(k1, k2, dim=3)
+
+ q = F.reshape(q, shape=(-1,0), reverse=True)
+ q = F.reshape(q, shape=(0, 1, -1))
+ c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
+ if self.dist_measure == "l2":
+ k_diff = F.broadcast_sub(q, c_cart)
+ k_dist = F.norm(k_diff, axis=-1)
+ else:
+ k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
+ k_dist = F.reshape(k_dist, shape=(0, -1))
+
+ i = F.topk(k_dist, k=self.k, ret_typ="both")
+
+ w = F.softmax(i[0])
+ w = F.reshape(w, shape=(0,1,-1))
+ vi = F.take(values, i[1])
+ aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist)
+
+ ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True)
+ one_vec = F.ones((1, 1, self.num_heads))
+ one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0)
+ ret = F.batch_dot(one_vec, ret)
+ ret = F.reshape(ret, shape=(-1, 0), reverse=True)
+
+ return ret
+
+ def get_query_network(self):
+ if hasattr(self, 'query_network'):
+ return self.query_network
+ else:
+ self.query_network = gluon.nn.HybridSequential()
+ for size in self.query_size:
+ if self.query_act == "linear":
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False))
+ else:
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
+ return self.query_network
+
+
+#EpisodicMemory layer
+class EpisodicMemory(EpisodicReplayMemoryInterface):
+ def __init__(self,
+ replay_interval,
+ replay_batch_size,
+ replay_steps,
+ replay_gradient_steps,
+ store_prob,
+ max_stored_samples,
+ memory_replacement_strategy,
+ use_replay,
+ query_net_dir,
+ query_net_prefix,
+ query_net_num_inputs,
+ **kwargs):
+ super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
+ with self.name_scope():
+ #Replay parameters
+ self.store_prob = store_prob
+ self.max_stored_samples = max_stored_samples
+ self.memory_replacement_strategy = memory_replacement_strategy
+
+ self.query_net_dir = query_net_dir
+ self.query_net_prefix = query_net_prefix
+ self.query_net_num_inputs = query_net_num_inputs
+
+ #Memory
+ self.key_memory = nd.array([])
+ self.value_memory = nd.array([])
+ self.label_memory = nd.array([])
+
+ def hybrid_forward(self, F, *args):
+ #propagate the input as the rest is only used for replay
+ return [args, []]
+
+ def store_samples(self, data, y, query_network, store_prob, context):
+ if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
+ num_pus = len(data)
+ sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
+ num_inputs = len(data[0][0])
+ num_outputs = len(y)
+ mx_context = context[0]
+
+ if len(self.key_memory) == 0:
+ self.key_memory = nd.empty(0, ctx=mx.cpu())
+ self.value_memory = []
+ self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
+
+ ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
+
+ max_inds = [nd.max(ind[i]) for i in range(num_pus)]
+ if any(max_inds):
+ to_store_values = []
+ for i in range(num_inputs):
+ tmp_values = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_values, list):
+ tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
+ else:
+ tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_values.append(tmp_values)
+
+ to_store_labels = []
+ for i in range(num_outputs):
+ tmp_labels = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_labels, list):
+ tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
+ else:
+ tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_labels.append(tmp_labels)
+
+ to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
+
+ if self.key_memory.shape[0] == 0:
+ self.key_memory = to_store_keys.as_in_context(mx.cpu())
+ for i in range(num_inputs):
+ self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
+ for i in range(num_outputs):
+ self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
+ elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
+ num_to_store = to_store_keys.shape[0]
+ self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+ else:
+ self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+
+ def sample_memory(self, batch_size):
+ num_stored_samples = self.key_memory.shape[0]
+ if self.replay_batch_size == -1:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu())
+ else:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu())
+
+ num_outputs = len(self.label_memory)
+
+ sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind]
+ sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)]
+
+ return sample_batches
+
+ def get_query_network(self, context):
+ lastEpoch = 0
+ for file in os.listdir(self.query_net_dir):
+ if self.query_net_prefix in file and ".json" in file:
+ symbolFile = file
+
+ if self.query_net_prefix in file and ".param" in file:
+ epochStr = file.replace(".params", "").replace(self.query_net_prefix, "")
+ epoch = int(epochStr)
+ if epoch >= lastEpoch:
+ lastEpoch = epoch
+ weightFile = file
+
+ inputNames = []
+ if self.query_net_num_inputs == 1:
+ inputNames.append("data")
+ else:
+ for i in range(self.query_net_num_inputs):
+ inputNames.append("data" + str(i))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0])
+ net.hybridize()
+ return net
+
+ def save_memory(self, path):
+ mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
+ mem_dict = {entry[0]:entry[1] for entry in mem_arr}
+ nd.save(path, mem_dict)
+
+ def load_memory(self, path):
+ mem_dict = nd.load(path)
+ self.value_memory = []
+ self.label_memory = []
+ for key in sorted(mem_dict.keys()):
+ if key == "keys":
+ self.key_memory = mem_dict[key]
+ elif key.startswith("values_"):
+ self.value_memory.append(mem_dict[key])
+ elif key.startswith("labels_"):
+ self.label_memory.append(mem_dict[key])
+
+
+#Stream 0
class Net_0(gluon.HybridBlock):
- def __init__(self, data_mean=None, data_std=None, **kwargs):
+ def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
with self.name_scope():
if data_mean:
@@ -186,5 +602,5 @@ class Net_0(gluon.HybridBlock):
tanh7_ = self.tanh7_(upconvolution7_)
data_ = F.identity(tanh7_)
- return data_
+ return [[data_]]
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNPredictor_infoGAN_infoGANConnector_predictor.h b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNPredictor_infoGAN_infoGANConnector_predictor.h
index 76240ebe330a66765750cec5cbfa1623022c779d..e0968814cfe44e76d9bbc2669f1ba1827a975443 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNPredictor_infoGAN_infoGANConnector_predictor.h
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNPredictor_infoGAN_infoGANConnector_predictor.h
@@ -1,109 +1,152 @@
#ifndef CNNPREDICTOR_INFOGAN_INFOGANCONNECTOR_PREDICTOR
#define CNNPREDICTOR_INFOGAN_INFOGANCONNECTOR_PREDICTOR
-#include
+#include
#include
#include
#include
+
+#include
+#include
-#include
-
+using namespace mxnet::cpp;
+
class CNNPredictor_infoGAN_infoGANConnector_predictor_0{
public:
- const std::string json_file = "model/infoGAN.InfoGANGenerator/model_0_newest-symbol.json";
- const std::string param_file = "model/infoGAN.InfoGANGenerator/model_0_newest-0000.params";
- const std::vector input_keys = {
+ const std::string file_prefix = "model/infoGAN.InfoGANGenerator/model_0_newest";
+
+ //network
+ const std::vector network_input_keys = {
"data0", "data1"
};
- const std::vector> input_shapes = {{1, 62}, {1, 10}};
- const bool use_gpu = false;
-
- PredictorHandle handle;
-
+ const std::vector> network_input_shapes = {{1, 62}, {1, 10}};
+ std::vector network_input_sizes;
+ std::vector> network_arg_names;
+ std::vector network_handles;
+
+
+ //misc
+ Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu
+ int dtype = 0; //use data type (float32=0 float64=1 ...)
+
+
explicit CNNPredictor_infoGAN_infoGANConnector_predictor_0(){
- init(json_file, param_file, input_keys, input_shapes, use_gpu);
+ init(file_prefix, network_input_keys, network_input_shapes);
}
~CNNPredictor_infoGAN_infoGANConnector_predictor_0(){
- if(handle) MXPredFree(handle);
+ for(Executor * handle : network_handles){
+ delete handle;
+ }
+ MXNotifyShutdown();
}
void predict(const std::vector &in_noise_, const std::vector &in_c1_,
std::vector &out_data_){
- MXPredSetInput(handle, input_keys[0].c_str(), in_noise_.data(), static_cast(in_noise_.size()));
- MXPredSetInput(handle, input_keys[1].c_str(), in_c1_.data(), static_cast(in_c1_.size()));
-
- MXPredForward(handle);
- mx_uint output_index;
- mx_uint *shape = 0;
- mx_uint shape_len;
- size_t size;
-
- output_index = 0;
- MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
- size = 1;
- for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
-
- assert(size == out_data_.size());
- MXPredGetOutput(handle, output_index, &(out_data_[0]), out_data_.size());
+ NDArray input_temp;
+ input_temp = NDArray(network_input_shapes[0], ctx, false, dtype);
+ input_temp.SyncCopyFromCPU(in_noise_.data(), network_input_sizes[0]);
+ input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]]));
+ input_temp = NDArray(network_input_shapes[1], ctx, false, dtype);
+ input_temp.SyncCopyFromCPU(in_c1_.data(), network_input_sizes[1]);
+ input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[1]]));
+ NDArray::WaitAll();
+
+ network_handles[0]->Forward(false);
+ CheckMXNetError("Forward, predict, handle ind. 0");
+
+
+ std::vector output = network_handles.back()->outputs;
+ std::vector curr_output_shape;
+ size_t curr_output_size;
+ curr_output_shape = output[0].GetShape();
+ curr_output_size = 1;
+ for (mx_uint i : curr_output_shape) curr_output_size *= i;
+ //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs
+ assert((curr_output_size == out_data_.size()) || (curr_output_size == out_data_[0]));
+ output[0].SyncCopyToCPU(&out_data_);
+
}
+
+
+
+ Executor* initExecutor(Symbol &sym,
+ std::map ¶m_map,
+ const std::vector &exec_input_keys,
+ const std::vector> &exec_input_shapes){
+
+ const mx_uint num_exec_input_nodes = exec_input_keys.size();
+ for(mx_uint i = 0; i < num_exec_input_nodes; i++){
+ param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype);
+ }
- void init(const std::string &json_file,
- const std::string ¶m_file,
- const std::vector &input_keys,
- const std::vector> &input_shapes,
- const bool &use_gpu){
+ std::vector param_arrays;
+ std::vector grad_array;
+ std::vector grad_reqs;
+ std::vector aux_arrays;
+ std::map< std::string, NDArray> aux_map;
- BufferFile json_data(json_file);
- BufferFile param_data(param_file);
+ sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs,
+ &aux_arrays, param_map, std::map(),
+ std::map(), aux_map);
- int dev_type = use_gpu ? 2 : 1;
- int dev_id = 0;
+ Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays);
+ assert(handle);
+ return handle;
+ }
- if (json_data.GetLength() == 0 ||
- param_data.GetLength() == 0) {
- std::exit(-1);
+ std::vector getSizesOfShapes(const std::vector> shapes){
+ std::vector sizes;
+ for(std::vector shape : shapes){
+ mx_uint val = 1;
+ for(mx_uint i: shape){
+ val *= i;
+ }
+ sizes.push_back(val);
}
+ return sizes;
+ }
- const mx_uint num_input_nodes = input_keys.size();
-
- const char* input_keys_ptr[num_input_nodes];
- for(mx_uint i = 0; i < num_input_nodes; i++){
- input_keys_ptr[i] = input_keys[i].c_str();
+ void CheckMXNetError(std::string loc){
+ const char* err = MXGetLastError();
+ if (err && err[0] != 0) {
+ std::cout << "MXNet error at " << loc << err << std::endl;
+ exit(-1);
}
-
- mx_uint shape_data_size = 0;
- mx_uint input_shape_indptr[input_shapes.size() + 1];
- input_shape_indptr[0] = 0;
- for(mx_uint i = 0; i < input_shapes.size(); i++){
- shape_data_size += input_shapes[i].size();
- input_shape_indptr[i+1] = shape_data_size;
+ }
+
+ void init(const std::string &file_prefix,
+ const std::vector &network_input_keys,
+ const std::vector> &network_input_shapes){
+
+ CNNLAOptimizer_infoGAN_infoGANConnector_predictor optimizer_creator = CNNLAOptimizer_infoGAN_infoGANConnector_predictor();
+
+ if(optimizer_creator.getContextName() == "gpu"){
+ ctx = Context::gpu();
}
-
- mx_uint input_shape_data[shape_data_size];
- mx_uint index = 0;
- for(mx_uint i = 0; i < input_shapes.size(); i++){
- for(mx_uint j = 0; j < input_shapes[i].size(); j++){
- input_shape_data[index] = input_shapes[i][j];
- index++;
- }
+
+ network_input_sizes = getSizesOfShapes(network_input_shapes);
+
+ ModelLoader model_loader(file_prefix, 0, ctx);
+
+ std::vector network_symbols = model_loader.GetNetworkSymbols();
+ std::vector> network_param_maps;
+ network_param_maps = model_loader.GetNetworkParamMaps();
+
+ //Init handles
+ std::map> in_shape_map;
+ for(mx_uint i=0; i < network_input_keys.size(); i++){
+ in_shape_map[network_input_keys[i]] = network_input_shapes[i];
}
-
- MXPredCreate(static_cast(json_data.GetBuffer()),
- static_cast(param_data.GetBuffer()),
- static_cast(param_data.GetLength()),
- dev_type,
- dev_id,
- num_input_nodes,
- input_keys_ptr,
- input_shape_indptr,
- input_shape_data,
- &handle);
- assert(handle);
+ std::vector> in_shapes;
+ std::vector> aux_shapes;
+ std::vector> out_shapes;
+ network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes);
+ network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes));
+
}
};
-
#endif // CNNPREDICTOR_INFOGAN_INFOGANCONNECTOR_PREDICTOR
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNTrainer_infoGAN_infoGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNTrainer_infoGAN_infoGANConnector_predictor.py
index dda1486b1cd310b030c00eec90cdecd1833928cc..97e91e467658d9a2033f55e920f2269eee6b2971 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNTrainer_infoGAN_infoGANConnector_predictor.py
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNTrainer_infoGAN_infoGANConnector_predictor.py
@@ -58,5 +58,3 @@ if __name__ == "__main__":
log_period=10,
print_images=True,
)
-
-
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANDiscriminator.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANDiscriminator.py
index 1254ae9a51cef6d70f8d6f9bc09161cd6ab00ce5..41c442dd044f4b6758382ab1572c39a1b5eb96d7 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANDiscriminator.py
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANDiscriminator.py
@@ -2,6 +2,8 @@ import mxnet as mx
import logging
import os
import shutil
+import warnings
+import inspect
from CNNNet_infoGAN_infoGANDiscriminator import Net_0
@@ -20,6 +22,10 @@ class CNNCreator_infoGAN_infoGANDiscriminator:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0]*num_episodic_sub_nets
+ mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
@@ -30,22 +36,77 @@ class CNNCreator_infoGAN_infoGANDiscriminator:
except OSError:
pass
+ if hasattr(network, 'episodic_sub_nets'):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
+ except OSError:
+ pass
+
+ for j in range(len(network.episodic_sub_nets)):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
+ except OSError:
+ pass
+
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
- epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
+ epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
+ if hasattr(network, 'episodic_sub_nets'):
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading Replay Memory: " + mem_files[j])
+ mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
- if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
- earliestLastEpoch = lastEpoch
+ if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
+ earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
@@ -56,27 +117,52 @@ class CNNCreator_infoGAN_infoGANDiscriminator:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0] * num_episodic_sub_nets
+ mem_files = [None] * num_episodic_sub_nets
+
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
+ if hasattr(network, 'episodic_sub_nets'):
+ assert lastEpoch == lastMemEpoch
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading pretrained Replay Memory: " + mem_files[j])
+ mem_layer = \
+ [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
+ param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
- self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
- self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
+ self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[0].hybridize()
- self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context))
+ self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context[0]))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANQNetwork.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANQNetwork.py
index 7316199fc4c5080033b8e146763c5fe797b5d782..de6c36acc4ed22f7f8b5f93c3f21833c19669b6c 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANQNetwork.py
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANQNetwork.py
@@ -2,6 +2,8 @@ import mxnet as mx
import logging
import os
import shutil
+import warnings
+import inspect
from CNNNet_infoGAN_infoGANQNetwork import Net_0
@@ -20,6 +22,10 @@ class CNNCreator_infoGAN_infoGANQNetwork:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0]*num_episodic_sub_nets
+ mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
@@ -30,22 +36,77 @@ class CNNCreator_infoGAN_infoGANQNetwork:
except OSError:
pass
+ if hasattr(network, 'episodic_sub_nets'):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
+ except OSError:
+ pass
+
+ for j in range(len(network.episodic_sub_nets)):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
+ except OSError:
+ pass
+
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
- epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
+ epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
+ if hasattr(network, 'episodic_sub_nets'):
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading Replay Memory: " + mem_files[j])
+ mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
- if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
- earliestLastEpoch = lastEpoch
+ if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
+ earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
@@ -56,27 +117,52 @@ class CNNCreator_infoGAN_infoGANQNetwork:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0] * num_episodic_sub_nets
+ mem_files = [None] * num_episodic_sub_nets
+
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
+ if hasattr(network, 'episodic_sub_nets'):
+ assert lastEpoch == lastMemEpoch
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading pretrained Replay Memory: " + mem_files[j])
+ mem_layer = \
+ [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
+ param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
- self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
- self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
+ self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[0].hybridize()
- self.networks[0](mx.nd.zeros((1, 512,4,4,), ctx=context))
+ self.networks[0](mx.nd.zeros((1, 512,4,4,), ctx=context[0]))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANDiscriminator.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANDiscriminator.py
index d6000f4999d620a77e4361db82f3dadfe97b9d9b..0147901ba984e585ccab730ef067fe184f7b84fa 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANDiscriminator.py
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANDiscriminator.py
@@ -1,7 +1,10 @@
import mxnet as mx
import numpy as np
import math
-from mxnet import gluon
+import os
+import abc
+import warnings
+from mxnet import gluon, nd
class ZScoreNormalization(gluon.HybridBlock):
@@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
+
+class DotProductSelfAttention(gluon.HybridBlock):
+ def __init__(self,
+ scale_factor,
+ num_heads,
+ dim_model,
+ dim_keys,
+ dim_values,
+ use_proj_bias,
+ use_mask,
+ **kwargs):
+ super(DotProductSelfAttention, self).__init__(**kwargs)
+ with self.name_scope():
+ self.num_heads = num_heads
+ self.dim_model = dim_model
+ self.use_proj_bias = use_proj_bias
+ self.use_mask = use_mask
+
+ if dim_keys == -1:
+ self.dim_keys = int(dim_model / self.num_heads)
+ else:
+ self.dim_keys = dim_keys
+ if dim_values == -1:
+ self.dim_values = int(dim_model / self.num_heads)
+ else:
+ self.dim_values = dim_values
+
+ if scale_factor == -1:
+ self.scale_factor = math.sqrt(self.dim_keys)
+ else:
+ self.scale_factor = scale_factor
+
+ self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)
+
+ def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):
+
+ queries = F.Reshape(queries, shape=(0, 0,-1))
+ keys = F.Reshape(queries, shape=(0, 0, -1))
+ values = F.Reshape(queries, shape=(0, 0, -1))
+
+ head_queries = self.proj_q(queries)
+ head_keys = self.proj_k(keys)
+ head_values = self.proj_v(values)
+
+ head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
+ head_queries = F.transpose(head_queries, axes=(0,2,1,3))
+ head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)
+
+ head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
+ head_keys = F.transpose(head_keys, axes=(0,2,1,3))
+ head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)
+
+ score = F.batch_dot(head_queries, head_keys, transpose_b=True)
+ score = score * self.scale_factor
+ if self.use_mask:
+ mask = F.tile(mask, self.num_heads)
+ mask = F.repeat(mask, self.dim_model)
+ mask = F.reshape(mask, shape=(-1, self.dim_model))
+ weights = F.softmax(score, mask, use_length=self.use_mask)
+
+ head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
+ head_values = F.transpose(head_values, axes=(0,2,1,3))
+ head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)
+
+ ret = F.batch_dot(weights, head_values)
+ ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
+ ret = F.transpose(ret, axes=(0, 2, 1, 3))
+ ret = F.reshape(ret, shape=(0, 0, -1))
+
+ ret = self.proj_o(ret)
+
+ return ret
+
+
+class EpisodicReplayMemoryInterface(gluon.HybridBlock):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
+ super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
+
+ self.use_replay = use_replay
+ self.replay_interval = replay_interval
+ self.replay_batch_size = replay_batch_size
+ self.replay_steps = replay_steps
+ self.replay_gradient_steps = replay_gradient_steps
+ self.num_heads = num_heads
+
+ @abc.abstractmethod
+ def store_samples(self, data, y, query_network, store_prob, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def sample_memory(self, batch_size, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def get_query_network(self, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def save_memory(self, path):
+ pass
+
+ @abc.abstractmethod
+ def load_memory(self, path):
+ pass
+
+#Memory layer
+class LargeMemory(gluon.HybridBlock):
+ def __init__(self,
+ sub_key_size,
+ query_size,
+ query_act,
+ dist_measure,
+ k,
+ num_heads,
+ values_dim,
+ **kwargs):
+ super(LargeMemory, self).__init__(**kwargs)
+ with self.name_scope():
+ #Memory parameters
+ self.dist_measure = dist_measure
+ self.k = k
+ self.num_heads = num_heads
+ self.query_act = query_act
+ self.query_size = query_size
+ self.num_heads = num_heads
+
+ #Batch norm sub-layer
+ self.batch_norm = gluon.nn.BatchNorm()
+
+ #Memory sub-layer
+ self.sub_key_size = sub_key_size
+ sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
+
+ if values_dim == -1:
+ values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
+ else:
+ values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
+
+ self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
+ self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
+ self.values = self.params.get("values", shape=values_shape, differentiable=True)
+ self.label_memory = nd.array([])
+
+ self.get_query_network()
+
+ def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
+ x = self.batch_norm(x)
+
+ x = F.reshape(x, shape=(0, -1))
+
+ q = self.query_network(x)
+
+ q = F.reshape(q, shape=(0, self.num_heads, -1))
+
+ q_split = F.split(q, num_outputs=2, axis=-1)
+
+ if self.dist_measure == "l2":
+ q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
+ sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
+ q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
+ q1_dist = F.norm(q1_diff, axis=-1)
+ q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
+ sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
+ q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
+ q2_dist = F.norm(q2_diff, axis=-1)
+ else:
+ q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
+ q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
+ sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ q1 = [q1]
+ q2 = [q2]
+ sub_keys1_resh = [sub_keys1_resh ]
+ sub_keys2_resh = [sub_keys2_resh ]
+
+ q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
+ q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
+ for h in range(1, self.num_heads):
+ q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+ q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+
+ i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
+ i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
+
+ # Calculate cross product for keys at indices I1 and I2
+
+ # def head_take(data, state):
+ # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
+ #
+ # i1 = F.transpose(i1, axes=(1,0,2))
+ # i2 = F.transpose(i2, axes=(1, 0, 2))
+ # st = F.zeros(1)
+ # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
+ # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
+ # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
+ i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
+ i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
+ sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ i1 = [i1]
+ i2 = [i2]
+ sub_keys1 = [sub_keys1]
+ sub_keys2 = [sub_keys2]
+
+ k1 = F.take(sub_keys1[0], i1[0])
+ k2 = F.take(sub_keys2[0], i2[0])
+ for h in range(1, self.num_heads):
+ k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
+ k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
+
+ k1 = F.tile(k1, (1, 1, self.k, 1))
+ k2 = F.repeat(k2, self.k, 2)
+ c_cart = F.concat(k1, k2, dim=3)
+
+ q = F.reshape(q, shape=(-1,0), reverse=True)
+ q = F.reshape(q, shape=(0, 1, -1))
+ c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
+ if self.dist_measure == "l2":
+ k_diff = F.broadcast_sub(q, c_cart)
+ k_dist = F.norm(k_diff, axis=-1)
+ else:
+ k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
+ k_dist = F.reshape(k_dist, shape=(0, -1))
+
+ i = F.topk(k_dist, k=self.k, ret_typ="both")
+
+ w = F.softmax(i[0])
+ w = F.reshape(w, shape=(0,1,-1))
+ vi = F.take(values, i[1])
+ aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist)
+
+ ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True)
+ one_vec = F.ones((1, 1, self.num_heads))
+ one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0)
+ ret = F.batch_dot(one_vec, ret)
+ ret = F.reshape(ret, shape=(-1, 0), reverse=True)
+
+ return ret
+
+ def get_query_network(self):
+ if hasattr(self, 'query_network'):
+ return self.query_network
+ else:
+ self.query_network = gluon.nn.HybridSequential()
+ for size in self.query_size:
+ if self.query_act == "linear":
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False))
+ else:
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
+ return self.query_network
+
+
+#EpisodicMemory layer
+class EpisodicMemory(EpisodicReplayMemoryInterface):
+ def __init__(self,
+ replay_interval,
+ replay_batch_size,
+ replay_steps,
+ replay_gradient_steps,
+ store_prob,
+ max_stored_samples,
+ memory_replacement_strategy,
+ use_replay,
+ query_net_dir,
+ query_net_prefix,
+ query_net_num_inputs,
+ **kwargs):
+ super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
+ with self.name_scope():
+ #Replay parameters
+ self.store_prob = store_prob
+ self.max_stored_samples = max_stored_samples
+ self.memory_replacement_strategy = memory_replacement_strategy
+
+ self.query_net_dir = query_net_dir
+ self.query_net_prefix = query_net_prefix
+ self.query_net_num_inputs = query_net_num_inputs
+
+ #Memory
+ self.key_memory = nd.array([])
+ self.value_memory = nd.array([])
+ self.label_memory = nd.array([])
+
+ def hybrid_forward(self, F, *args):
+ #propagate the input as the rest is only used for replay
+ return [args, []]
+
+ def store_samples(self, data, y, query_network, store_prob, context):
+ if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
+ num_pus = len(data)
+ sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
+ num_inputs = len(data[0][0])
+ num_outputs = len(y)
+ mx_context = context[0]
+
+ if len(self.key_memory) == 0:
+ self.key_memory = nd.empty(0, ctx=mx.cpu())
+ self.value_memory = []
+ self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
+
+ ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
+
+ max_inds = [nd.max(ind[i]) for i in range(num_pus)]
+ if any(max_inds):
+ to_store_values = []
+ for i in range(num_inputs):
+ tmp_values = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_values, list):
+ tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
+ else:
+ tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_values.append(tmp_values)
+
+ to_store_labels = []
+ for i in range(num_outputs):
+ tmp_labels = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_labels, list):
+ tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
+ else:
+ tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_labels.append(tmp_labels)
+
+ to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
+
+ if self.key_memory.shape[0] == 0:
+ self.key_memory = to_store_keys.as_in_context(mx.cpu())
+ for i in range(num_inputs):
+ self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
+ for i in range(num_outputs):
+ self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
+ elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
+ num_to_store = to_store_keys.shape[0]
+ self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+ else:
+ self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+
+ def sample_memory(self, batch_size):
+ num_stored_samples = self.key_memory.shape[0]
+ if self.replay_batch_size == -1:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu())
+ else:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu())
+
+ num_outputs = len(self.label_memory)
+
+ sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind]
+ sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)]
+
+ return sample_batches
+
+ def get_query_network(self, context):
+ lastEpoch = 0
+ for file in os.listdir(self.query_net_dir):
+ if self.query_net_prefix in file and ".json" in file:
+ symbolFile = file
+
+ if self.query_net_prefix in file and ".param" in file:
+ epochStr = file.replace(".params", "").replace(self.query_net_prefix, "")
+ epoch = int(epochStr)
+ if epoch >= lastEpoch:
+ lastEpoch = epoch
+ weightFile = file
+
+ inputNames = []
+ if self.query_net_num_inputs == 1:
+ inputNames.append("data")
+ else:
+ for i in range(self.query_net_num_inputs):
+ inputNames.append("data" + str(i))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0])
+ net.hybridize()
+ return net
+
+ def save_memory(self, path):
+ mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
+ mem_dict = {entry[0]:entry[1] for entry in mem_arr}
+ nd.save(path, mem_dict)
+
+ def load_memory(self, path):
+ mem_dict = nd.load(path)
+ self.value_memory = []
+ self.label_memory = []
+ for key in sorted(mem_dict.keys()):
+ if key == "keys":
+ self.key_memory = mem_dict[key]
+ elif key.startswith("values_"):
+ self.value_memory.append(mem_dict[key])
+ elif key.startswith("labels_"):
+ self.label_memory.append(mem_dict[key])
+
+
+#Stream 0
class Net_0(gluon.HybridBlock):
- def __init__(self, data_mean=None, data_std=None, **kwargs):
+ def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
with self.name_scope():
if data_mean:
@@ -173,5 +589,5 @@ class Net_0(gluon.HybridBlock):
dis_ = F.identity(sigmoid5_1_)
features_ = F.identity(leakyrelu4_)
- return dis_, features_
+ return [[dis_, features_]]
diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANQNetwork.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANQNetwork.py
index a39b202096c47a7ad603f371922f4f34cb66c5cf..df2d89c866211410aed9bc3cf84fe87d439b6b09 100644
--- a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANQNetwork.py
+++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANQNetwork.py
@@ -1,7 +1,10 @@
import mxnet as mx
import numpy as np
import math
-from mxnet import gluon
+import os
+import abc
+import warnings
+from mxnet import gluon, nd
class ZScoreNormalization(gluon.HybridBlock):
@@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
+
+class DotProductSelfAttention(gluon.HybridBlock):
+ def __init__(self,
+ scale_factor,
+ num_heads,
+ dim_model,
+ dim_keys,
+ dim_values,
+ use_proj_bias,
+ use_mask,
+ **kwargs):
+ super(DotProductSelfAttention, self).__init__(**kwargs)
+ with self.name_scope():
+ self.num_heads = num_heads
+ self.dim_model = dim_model
+ self.use_proj_bias = use_proj_bias
+ self.use_mask = use_mask
+
+ if dim_keys == -1:
+ self.dim_keys = int(dim_model / self.num_heads)
+ else:
+ self.dim_keys = dim_keys
+ if dim_values == -1:
+ self.dim_values = int(dim_model / self.num_heads)
+ else:
+ self.dim_values = dim_values
+
+ if scale_factor == -1:
+ self.scale_factor = math.sqrt(self.dim_keys)
+ else:
+ self.scale_factor = scale_factor
+
+ self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)
+
+ def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):
+
+ queries = F.Reshape(queries, shape=(0, 0,-1))
+ keys = F.Reshape(queries, shape=(0, 0, -1))
+ values = F.Reshape(queries, shape=(0, 0, -1))
+
+ head_queries = self.proj_q(queries)
+ head_keys = self.proj_k(keys)
+ head_values = self.proj_v(values)
+
+ head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
+ head_queries = F.transpose(head_queries, axes=(0,2,1,3))
+ head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)
+
+ head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
+ head_keys = F.transpose(head_keys, axes=(0,2,1,3))
+ head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)
+
+ score = F.batch_dot(head_queries, head_keys, transpose_b=True)
+ score = score * self.scale_factor
+ if self.use_mask:
+ mask = F.tile(mask, self.num_heads)
+ mask = F.repeat(mask, self.dim_model)
+ mask = F.reshape(mask, shape=(-1, self.dim_model))
+ weights = F.softmax(score, mask, use_length=self.use_mask)
+
+ head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
+ head_values = F.transpose(head_values, axes=(0,2,1,3))
+ head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)
+
+ ret = F.batch_dot(weights, head_values)
+ ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
+ ret = F.transpose(ret, axes=(0, 2, 1, 3))
+ ret = F.reshape(ret, shape=(0, 0, -1))
+
+ ret = self.proj_o(ret)
+
+ return ret
+
+
+class EpisodicReplayMemoryInterface(gluon.HybridBlock):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
+ super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
+
+ self.use_replay = use_replay
+ self.replay_interval = replay_interval
+ self.replay_batch_size = replay_batch_size
+ self.replay_steps = replay_steps
+ self.replay_gradient_steps = replay_gradient_steps
+ self.num_heads = num_heads
+
+ @abc.abstractmethod
+ def store_samples(self, data, y, query_network, store_prob, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def sample_memory(self, batch_size, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def get_query_network(self, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def save_memory(self, path):
+ pass
+
+ @abc.abstractmethod
+ def load_memory(self, path):
+ pass
+
+#Memory layer
+class LargeMemory(gluon.HybridBlock):
+ def __init__(self,
+ sub_key_size,
+ query_size,
+ query_act,
+ dist_measure,
+ k,
+ num_heads,
+ values_dim,
+ **kwargs):
+ super(LargeMemory, self).__init__(**kwargs)
+ with self.name_scope():
+ #Memory parameters
+ self.dist_measure = dist_measure
+ self.k = k
+ self.num_heads = num_heads
+ self.query_act = query_act
+ self.query_size = query_size
+ self.num_heads = num_heads
+
+ #Batch norm sub-layer
+ self.batch_norm = gluon.nn.BatchNorm()
+
+ #Memory sub-layer
+ self.sub_key_size = sub_key_size
+ sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
+
+ if values_dim == -1:
+ values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
+ else:
+ values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
+
+ self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
+ self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
+ self.values = self.params.get("values", shape=values_shape, differentiable=True)
+ self.label_memory = nd.array([])
+
+ self.get_query_network()
+
+ def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
+ x = self.batch_norm(x)
+
+ x = F.reshape(x, shape=(0, -1))
+
+ q = self.query_network(x)
+
+ q = F.reshape(q, shape=(0, self.num_heads, -1))
+
+ q_split = F.split(q, num_outputs=2, axis=-1)
+
+ if self.dist_measure == "l2":
+ q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
+ sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
+ q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
+ q1_dist = F.norm(q1_diff, axis=-1)
+ q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
+ sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
+ q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
+ q2_dist = F.norm(q2_diff, axis=-1)
+ else:
+ q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
+ q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
+ sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ q1 = [q1]
+ q2 = [q2]
+ sub_keys1_resh = [sub_keys1_resh ]
+ sub_keys2_resh = [sub_keys2_resh ]
+
+ q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
+ q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
+ for h in range(1, self.num_heads):
+ q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+ q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+
+ i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
+ i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
+
+ # Calculate cross product for keys at indices I1 and I2
+
+ # def head_take(data, state):
+ # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
+ #
+ # i1 = F.transpose(i1, axes=(1,0,2))
+ # i2 = F.transpose(i2, axes=(1, 0, 2))
+ # st = F.zeros(1)
+ # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
+ # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
+ # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
+ i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
+ i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
+ sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ i1 = [i1]
+ i2 = [i2]
+ sub_keys1 = [sub_keys1]
+ sub_keys2 = [sub_keys2]
+
+ k1 = F.take(sub_keys1[0], i1[0])
+ k2 = F.take(sub_keys2[0], i2[0])
+ for h in range(1, self.num_heads):
+ k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
+ k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
+
+ k1 = F.tile(k1, (1, 1, self.k, 1))
+ k2 = F.repeat(k2, self.k, 2)
+ c_cart = F.concat(k1, k2, dim=3)
+
+ q = F.reshape(q, shape=(-1,0), reverse=True)
+ q = F.reshape(q, shape=(0, 1, -1))
+ c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
+ if self.dist_measure == "l2":
+ k_diff = F.broadcast_sub(q, c_cart)
+ k_dist = F.norm(k_diff, axis=-1)
+ else:
+ k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
+ k_dist = F.reshape(k_dist, shape=(0, -1))
+
+ i = F.topk(k_dist, k=self.k, ret_typ="both")
+
+ w = F.softmax(i[0])
+ w = F.reshape(w, shape=(0,1,-1))
+ vi = F.take(values, i[1])
+ aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist)
+
+ ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True)
+ one_vec = F.ones((1, 1, self.num_heads))
+ one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0)
+ ret = F.batch_dot(one_vec, ret)
+ ret = F.reshape(ret, shape=(-1, 0), reverse=True)
+
+ return ret
+
+ def get_query_network(self):
+ if hasattr(self, 'query_network'):
+ return self.query_network
+ else:
+ self.query_network = gluon.nn.HybridSequential()
+ for size in self.query_size:
+ if self.query_act == "linear":
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False))
+ else:
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
+ return self.query_network
+
+
+#EpisodicMemory layer
+class EpisodicMemory(EpisodicReplayMemoryInterface):
+ def __init__(self,
+ replay_interval,
+ replay_batch_size,
+ replay_steps,
+ replay_gradient_steps,
+ store_prob,
+ max_stored_samples,
+ memory_replacement_strategy,
+ use_replay,
+ query_net_dir,
+ query_net_prefix,
+ query_net_num_inputs,
+ **kwargs):
+ super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
+ with self.name_scope():
+ #Replay parameters
+ self.store_prob = store_prob
+ self.max_stored_samples = max_stored_samples
+ self.memory_replacement_strategy = memory_replacement_strategy
+
+ self.query_net_dir = query_net_dir
+ self.query_net_prefix = query_net_prefix
+ self.query_net_num_inputs = query_net_num_inputs
+
+ #Memory
+ self.key_memory = nd.array([])
+ self.value_memory = nd.array([])
+ self.label_memory = nd.array([])
+
+ def hybrid_forward(self, F, *args):
+ #propagate the input as the rest is only used for replay
+ return [args, []]
+
+ def store_samples(self, data, y, query_network, store_prob, context):
+ if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
+ num_pus = len(data)
+ sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
+ num_inputs = len(data[0][0])
+ num_outputs = len(y)
+ mx_context = context[0]
+
+ if len(self.key_memory) == 0:
+ self.key_memory = nd.empty(0, ctx=mx.cpu())
+ self.value_memory = []
+ self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
+
+ ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
+
+ max_inds = [nd.max(ind[i]) for i in range(num_pus)]
+ if any(max_inds):
+ to_store_values = []
+ for i in range(num_inputs):
+ tmp_values = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_values, list):
+ tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
+ else:
+ tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_values.append(tmp_values)
+
+ to_store_labels = []
+ for i in range(num_outputs):
+ tmp_labels = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_labels, list):
+ tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
+ else:
+ tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_labels.append(tmp_labels)
+
+ to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
+
+ if self.key_memory.shape[0] == 0:
+ self.key_memory = to_store_keys.as_in_context(mx.cpu())
+ for i in range(num_inputs):
+ self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
+ for i in range(num_outputs):
+ self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
+ elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
+ num_to_store = to_store_keys.shape[0]
+ self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+ else:
+ self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+
+ def sample_memory(self, batch_size):
+ num_stored_samples = self.key_memory.shape[0]
+ if self.replay_batch_size == -1:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu())
+ else:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu())
+
+ num_outputs = len(self.label_memory)
+
+ sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind]
+ sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)]
+
+ return sample_batches
+
+ def get_query_network(self, context):
+ lastEpoch = 0
+ for file in os.listdir(self.query_net_dir):
+ if self.query_net_prefix in file and ".json" in file:
+ symbolFile = file
+
+ if self.query_net_prefix in file and ".param" in file:
+ epochStr = file.replace(".params", "").replace(self.query_net_prefix, "")
+ epoch = int(epochStr)
+ if epoch >= lastEpoch:
+ lastEpoch = epoch
+ weightFile = file
+
+ inputNames = []
+ if self.query_net_num_inputs == 1:
+ inputNames.append("data")
+ else:
+ for i in range(self.query_net_num_inputs):
+ inputNames.append("data" + str(i))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0])
+ net.hybridize()
+ return net
+
+ def save_memory(self, path):
+ mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
+ mem_dict = {entry[0]:entry[1] for entry in mem_arr}
+ nd.save(path, mem_dict)
+
+ def load_memory(self, path):
+ mem_dict = nd.load(path)
+ self.value_memory = []
+ self.label_memory = []
+ for key in sorted(mem_dict.keys()):
+ if key == "keys":
+ self.key_memory = mem_dict[key]
+ elif key.startswith("values_"):
+ self.value_memory.append(mem_dict[key])
+ elif key.startswith("labels_"):
+ self.label_memory.append(mem_dict[key])
+
+
+#Stream 0
class Net_0(gluon.HybridBlock):
- def __init__(self, data_mean=None, data_std=None, **kwargs):
+ def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
with self.name_scope():
if data_mean:
@@ -120,5 +536,5 @@ class Net_0(gluon.HybridBlock):
softmax2_ = F.softmax(fc2_, axis=-1)
c1_ = F.identity(softmax2_)
- return c1_
+ return [[c1_]]
diff --git a/src/test/resources/target_code/gluon/mnist_mnistClassifier_net.h b/src/test/resources/target_code/gluon/mnist_mnistClassifier_net.h
index 8b21dbcdab7bd81d53191332941887949309ffe1..fdb1bb174d04ac75a4db001bab69a27555142fac 100644
--- a/src/test/resources/target_code/gluon/mnist_mnistClassifier_net.h
+++ b/src/test/resources/target_code/gluon/mnist_mnistClassifier_net.h
@@ -20,8 +20,10 @@ predictions=colvec(classes);
}
void execute(){
vector image_ = CNNTranslator::translate(image);
+
vector predictions_(10);
+
_predictor_0_.predict(image_, predictions_);
predictions = CNNTranslator::translateToCol(predictions_, std::vector {10});
diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNBufferFile.h b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNBufferFile.h
deleted file mode 100644
index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000
--- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNBufferFile.h
+++ /dev/null
@@ -1,51 +0,0 @@
-#ifndef CNNBUFFERFILE_H
-#define CNNBUFFERFILE_H
-
-#include
-#include
-#include
-
-// Read file to buffer
-class BufferFile {
- public :
- std::string file_path_;
- int length_;
- char* buffer_;
-
- explicit BufferFile(std::string file_path)
- :file_path_(file_path) {
-
- std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
- if (!ifs) {
- std::cerr << "Can't open the file. Please check " << file_path << ". \n";
- length_ = 0;
- buffer_ = NULL;
- return;
- }
-
- ifs.seekg(0, std::ios::end);
- length_ = ifs.tellg();
- ifs.seekg(0, std::ios::beg);
- std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
-
- buffer_ = new char[sizeof(char) * length_];
- ifs.read(buffer_, length_);
- ifs.close();
- }
-
- int GetLength() {
- return length_;
- }
- char* GetBuffer() {
- return buffer_;
- }
-
- ~BufferFile() {
- if (buffer_) {
- delete[] buffer_;
- buffer_ = NULL;
- }
- }
-};
-
-#endif // CNNBUFFERFILE_H
diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNCreator_cartpole_master_dqn.py b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNCreator_cartpole_master_dqn.py
index 0f6f1d26d08fc8bbe3ca19df9bc975ab4a16fa0e..000d6bb92d64ba83f6e4606797a8bd4adaf57be6 100644
--- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNCreator_cartpole_master_dqn.py
+++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNCreator_cartpole_master_dqn.py
@@ -2,6 +2,8 @@ import mxnet as mx
import logging
import os
import shutil
+import warnings
+import inspect
from CNNNet_cartpole_master_dqn import Net_0
@@ -20,6 +22,10 @@ class CNNCreator_cartpole_master_dqn:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0]*num_episodic_sub_nets
+ mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
@@ -30,22 +36,77 @@ class CNNCreator_cartpole_master_dqn:
except OSError:
pass
+ if hasattr(network, 'episodic_sub_nets'):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
+ except OSError:
+ pass
+
+ for j in range(len(network.episodic_sub_nets)):
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
+ except OSError:
+ pass
+ try:
+ os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
+ except OSError:
+ pass
+
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
- epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
+ epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
+ if hasattr(network, 'episodic_sub_nets'):
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading Replay Memory: " + mem_files[j])
+ mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
- if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
- earliestLastEpoch = lastEpoch
+ if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
+ earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
@@ -56,27 +117,52 @@ class CNNCreator_cartpole_master_dqn:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
+ if hasattr(network, 'episodic_sub_nets'):
+ num_episodic_sub_nets = len(network.episodic_sub_nets)
+ lastMemEpoch = [0] * num_episodic_sub_nets
+ mem_files = [None] * num_episodic_sub_nets
+
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
- if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
+ if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
- if epoch > lastEpoch:
+ if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
+ elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
+ relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
+ memSubNet = int(relMemPathInfo[0])
+ memEpochStr = relMemPathInfo[1]
+ memEpoch = int(memEpochStr)
+ if memEpoch >= lastMemEpoch[memSubNet-1]:
+ lastMemEpoch[memSubNet-1] = memEpoch
+ mem_files[memSubNet-1] = file
+
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
+ if hasattr(network, 'episodic_sub_nets'):
+ assert lastEpoch == lastMemEpoch
+ for j, sub_net in enumerate(network.episodic_sub_nets):
+ if mem_files[j] != None:
+ logging.info("Loading pretrained Replay Memory: " + mem_files[j])
+ mem_layer = \
+ [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
+ param[0].startswith("memory")][0][1]
+ mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
- self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
- self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
+ self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[0].hybridize()
- self.networks[0](mx.nd.zeros((1, 4,), ctx=context))
+ self.networks[0](mx.nd.zeros((1, 4,), ctx=context[0]))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNModelLoader.h b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNModelLoader.h
new file mode 100644
index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4
--- /dev/null
+++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNModelLoader.h
@@ -0,0 +1,141 @@
+#ifndef CNNMODELLOADER
+#define CNNMODELLOADER
+
+#include
+
+#include
+#include
+#include
+
+using namespace mxnet::cpp;
+
+// Read files to load moddel symbol and parameters
+class ModelLoader {
+private:
+ Context ctx = Context::cpu();
+ std::vector network_symbol_list;
+ std::vector> network_param_map_list;
+
+ std::vector query_symbol_list;
+ std::vector> query_param_map_list;
+
+ std::vector> replay_memory;
+
+ std::vector loss_symbol;
+ std::vector> loss_param_map;
+
+
+ void checkFile(std::string file_path){
+ std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
+ if (!ifs) {
+ std::cerr << "Can't open the file. Please check " << file_path << ". \n";
+ return;
+ }
+
+ int length_;
+ ifs.seekg(0, std::ios::end);
+ length_ = ifs.tellg();
+ ifs.seekg(0, std::ios::beg);
+ std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
+ ifs.close();
+ }
+
+ void loadComponent(std::string json_path,
+ std::string param_path,
+ std::vector &symbols_list,
+ std::vector> ¶m_map_list){
+ checkFile(json_path);
+ symbols_list.push_back(Symbol::Load(json_path));
+ checkFile(param_path);
+ std::map params;
+ NDArray::Load(param_path, 0, ¶ms);
+ param_map_list.push_back(processParamMap(params));
+ }
+
+ std::map processParamMap(std::map param_map){
+ std::map processed_param_map;
+ if(!param_map.empty()){
+ for (const auto &pair : param_map) {
+ std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure)
+ processed_param_map[name] = pair.second.Copy(ctx);
+ }
+ }
+ return processed_param_map;
+ }
+
+public:
+ explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){
+
+ ctx = ctx_param;
+ std::string network_json_path;
+ std::string network_param_path;
+ std::string query_json_path;
+ std::string query_param_path;
+ std::string memory_path;
+ std::string loss_json_path;
+ std::string loss_param_path;
+
+ //Load network
+ if(!num_subnets){
+ network_json_path = file_prefix + "-symbol.json";
+ network_param_path = file_prefix + "-0000.params";
+ loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list);
+ }else{
+ for(int i=0; i < num_subnets; i++){
+ network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json";
+ network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params";
+ loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list);
+ if(i >= 1){
+ query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json";
+ query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params";
+ loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list);
+
+ memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000";
+ checkFile(memory_path);
+
+ std::map mem_map = NDArray::LoadToMap(memory_path);
+ for(auto &mem : mem_map){
+ mem.second = mem.second.Copy(ctx);
+ }
+ replay_memory.push_back(mem_map);
+ }
+ }
+ }
+
+ //Load Loss
+ loss_json_path = file_prefix + "_loss-symbol.json";
+ loss_param_path = file_prefix + "_loss-0000.params";
+ loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map);
+
+ NDArray::WaitAll();
+ }
+
+ std::vector GetNetworkSymbols() {
+ return network_symbol_list;
+ }
+
+ std::vector> GetNetworkParamMaps() {
+ return network_param_map_list;
+ }
+
+ Symbol GetLoss() {
+ return loss_symbol[0];
+ }
+
+ std::map GetLossParamMap() {
+ return loss_param_map[0];
+ }
+
+ std::vector GetQuerySymbols() {
+ return query_symbol_list;
+ }
+
+ std::vector> GetQueryParamMaps() {
+ return query_param_map_list;
+ }
+
+ std::vector> GetReplayMemory(){
+ return replay_memory;
+ }
+};
+#endif // CNNMODELLOADER
diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNNet_cartpole_master_dqn.py b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNNet_cartpole_master_dqn.py
index 98fc737d351baf0c83dc46c3a29cc9955faaef60..acc3c897d0c4784b2120f2b7267c4a19f32dcddc 100644
--- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNNet_cartpole_master_dqn.py
+++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNNet_cartpole_master_dqn.py
@@ -1,7 +1,10 @@
import mxnet as mx
import numpy as np
import math
-from mxnet import gluon
+import os
+import abc
+import warnings
+from mxnet import gluon, nd
class ZScoreNormalization(gluon.HybridBlock):
@@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
+
+class DotProductSelfAttention(gluon.HybridBlock):
+ def __init__(self,
+ scale_factor,
+ num_heads,
+ dim_model,
+ dim_keys,
+ dim_values,
+ use_proj_bias,
+ use_mask,
+ **kwargs):
+ super(DotProductSelfAttention, self).__init__(**kwargs)
+ with self.name_scope():
+ self.num_heads = num_heads
+ self.dim_model = dim_model
+ self.use_proj_bias = use_proj_bias
+ self.use_mask = use_mask
+
+ if dim_keys == -1:
+ self.dim_keys = int(dim_model / self.num_heads)
+ else:
+ self.dim_keys = dim_keys
+ if dim_values == -1:
+ self.dim_values = int(dim_model / self.num_heads)
+ else:
+ self.dim_values = dim_values
+
+ if scale_factor == -1:
+ self.scale_factor = math.sqrt(self.dim_keys)
+ else:
+ self.scale_factor = scale_factor
+
+ self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False)
+ self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False)
+
+ def hybrid_forward(self, F, queries, keys, values, *args, **kwargs):
+
+ queries = F.Reshape(queries, shape=(0, 0,-1))
+ keys = F.Reshape(queries, shape=(0, 0, -1))
+ values = F.Reshape(queries, shape=(0, 0, -1))
+
+ head_queries = self.proj_q(queries)
+ head_keys = self.proj_k(keys)
+ head_values = self.proj_v(values)
+
+ head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1))
+ head_queries = F.transpose(head_queries, axes=(0,2,1,3))
+ head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True)
+
+ head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1))
+ head_keys = F.transpose(head_keys, axes=(0,2,1,3))
+ head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True)
+
+ score = F.batch_dot(head_queries, head_keys, transpose_b=True)
+ score = score * self.scale_factor
+ if self.use_mask:
+ mask = F.tile(mask, self.num_heads)
+ mask = F.repeat(mask, self.dim_model)
+ mask = F.reshape(mask, shape=(-1, self.dim_model))
+ weights = F.softmax(score, mask, use_length=self.use_mask)
+
+ head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
+ head_values = F.transpose(head_values, axes=(0,2,1,3))
+ head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True)
+
+ ret = F.batch_dot(weights, head_values)
+ ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True)
+ ret = F.transpose(ret, axes=(0, 2, 1, 3))
+ ret = F.reshape(ret, shape=(0, 0, -1))
+
+ ret = self.proj_o(ret)
+
+ return ret
+
+
+class EpisodicReplayMemoryInterface(gluon.HybridBlock):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
+ super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
+
+ self.use_replay = use_replay
+ self.replay_interval = replay_interval
+ self.replay_batch_size = replay_batch_size
+ self.replay_steps = replay_steps
+ self.replay_gradient_steps = replay_gradient_steps
+ self.num_heads = num_heads
+
+ @abc.abstractmethod
+ def store_samples(self, data, y, query_network, store_prob, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def sample_memory(self, batch_size, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def get_query_network(self, mx_context):
+ pass
+
+ @abc.abstractmethod
+ def save_memory(self, path):
+ pass
+
+ @abc.abstractmethod
+ def load_memory(self, path):
+ pass
+
+#Memory layer
+class LargeMemory(gluon.HybridBlock):
+ def __init__(self,
+ sub_key_size,
+ query_size,
+ query_act,
+ dist_measure,
+ k,
+ num_heads,
+ values_dim,
+ **kwargs):
+ super(LargeMemory, self).__init__(**kwargs)
+ with self.name_scope():
+ #Memory parameters
+ self.dist_measure = dist_measure
+ self.k = k
+ self.num_heads = num_heads
+ self.query_act = query_act
+ self.query_size = query_size
+ self.num_heads = num_heads
+
+ #Batch norm sub-layer
+ self.batch_norm = gluon.nn.BatchNorm()
+
+ #Memory sub-layer
+ self.sub_key_size = sub_key_size
+ sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2))
+
+ if values_dim == -1:
+ values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1])
+ else:
+ values_shape = (self.sub_key_size*self.sub_key_size, values_dim)
+
+ self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True)
+ self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True)
+ self.values = self.params.get("values", shape=values_shape, differentiable=True)
+ self.label_memory = nd.array([])
+
+ self.get_query_network()
+
+ def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values):
+ x = self.batch_norm(x)
+
+ x = F.reshape(x, shape=(0, -1))
+
+ q = self.query_network(x)
+
+ q = F.reshape(q, shape=(0, self.num_heads, -1))
+
+ q_split = F.split(q, num_outputs=2, axis=-1)
+
+ if self.dist_measure == "l2":
+ q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
+ sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
+ q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
+ q1_dist = F.norm(q1_diff, axis=-1)
+ q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
+ sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
+ q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
+ q2_dist = F.norm(q2_diff, axis=-1)
+ else:
+ q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
+ q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
+ sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ q1 = [q1]
+ q2 = [q2]
+ sub_keys1_resh = [sub_keys1_resh ]
+ sub_keys2_resh = [sub_keys2_resh ]
+
+ q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
+ q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
+ for h in range(1, self.num_heads):
+ q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+ q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
+
+ i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
+ i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
+
+ # Calculate cross product for keys at indices I1 and I2
+
+ # def head_take(data, state):
+ # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
+ #
+ # i1 = F.transpose(i1, axes=(1,0,2))
+ # i2 = F.transpose(i2, axes=(1, 0, 2))
+ # st = F.zeros(1)
+ # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
+ # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
+ # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
+ i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
+ i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
+ sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
+ if self.num_heads == 1:
+ i1 = [i1]
+ i2 = [i2]
+ sub_keys1 = [sub_keys1]
+ sub_keys2 = [sub_keys2]
+
+ k1 = F.take(sub_keys1[0], i1[0])
+ k2 = F.take(sub_keys2[0], i2[0])
+ for h in range(1, self.num_heads):
+ k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1)
+ k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1)
+
+ k1 = F.tile(k1, (1, 1, self.k, 1))
+ k2 = F.repeat(k2, self.k, 2)
+ c_cart = F.concat(k1, k2, dim=3)
+
+ q = F.reshape(q, shape=(-1,0), reverse=True)
+ q = F.reshape(q, shape=(0, 1, -1))
+ c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
+ if self.dist_measure == "l2":
+ k_diff = F.broadcast_sub(q, c_cart)
+ k_dist = F.norm(k_diff, axis=-1)
+ else:
+ k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
+ k_dist = F.reshape(k_dist, shape=(0, -1))
+
+ i = F.topk(k_dist, k=self.k, ret_typ="both")
+
+ w = F.softmax(i[0])
+ w = F.reshape(w, shape=(0,1,-1))
+ vi = F.take(values, i[1])
+ aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist)
+
+ ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True)
+ one_vec = F.ones((1, 1, self.num_heads))
+ one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0)
+ ret = F.batch_dot(one_vec, ret)
+ ret = F.reshape(ret, shape=(-1, 0), reverse=True)
+
+ return ret
+
+ def get_query_network(self):
+ if hasattr(self, 'query_network'):
+ return self.query_network
+ else:
+ self.query_network = gluon.nn.HybridSequential()
+ for size in self.query_size:
+ if self.query_act == "linear":
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False))
+ else:
+ self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
+ return self.query_network
+
+
+#EpisodicMemory layer
+class EpisodicMemory(EpisodicReplayMemoryInterface):
+ def __init__(self,
+ replay_interval,
+ replay_batch_size,
+ replay_steps,
+ replay_gradient_steps,
+ store_prob,
+ max_stored_samples,
+ memory_replacement_strategy,
+ use_replay,
+ query_net_dir,
+ query_net_prefix,
+ query_net_num_inputs,
+ **kwargs):
+ super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
+ with self.name_scope():
+ #Replay parameters
+ self.store_prob = store_prob
+ self.max_stored_samples = max_stored_samples
+ self.memory_replacement_strategy = memory_replacement_strategy
+
+ self.query_net_dir = query_net_dir
+ self.query_net_prefix = query_net_prefix
+ self.query_net_num_inputs = query_net_num_inputs
+
+ #Memory
+ self.key_memory = nd.array([])
+ self.value_memory = nd.array([])
+ self.label_memory = nd.array([])
+
+ def hybrid_forward(self, F, *args):
+ #propagate the input as the rest is only used for replay
+ return [args, []]
+
+ def store_samples(self, data, y, query_network, store_prob, context):
+ if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
+ num_pus = len(data)
+ sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
+ num_inputs = len(data[0][0])
+ num_outputs = len(y)
+ mx_context = context[0]
+
+ if len(self.key_memory) == 0:
+ self.key_memory = nd.empty(0, ctx=mx.cpu())
+ self.value_memory = []
+ self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
+
+ ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
+
+ max_inds = [nd.max(ind[i]) for i in range(num_pus)]
+ if any(max_inds):
+ to_store_values = []
+ for i in range(num_inputs):
+ tmp_values = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_values, list):
+ tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
+ else:
+ tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_values.append(tmp_values)
+
+ to_store_labels = []
+ for i in range(num_outputs):
+ tmp_labels = []
+ for j in range(0, num_pus):
+ if max_inds[j]:
+ if isinstance(tmp_labels, list):
+ tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
+ else:
+ tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
+ to_store_labels.append(tmp_labels)
+
+ to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
+
+ if self.key_memory.shape[0] == 0:
+ self.key_memory = to_store_keys.as_in_context(mx.cpu())
+ for i in range(num_inputs):
+ self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
+ for i in range(num_outputs):
+ self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
+ elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
+ num_to_store = to_store_keys.shape[0]
+ self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+ else:
+ self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
+ for i in range(num_inputs):
+ self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
+ for i in range(num_outputs):
+ self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
+
+ def sample_memory(self, batch_size):
+ num_stored_samples = self.key_memory.shape[0]
+ if self.replay_batch_size == -1:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu())
+ else:
+ sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu())
+
+ num_outputs = len(self.label_memory)
+
+ sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind]
+ sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)]
+
+ return sample_batches
+
+ def get_query_network(self, context):
+ lastEpoch = 0
+ for file in os.listdir(self.query_net_dir):
+ if self.query_net_prefix in file and ".json" in file:
+ symbolFile = file
+
+ if self.query_net_prefix in file and ".param" in file:
+ epochStr = file.replace(".params", "").replace(self.query_net_prefix, "")
+ epoch = int(epochStr)
+ if epoch >= lastEpoch:
+ lastEpoch = epoch
+ weightFile = file
+
+ inputNames = []
+ if self.query_net_num_inputs == 1:
+ inputNames.append("data")
+ else:
+ for i in range(self.query_net_num_inputs):
+ inputNames.append("data" + str(i))
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0])
+ net.hybridize()
+ return net
+
+ def save_memory(self, path):
+ mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
+ mem_dict = {entry[0]:entry[1] for entry in mem_arr}
+ nd.save(path, mem_dict)
+
+ def load_memory(self, path):
+ mem_dict = nd.load(path)
+ self.value_memory = []
+ self.label_memory = []
+ for key in sorted(mem_dict.keys()):
+ if key == "keys":
+ self.key_memory = mem_dict[key]
+ elif key.startswith("values_"):
+ self.value_memory.append(mem_dict[key])
+ elif key.startswith("labels_"):
+ self.label_memory.append(mem_dict[key])
+
+
+#Stream 0
class Net_0(gluon.HybridBlock):
- def __init__(self, data_mean=None, data_std=None, **kwargs):
+ def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
with self.name_scope():
if data_mean:
@@ -121,5 +537,5 @@ class Net_0(gluon.HybridBlock):
fc3_ = self.fc3_(tanh2_)
qvalues_ = F.identity(fc3_)
- return qvalues_
+ return [[qvalues_]]
diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNPredictor_cartpole_master_dqn.h b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNPredictor_cartpole_master_dqn.h
index 9c02b06612cc9982ba9f91a7b1e29b51b206dcfb..b1dd0cf661c18568e8cff572e5bd64d6fa3f54e6 100644
--- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNPredictor_cartpole_master_dqn.h
+++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNPredictor_cartpole_master_dqn.h
@@ -1,107 +1,149 @@
#ifndef CNNPREDICTOR_CARTPOLE_MASTER_DQN
#define CNNPREDICTOR_CARTPOLE_MASTER_DQN
-#include
+#include
#include
#include
#include
+
+#include
+#include
-#include
-
+using namespace mxnet::cpp;
+
class CNNPredictor_cartpole_master_dqn_0{
public:
- const std::string json_file = "model/cartpole.agent.CartPoleDQN/model_0_newest-symbol.json";
- const std::string param_file = "model/cartpole.agent.CartPoleDQN/model_0_newest-0000.params";
- const std::vector input_keys = {
+ const std::string file_prefix = "model/cartpole.agent.CartPoleDQN/model_0_newest";
+
+ //network
+ const std::vector network_input_keys = {
"data"
};
- const std::vector> input_shapes = {{1, 4}};
- const bool use_gpu = false;
-
- PredictorHandle handle;
-
+ const std::vector> network_input_shapes = {{1, 4}};
+ std::vector network_input_sizes;
+ std::vector> network_arg_names;
+ std::vector network_handles;
+
+
+ //misc
+ Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu
+ int dtype = 0; //use data type (float32=0 float64=1 ...)
+
+
explicit CNNPredictor_cartpole_master_dqn_0(){
- init(json_file, param_file, input_keys, input_shapes, use_gpu);
+ init(file_prefix, network_input_keys, network_input_shapes);
}
~CNNPredictor_cartpole_master_dqn_0(){
- if(handle) MXPredFree(handle);
+ for(Executor * handle : network_handles){
+ delete handle;
+ }
+ MXNotifyShutdown();
}
void predict(const std::vector &in_state_,
std::vector &out_qvalues_){
- MXPredSetInput(handle, input_keys[0].c_str(), in_state_.data(), static_cast(in_state_.size()));
-
- MXPredForward(handle);
- mx_uint output_index;
- mx_uint *shape = 0;
- mx_uint shape_len;
- size_t size;
-
- output_index = 0;
- MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
- size = 1;
- for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
- assert(size == out_qvalues_.size());
- MXPredGetOutput(handle, output_index, &(out_qvalues_[0]), out_qvalues_.size());
+ NDArray input_temp;
+ input_temp = NDArray(network_input_shapes[0], ctx, false, dtype);
+ input_temp.SyncCopyFromCPU(in_state_.data(), network_input_sizes[0]);
+ input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]]));
+ NDArray::WaitAll();
+
+ network_handles[0]->Forward(false);
+ CheckMXNetError("Forward, predict, handle ind. 0");
+
+
+ std::vector output = network_handles.back()->outputs;
+ std::vector curr_output_shape;
+ size_t curr_output_size;
+ curr_output_shape = output[0].GetShape();
+ curr_output_size = 1;
+ for (mx_uint i : curr_output_shape) curr_output_size *= i;
+ //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs
+ assert((curr_output_size == out_qvalues_.size()) || (curr_output_size == out_qvalues_[0]));
+ output[0].SyncCopyToCPU(&out_qvalues_);
+
}
+
+
+
+ Executor* initExecutor(Symbol &sym,
+ std::map ¶m_map,
+ const std::vector