diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3b4055745a70669d3cc3ae582dba115c1fc12a95..c8121cf150898070aff59c0e580496edc7f7ec51 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -19,7 +19,7 @@ git masterJobLinux: integrationMXNetJobLinux: stage: linux - image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/applications/gans/mnist-infogan/gans_mxnet:latest + image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.5 script: - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest=IntegrationMXNetTest @@ -33,7 +33,7 @@ integrationCaffe2JobLinux: integrationGluonJobLinux: stage: linux - image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/applications/gans/mnist-infogan/gans_mxnet:latest + image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.5 script: - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest=IntegrationGluonTest diff --git a/README.md b/README.md index 5e5f7d4655a6ee7937c3335abdecdeab6cd0ca98..4524b85c793d6e69e6f58e983ec59a0100494e88 100644 --- a/README.md +++ b/README.md @@ -17,11 +17,11 @@ See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/E * Deep learning backend: * MXNet * training - generated is Python code. Required is Python 2.7 or higher, Python packages `h5py`, `mxnet` (for training on CPU) or e.g. `mxnet-cu75` for CUDA 7.5 (for training on GPU with CUDA, concrete package should be selected according to CUDA version). Follow [official instructions on MXNet site](https://mxnet.incubator.apache.org/install/index.html?platform=Linux&language=Python&processor=CPU) - * prediction - generated code is C++. Install MXNet using [official instructions on MXNet site](https://mxnet.incubator.apache.org) for C++. + * prediction - generated code is C++. * Caffe2 * training - generated is Python code. Follow [ official instructions on Caffe2 site ](https://caffe2.ai/docs/getting-started.html?platform=ubuntu&configuration=prebuilt) - * See the scripts under Installation for better instructions, as an old caffe vversion is used that needs special considerations. + * See the scripts under Installation for better instructions, as an old caffe version is used that needs special considerations. * Gluon @@ -30,16 +30,14 @@ See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/E * prediction - generated code is C++. ## Installation -The two bash scripts found under [installation scripts](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/tree/tensorflow_group/src/main/resources/installation_scripts) -should build and install all prerequisits for all backends as of 26.09.2019. -Note that the installation may take some time (hours) and you will need some disk space (> 60GB) for all backends. Also enough RAM or a big -enough swapspace is advisable (>10GB) for the installation of the cpp part of tensorflow. This scripts were tested with a completly clean Ubuntu 16.04, +A new bash script for mxnet/gluon can be found [installation scripts](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/-/tree/master/src/main/resources/installation_scripts) +changing the installation process for mxnet for cpp. This fill now install the full cpp api and not the reduced c api. This script will install all dependencies both for python and cpp as of 26.08.2020. +Additionally a similar docker script used for the git ci pipeline can be found in the gluon subfolder at [Docker images](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/-/tree/master/src/test/resources/docker). +The other two bash scripts found in the installation_scripts folder are outdated but may be consulted for installation guidlines for other backends. +Note that the installation may take some time (hours) enough RAM or a big enough swapspace is advisable (>10GB). This scripts were tested with a completly clean Ubuntu 16.04, without system updates installed. Using another Ubuntu version or installing other stuff, system updates included might/ have caused problems. If you want to install the backends with CUDA GPU support(only MXNet/Gluon and Tensorflow, the used caffe2 version does not work with GPU support anymore), -you have to install CUDA 10.0(!!), CUDNN and NCCL (Obtainable from the nvidai webpage. You can follow their instructions.) inbetween the two scripts. -Furthermore you will have to change the pip commands for mxnet and tensorflow to the respective commented out parts. -Also docker images for the cpu version of each backend are provided at [Docker images](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/tree/tensorflow_group/src/test/resources/docker), -though some of these might be outdated. +you have to install CUDA 10.0(mxnet/ gluon also works with newer version and maybe older), CUDNN and NCCL (Obtainable from the nvidai webpage). ### HowTo 1. Define a EMADL component containing architecture of a neural network and save it in a `.emadl` file. For more information on architecture language please refer to [CNNArchLang project](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/languages/CNNArchLang). An example of NN architecture: diff --git a/pom.xml b/pom.xml index 06b0cb4852aee9d48bb2af511d9781c4f302ddce..3011a371998cef7a22e5fdc9dd54216fb8f4fa28 100644 --- a/pom.xml +++ b/pom.xml @@ -9,19 +9,19 @@ de.monticore.lang.monticar embedded-montiarc-emadl-generator - 0.4.0 + 0.4.1 - 0.2.11-SNAPSHOT - 0.3.10-SNAPSHOT - 0.0.6-SNAPSHOT + 0.2.12-SNAPSHOT + 0.3.12-SNAPSHOT + 0.0.7-SNAPSHOT 0.2.17-SNAPSHOT 0.2.14-SNAPSHOT - 0.2.11-SNAPSHOT + 0.2.12-SNAPSHOT 0.1.0-SNAPSHOT 0.0.19-SNAPSHOT 0.1.6 diff --git a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLAbstractSymtab.java b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLAbstractSymtab.java index dbc5cf31cd36bd05fb59f05c73d329ef4310a350..5a35d39f647c4565a8504c1c7141748effd2704c 100644 --- a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLAbstractSymtab.java +++ b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLAbstractSymtab.java @@ -7,6 +7,7 @@ import de.monticore.lang.embeddedmontiarc.LogConfig; import de.monticore.lang.embeddedmontiarc.helper.ConstantPortHelper; import de.monticore.lang.monticar.emadl._symboltable.EMADLLanguage; import de.monticore.lang.monticar.emadl.tagging.dltag.DataPathTagSchema; +import de.monticore.lang.monticar.emadl.tagging.dltag.LayerPathParameterTagSchema; import de.monticore.lang.monticar.enumlang._symboltable.EnumLangLanguage; import de.monticore.lang.monticar.generator.cpp.converter.MathConverter; import de.monticore.lang.monticar.generator.optimization.ThreadingOptimizer; @@ -41,6 +42,7 @@ public class EMADLAbstractSymtab { TagThresholdTagSchema.registerTagTypes(tagging); TagDelayTagSchema.registerTagTypes(tagging); DataPathTagSchema.registerTagTypes(tagging); + LayerPathParameterTagSchema.registerTagTypes(tagging); return tagging; } diff --git a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java index 8fecb4471919df7bfc57d6daa61f40701cbcf97b..c3a465ba0cff1b9d3e6a6979326e235675c9c0e0 100644 --- a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java +++ b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java @@ -26,6 +26,7 @@ import de.monticore.lang.monticar.cnntrain._symboltable.PreprocessingComponentSy import de.monticore.lang.monticar.emadl._cocos.DataPathCocos; import de.monticore.lang.monticar.emadl._cocos.EMADLCocos; import de.monticore.lang.monticar.emadl.tagging.dltag.DataPathSymbol; +import de.monticore.lang.monticar.emadl.tagging.dltag.LayerPathParameterSymbol; import de.monticore.lang.monticar.generator.FileContent; import de.monticore.lang.monticar.generator.cpp.ArmadilloHelper; import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP; @@ -394,7 +395,7 @@ public class EMADLGenerator { // TODO: Replace warinings with errors, until then use this method stopGeneratorIfWarning(); - Log.warn("Tagging info for symbol was found, ignoring data_paths.txt: " + dataPath); + Log.warn("Tagging info for DataPath symbol was found, ignoring data_paths.txt: " + dataPath); } else { Path dataPathDefinition = Paths.get(getModelsPath(), "data_paths.txt"); @@ -426,6 +427,37 @@ public class EMADLGenerator { return weightsPath; } + protected HashMap getLayerPathParameterTags(TaggingResolver taggingResolver, EMAComponentSymbol component, EMAComponentInstanceSymbol instance){ + List instanceTags = new LinkedList<>(); + + boolean isChildComponent = instance.getEnclosingComponent().isPresent(); + + if (isChildComponent) { + // get all instantiated components of parent + List instantiationSymbols = (List) instance + .getEnclosingComponent().get().getComponentType().getReferencedSymbol().getSubComponents(); + + // filter corresponding instantiation of instance and add tags + instantiationSymbols.stream().filter(e -> e.getName().equals(instance.getName())).findFirst() + .ifPresent(symbol -> instanceTags.addAll(taggingResolver.getTags(symbol, LayerPathParameterSymbol.KIND))); + } + + List tags = !instanceTags.isEmpty() ? instanceTags + : (List) taggingResolver.getTags(component, LayerPathParameterSymbol.KIND); + + HashMap layerPathParameterTags = new HashMap(); + if (!tags.isEmpty()) { + for(TagSymbol tag: tags) { + LayerPathParameterSymbol layerPathParameterSymbol = (LayerPathParameterSymbol) tag; + layerPathParameterTags.put(layerPathParameterSymbol.getId(), layerPathParameterSymbol.getPath()); + } + // TODO: Replace warinings with errors, until then use this method + stopGeneratorIfWarning(); + Log.warn("Tagging info for LayerPathParameter symbols was found."); + } + return layerPathParameterTags; + } + protected void generateComponent(List fileContents, Set allInstances, TaggingResolver taggingResolver, @@ -448,8 +480,10 @@ public class EMADLGenerator { cnnArchGenerator.check(architecture.get()); String dPath = getDataPath(taggingResolver, EMAComponentSymbol, componentInstanceSymbol); String wPath = getWeightsPath(EMAComponentSymbol, componentInstanceSymbol); + HashMap layerPathParameterTags = getLayerPathParameterTags(taggingResolver, EMAComponentSymbol, componentInstanceSymbol); architecture.get().setDataPath(dPath); architecture.get().setWeightsPath(wPath); + architecture.get().processLayerPathParameterTags(layerPathParameterTags); architecture.get().setComponentName(EMAComponentSymbol.getFullName()); generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get()); if (processedArchitecture != null) { diff --git a/src/main/resources/installation_scripts/install_after_cuda b/src/main/resources/installation_scripts/legacy_script_all_backends/install_after_cuda similarity index 100% rename from src/main/resources/installation_scripts/install_after_cuda rename to src/main/resources/installation_scripts/legacy_script_all_backends/install_after_cuda diff --git a/src/main/resources/installation_scripts/install_before_cuda b/src/main/resources/installation_scripts/legacy_script_all_backends/install_before_cuda similarity index 100% rename from src/main/resources/installation_scripts/install_before_cuda rename to src/main/resources/installation_scripts/legacy_script_all_backends/install_before_cuda diff --git a/src/main/resources/installation_scripts/mxnet_gluon_installation_script.sh b/src/main/resources/installation_scripts/mxnet_gluon_installation_script.sh new file mode 100644 index 0000000000000000000000000000000000000000..a13eb17e8baf3a3d4850efc94096267882329931 --- /dev/null +++ b/src/main/resources/installation_scripts/mxnet_gluon_installation_script.sh @@ -0,0 +1,33 @@ +sudo apt-get update -y +sudp apt-get install -y build-essential git openjdk-8-jdk maven ninja-build ccache libopenblas-dev libblas-dev / + liblapack-dev libopencv-dev libarmadillo-dev cmake python2.7 python-dev / + python-numpy python3-pip python3-pip swig unzip libboost-all-dev + +sudo update-alternatives --config java + +pip3 install --user --upgrade "cmake>=3.13.2" + +wget https://bootstrap.pypa.io/get-pip.py +python get-pip.py +pip install --user h5py matplotlib numpy==1.16.5 mxnet==1.5.1.post0 #The newest version installed the curent standard version installed v1.7.0 cant be run with python2 the current standard of the EMDAL2CPP generator, + #As the needed numpy version is not suported anymore for python2 (python will no longer be supported). + #Further more not all test work with mxnet v1.6.0. And when just using v1.5.1 you can't compile against the libmxnet.so needed for compiling + #the cpp prediction part, the same holds for 1.6.0 but not 1.5.1.post0 and versions later than 1.6.0 (1.7.0) as there was some compression done for this + #library which was then droped again by the developers of mxnet. + #You could alternativly also use python 3.6 instead of 2.7, then you could also use the newest numpy version. + #Note that you then have to also set the PYTHON_PATH acordingly, or specifiy the python path for all applications in their build scripts + #and test currently only run on the python specified in the PYTHON_PATH. + #If you want to use mxnet with cuda install f.e. mxnet-cu100 for cuda 10.0 (if v1.5.1 is the newest version (here it specifing post0 is not + #neccescery), otherwise no gurantee for the numpy dependency, see above), #of course then you have to install cuda and cudnn beforehand. + +git clone --recursive https://github.com/apache/incubator-mxnet.git mxnet +cd mxnet && git checkout tags/1.5.0 && git submodule update --recursive --init +cd mxnet && mkdir build && cd build && cmake -DUSE_CPP_PACKAGE=1 -DUSE_CUDA=0 -GNinja .. && ninja -v +cd mxnet && cp -r include/mxnet /usr/include/mxnet && cp -r cpp-package/include/mxnet-cpp /usr/include/ && cp -r 3rdparty/tvm/nnvm/include/nnvm /usr/include/ && cp -r 3rdparty/dmlc-core/include/dmlc /usr/include/ + +#you have to have armadillo-9.600.6.zip in your current folder +unzip armadillo.zip -d . +cd armadillo-9.600.6 && cmake . && make && make install + +mkdir -p /root/.config/matplotlib +echo "backend : Agg" > /root/.config/matplotlib/matplotlibrc \ No newline at end of file diff --git a/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java index 6c988f46bf6bd1d9aa26316b1fbea73de47629f2..e14e7ed20ed5fd93bf855618ada9f884b2b8ec80 100644 --- a/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java +++ b/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java @@ -42,7 +42,7 @@ public class GenerationTest extends AbstractSymtabTest { "cifar10_cifar10Classifier.cpp", "cifar10_cifar10Classifier.h", "CNNCreator_cifar10_cifar10Classifier_net.py", - "CNNBufferFile.h", + "CNNModelLoader.h", "CNNPredictor_cifar10_cifar10Classifier_net.h", "cifar10_cifar10Classifier_net.h", "CNNTranslator.h", @@ -106,6 +106,13 @@ public class GenerationTest extends AbstractSymtabTest { assertTrue(Log.getFindings().isEmpty()); } + @Test + public void testEpisodicMemorySimpleGeneration() throws IOException, TemplateException { + Log.getFindings().clear(); + String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON", "-f", "n", "-c", "n"}; + EMADLGeneratorCli.main(args); + } + @Test public void testMultipleInstances() throws IOException, TemplateException { try { @@ -175,7 +182,7 @@ public class GenerationTest extends AbstractSymtabTest { Paths.get("./target/generated-sources-emadl"), Paths.get("./src/test/resources/target_code/gluon"), Arrays.asList( - "CNNBufferFile.h", + "CNNModelLoader.h", "CNNNet_mnist_mnistClassifier_net.py", "mnist_mnistClassifier.cpp", "mnist_mnistClassifier.h", @@ -183,7 +190,6 @@ public class GenerationTest extends AbstractSymtabTest { "CNNPredictor_mnist_mnistClassifier_net.h", "CNNDataLoader_mnist_mnistClassifier_net.py", "CNNSupervisedTrainer_mnist_mnistClassifier_net.py", - "mnist_mnistClassifier_net.h", "HelperA.h", "CNNTranslator.h", "mnist_mnistClassifier_calculateClass.h", @@ -215,7 +221,7 @@ public class GenerationTest extends AbstractSymtabTest { "cartpole_master_dqn.h", "cartpole_master_policy.h", "CMakeLists.txt", - "CNNBufferFile.h", + "CNNModelLoader.h", "CNNCreator_cartpole_master_dqn.py", "CNNNet_cartpole_master_dqn.py", "CNNPredictor_cartpole_master_dqn.h", @@ -260,7 +266,7 @@ public class GenerationTest extends AbstractSymtabTest { "mountaincar_master.h", "mountaincar_master_actor.h", "CMakeLists.txt", - "CNNBufferFile.h", + "CNNModelLoader.h", "CNNCreator_mountaincar_master_actor.py", "CNNNet_mountaincar_master_actor.py", "CNNPredictor_mountaincar_master_actor.h", @@ -300,9 +306,6 @@ public class GenerationTest extends AbstractSymtabTest { "CNNTrainer_defaultGAN_defaultGANConnector_predictor.py", "defaultGAN_defaultGANConnector.cpp", "defaultGAN_defaultGANConnector.h", - "defaultGAN_defaultGANConnector_predictor.h", - "defaultGAN_defaultGANConnector.cpp", - "defaultGAN_defaultGANConnector.h", "defaultGAN_defaultGANConnector_predictor.h" ) ); @@ -361,7 +364,7 @@ public class GenerationTest extends AbstractSymtabTest { EMADLGeneratorCli.main(args); assertEquals(Log.getFindings().size(), 1); assertEquals(Log.getFindings().get(0).toString(), - "Tagging info for symbol was found, ignoring data_paths.txt: src/test/resources/models"); + "Tagging info for DataPath symbol was found, ignoring data_paths.txt: src/test/resources/models"); assertTrue(Log.getErrorCount() == 0); } diff --git a/src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java b/src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java index 1a64af83fbff4ee85b56691f6274218c6f94595e..91afa7469ebb5a9c8e94dd0e473d52cbd61dfede 100644 --- a/src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java +++ b/src/test/java/de/monticore/lang/monticar/emadl/IntegrationGluonTest.java @@ -70,6 +70,16 @@ public class IntegrationGluonTest extends IntegrationTest { assertTrue(Log.getFindings().isEmpty()); } + @Test + public void testEpisodicMemorySimple() { + Log.getFindings().clear(); + + deleteHashFile(Paths.get("./target/generated-sources-emadl/episodicMemorySimple/episodicMemorySimple.training_hash")); + + String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON"}; + EMADLGeneratorCli.main(args); + } + @Test public void testGluonPreprocessingWithSupervised() { Log.getFindings().clear(); diff --git a/src/test/java/de/monticore/lang/monticar/emadl/IntegrationPythonWrapperTest.java b/src/test/java/de/monticore/lang/monticar/emadl/IntegrationPythonWrapperTest.java index 4334b7ea3f16c1a883d4177b6a160369a8d7c2a6..6e18f1455e6edc06ca637ac53c3d40a67f0e4d4d 100644 --- a/src/test/java/de/monticore/lang/monticar/emadl/IntegrationPythonWrapperTest.java +++ b/src/test/java/de/monticore/lang/monticar/emadl/IntegrationPythonWrapperTest.java @@ -27,7 +27,7 @@ public class IntegrationPythonWrapperTest extends AbstractSymtabTest { Paths.get("./src/test/resources/target_code/gluon/reinforcementModel/torcs"), Arrays.asList( "CMakeLists.txt", - "CNNBufferFile.h", + "CNNModelLoader.h", "torcs_agent_torcsAgent.cpp", "torcs_agent_torcsAgent.h", "torcs_agent_torcsAgent_dqn.h", @@ -82,7 +82,7 @@ public class IntegrationPythonWrapperTest extends AbstractSymtabTest { Paths.get("./src/test/resources/target_code/gluon/reinforcementModel/torcs_td3"), Arrays.asList( "CMakeLists.txt", - "CNNBufferFile.h", + "CNNModelLoader.h", "torcs_agent_torcsAgent.cpp", "torcs_agent_torcsAgent.h", "torcs_agent_torcsAgent_actor.h", diff --git a/src/test/resources/docker/mxnet/Dockerfile b/src/test/resources/docker/mxnet/Dockerfile index e374b94a22f412f2c85a73c88d2f681c91e190e2..7c5945181003075b454ec478a3367be2a97aedc4 100644 --- a/src/test/resources/docker/mxnet/Dockerfile +++ b/src/test/resources/docker/mxnet/Dockerfile @@ -1,19 +1,22 @@ FROM maven:3-jdk-8 -RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - git \ - libgtk2.0-dev \ - python-subprocess32 \ - python-tk \ - wget python gcc \ - build-essential cmake \ - liblapack-dev libblas-dev libboost-dev libarmadillo-dev && \ - rm -rf /var/lib/apt/lists/* -RUN git clone https://github.com/apache/incubator-mxnet.git mxnet-source && \ - cd mxnet-source && git checkout tags/1.4.0 && cd .. && \ - cp -r mxnet-source/include/mxnet /usr/include/mxnet && \ - rm -r mxnet-source +RUN apt-get update +RUN apt-get install -y build-essential git ninja-build ccache libopenblas-dev libblas-dev liblapack-dev libopencv-dev libarmadillo-dev cmake python2.7 python-dev python-numpy python3-pip python3-pip swig unzip libboost-all-dev + +RUN pip3 install --user --upgrade "cmake>=3.13.2" + RUN wget https://bootstrap.pypa.io/get-pip.py RUN python get-pip.py -RUN pip install mxnet h5py opencv-python matplotlib +RUN pip install --user h5py matplotlib numpy==1.16.5 mxnet==1.5.1.post0 + +RUN git clone --recursive https://github.com/apache/incubator-mxnet.git mxnet +RUN cd mxnet && git checkout tags/1.5.0 && git submodule update --recursive --init +RUN cd mxnet && mkdir build && cd build && cmake -DUSE_CPP_PACKAGE=1 -DUSE_CUDA=0 -GNinja .. && ninja -v +RUN cd mxnet && cp -r include/mxnet /usr/include/mxnet && cp -r cpp-package/include/mxnet-cpp /usr/include/ && cp -r 3rdparty/tvm/nnvm/include/nnvm /usr/include/ && cp -r 3rdparty/dmlc-core/include/dmlc /usr/include/ + +ADD armadillo-9.600.6.zip /root/armadillo.zip +RUN unzip /root/armadillo.zip -d /root/armadillo +RUN cd /root/armadillo/armadillo-9.600.6 && cmake . && make && make install + +RUN mkdir -p /root/.config/matplotlib +RUN echo "backend : Agg" > /root/.config/matplotlib/matplotlibrc diff --git a/src/test/resources/docker/mxnet/armadillo-9.600.6.zip b/src/test/resources/docker/mxnet/armadillo-9.600.6.zip new file mode 100644 index 0000000000000000000000000000000000000000..e053f3bc1670da20fb0a40eea7ee27a8ede111eb Binary files /dev/null and b/src/test/resources/docker/mxnet/armadillo-9.600.6.zip differ diff --git a/src/test/resources/models/episodicMemorySimple/Network.cnnt b/src/test/resources/models/episodicMemorySimple/Network.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..af61f45850c4c3529e2f3f0aac40bce4ed862389 --- /dev/null +++ b/src/test/resources/models/episodicMemorySimple/Network.cnnt @@ -0,0 +1,13 @@ +/* (c) https://github.com/MontiCore/monticore */ +configuration Network{ + num_epoch:1 + batch_size:5 + normalize:false + context:cpu + load_checkpoint:false + loss:cross_entropy + optimizer:adam{ + learning_rate:0.00003 + weight_decay:0.01 + } +} diff --git a/src/test/resources/models/episodicMemorySimple/Network.emadl b/src/test/resources/models/episodicMemorySimple/Network.emadl new file mode 100644 index 0000000000000000000000000000000000000000..17341b277eaebbd1e4714fe2ec222f8764a66281 --- /dev/null +++ b/src/test/resources/models/episodicMemorySimple/Network.emadl @@ -0,0 +1,16 @@ +/* (c) https://github.com/MontiCore/monticore */ +package episodicMemorySimple; + +component Network{ + ports in Z(0:oo)^{10} data, + out Q(0:1)^{33} softmax; + + implementation CNN { + data -> + EpisodicMemory(replayInterval=10, replayBatchSize=100, replaySteps=1, replayGradientSteps=1, replayMemoryStoreProb=0.5, localAdaptionGradientSteps=30, maxStoredSamples=-1, localAdaptionK=32, queryNetDir="tag:simple", queryNetPrefix="simple_embedding-", queryNetNumInputs=1) -> + LoadNetwork(networkDir="tag:simple", networkPrefix="simple_embedding-", numInputs=1, outputShape=(1,768)) -> + FullyConnected(units=33) -> + Softmax() -> + softmax; + } +} diff --git a/src/test/resources/models/episodicMemorySimple/episodicMemorySimple.tag b/src/test/resources/models/episodicMemorySimple/episodicMemorySimple.tag new file mode 100644 index 0000000000000000000000000000000000000000..0b5b8796a5b2257d0264e17928a0aa55b2ec6828 --- /dev/null +++ b/src/test/resources/models/episodicMemorySimple/episodicMemorySimple.tag @@ -0,0 +1,8 @@ +/* (c) https://github.com/MontiCore/monticore */ +package episodicMemorySimple; +conforms to dltag.DataPathTagSchema, dltag.LayerPathParameterTagSchema; + +tags episodic { +tag Network with DataPath = {path = src/test/resources/training_data/episodicMemorySimple, type = HDF5}; +tag Network with LayerPathParameter = {path = src/test/resources/pretrained/episodicMemorySimple, id = simple}; +} diff --git a/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-0000.params b/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-0000.params new file mode 100644 index 0000000000000000000000000000000000000000..3288444a7f261cd83beb3be9c9655c2e5a376e3f Binary files /dev/null and b/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-0000.params differ diff --git a/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-symbol.json b/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-symbol.json new file mode 100644 index 0000000000000000000000000000000000000000..31c823adc6c18b5861cc3a3bd99ae6ad078b69ef --- /dev/null +++ b/src/test/resources/pretrained/episodicMemorySimple/simple_embedding-symbol.json @@ -0,0 +1,18 @@ +{ + "nodes": [ + { + "op": "null", + "name": "data", + "inputs": [] + }, + { + "op": "_copy", + "name": "simpleembedding0_identity0", + "inputs": [[0, 0, 0]] + } + ], + "arg_nodes": [0], + "node_row_ptr": [0, 1, 2], + "heads": [[1, 0, 0]], + "attrs": {"mxnet_version": ["int", 10501]} +} \ No newline at end of file diff --git a/src/test/resources/target_code/CNNBufferFile.h b/src/test/resources/target_code/CNNBufferFile.h deleted file mode 100644 index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000 --- a/src/test/resources/target_code/CNNBufferFile.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef CNNBUFFERFILE_H -#define CNNBUFFERFILE_H - -#include -#include -#include - -// Read file to buffer -class BufferFile { - public : - std::string file_path_; - int length_; - char* buffer_; - - explicit BufferFile(std::string file_path) - :file_path_(file_path) { - - std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); - if (!ifs) { - std::cerr << "Can't open the file. Please check " << file_path << ". \n"; - length_ = 0; - buffer_ = NULL; - return; - } - - ifs.seekg(0, std::ios::end); - length_ = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; - - buffer_ = new char[sizeof(char) * length_]; - ifs.read(buffer_, length_); - ifs.close(); - } - - int GetLength() { - return length_; - } - char* GetBuffer() { - return buffer_; - } - - ~BufferFile() { - if (buffer_) { - delete[] buffer_; - buffer_ = NULL; - } - } -}; - -#endif // CNNBUFFERFILE_H diff --git a/src/test/resources/target_code/CNNModelLoader.h b/src/test/resources/target_code/CNNModelLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4 --- /dev/null +++ b/src/test/resources/target_code/CNNModelLoader.h @@ -0,0 +1,141 @@ +#ifndef CNNMODELLOADER +#define CNNMODELLOADER + +#include + +#include +#include +#include + +using namespace mxnet::cpp; + +// Read files to load moddel symbol and parameters +class ModelLoader { +private: + Context ctx = Context::cpu(); + std::vector network_symbol_list; + std::vector> network_param_map_list; + + std::vector query_symbol_list; + std::vector> query_param_map_list; + + std::vector> replay_memory; + + std::vector loss_symbol; + std::vector> loss_param_map; + + + void checkFile(std::string file_path){ + std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); + if (!ifs) { + std::cerr << "Can't open the file. Please check " << file_path << ". \n"; + return; + } + + int length_; + ifs.seekg(0, std::ios::end); + length_ = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; + ifs.close(); + } + + void loadComponent(std::string json_path, + std::string param_path, + std::vector &symbols_list, + std::vector> ¶m_map_list){ + checkFile(json_path); + symbols_list.push_back(Symbol::Load(json_path)); + checkFile(param_path); + std::map params; + NDArray::Load(param_path, 0, ¶ms); + param_map_list.push_back(processParamMap(params)); + } + + std::map processParamMap(std::map param_map){ + std::map processed_param_map; + if(!param_map.empty()){ + for (const auto &pair : param_map) { + std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure) + processed_param_map[name] = pair.second.Copy(ctx); + } + } + return processed_param_map; + } + +public: + explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){ + + ctx = ctx_param; + std::string network_json_path; + std::string network_param_path; + std::string query_json_path; + std::string query_param_path; + std::string memory_path; + std::string loss_json_path; + std::string loss_param_path; + + //Load network + if(!num_subnets){ + network_json_path = file_prefix + "-symbol.json"; + network_param_path = file_prefix + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + }else{ + for(int i=0; i < num_subnets; i++){ + network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json"; + network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + if(i >= 1){ + query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json"; + query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params"; + loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list); + + memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000"; + checkFile(memory_path); + + std::map mem_map = NDArray::LoadToMap(memory_path); + for(auto &mem : mem_map){ + mem.second = mem.second.Copy(ctx); + } + replay_memory.push_back(mem_map); + } + } + } + + //Load Loss + loss_json_path = file_prefix + "_loss-symbol.json"; + loss_param_path = file_prefix + "_loss-0000.params"; + loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map); + + NDArray::WaitAll(); + } + + std::vector GetNetworkSymbols() { + return network_symbol_list; + } + + std::vector> GetNetworkParamMaps() { + return network_param_map_list; + } + + Symbol GetLoss() { + return loss_symbol[0]; + } + + std::map GetLossParamMap() { + return loss_param_map[0]; + } + + std::vector GetQuerySymbols() { + return query_symbol_list; + } + + std::vector> GetQueryParamMaps() { + return query_param_map_list; + } + + std::vector> GetReplayMemory(){ + return replay_memory; + } +}; +#endif // CNNMODELLOADER diff --git a/src/test/resources/target_code/gluon/CNNBufferFile.h b/src/test/resources/target_code/gluon/CNNBufferFile.h deleted file mode 100644 index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000 --- a/src/test/resources/target_code/gluon/CNNBufferFile.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef CNNBUFFERFILE_H -#define CNNBUFFERFILE_H - -#include -#include -#include - -// Read file to buffer -class BufferFile { - public : - std::string file_path_; - int length_; - char* buffer_; - - explicit BufferFile(std::string file_path) - :file_path_(file_path) { - - std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); - if (!ifs) { - std::cerr << "Can't open the file. Please check " << file_path << ". \n"; - length_ = 0; - buffer_ = NULL; - return; - } - - ifs.seekg(0, std::ios::end); - length_ = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; - - buffer_ = new char[sizeof(char) * length_]; - ifs.read(buffer_, length_); - ifs.close(); - } - - int GetLength() { - return length_; - } - char* GetBuffer() { - return buffer_; - } - - ~BufferFile() { - if (buffer_) { - delete[] buffer_; - buffer_ = NULL; - } - } -}; - -#endif // CNNBUFFERFILE_H diff --git a/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py b/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py index 21ce4267c47a7ef296f301abf76932775e07bbfa..814f8ebaa5b866cf328ff459946f0af3a4571be2 100644 --- a/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py +++ b/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_mnist_mnistClassifier_net import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_mnist_mnistClassifier_net: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_mnist_mnistClassifier_net: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_mnist_mnistClassifier_net: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/CNNModelLoader.h b/src/test/resources/target_code/gluon/CNNModelLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4 --- /dev/null +++ b/src/test/resources/target_code/gluon/CNNModelLoader.h @@ -0,0 +1,141 @@ +#ifndef CNNMODELLOADER +#define CNNMODELLOADER + +#include + +#include +#include +#include + +using namespace mxnet::cpp; + +// Read files to load moddel symbol and parameters +class ModelLoader { +private: + Context ctx = Context::cpu(); + std::vector network_symbol_list; + std::vector> network_param_map_list; + + std::vector query_symbol_list; + std::vector> query_param_map_list; + + std::vector> replay_memory; + + std::vector loss_symbol; + std::vector> loss_param_map; + + + void checkFile(std::string file_path){ + std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); + if (!ifs) { + std::cerr << "Can't open the file. Please check " << file_path << ". \n"; + return; + } + + int length_; + ifs.seekg(0, std::ios::end); + length_ = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; + ifs.close(); + } + + void loadComponent(std::string json_path, + std::string param_path, + std::vector &symbols_list, + std::vector> ¶m_map_list){ + checkFile(json_path); + symbols_list.push_back(Symbol::Load(json_path)); + checkFile(param_path); + std::map params; + NDArray::Load(param_path, 0, ¶ms); + param_map_list.push_back(processParamMap(params)); + } + + std::map processParamMap(std::map param_map){ + std::map processed_param_map; + if(!param_map.empty()){ + for (const auto &pair : param_map) { + std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure) + processed_param_map[name] = pair.second.Copy(ctx); + } + } + return processed_param_map; + } + +public: + explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){ + + ctx = ctx_param; + std::string network_json_path; + std::string network_param_path; + std::string query_json_path; + std::string query_param_path; + std::string memory_path; + std::string loss_json_path; + std::string loss_param_path; + + //Load network + if(!num_subnets){ + network_json_path = file_prefix + "-symbol.json"; + network_param_path = file_prefix + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + }else{ + for(int i=0; i < num_subnets; i++){ + network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json"; + network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + if(i >= 1){ + query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json"; + query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params"; + loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list); + + memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000"; + checkFile(memory_path); + + std::map mem_map = NDArray::LoadToMap(memory_path); + for(auto &mem : mem_map){ + mem.second = mem.second.Copy(ctx); + } + replay_memory.push_back(mem_map); + } + } + } + + //Load Loss + loss_json_path = file_prefix + "_loss-symbol.json"; + loss_param_path = file_prefix + "_loss-0000.params"; + loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map); + + NDArray::WaitAll(); + } + + std::vector GetNetworkSymbols() { + return network_symbol_list; + } + + std::vector> GetNetworkParamMaps() { + return network_param_map_list; + } + + Symbol GetLoss() { + return loss_symbol[0]; + } + + std::map GetLossParamMap() { + return loss_param_map[0]; + } + + std::vector GetQuerySymbols() { + return query_symbol_list; + } + + std::vector> GetQueryParamMaps() { + return query_param_map_list; + } + + std::vector> GetReplayMemory(){ + return replay_memory; + } +}; +#endif // CNNMODELLOADER diff --git a/src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py b/src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py index e376a095669594ec94a489b736580cf0aebb6c26..415132dc1f96139e94a977f7ab3ac87790f66e78 100644 --- a/src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py +++ b/src/test/resources/target_code/gluon/CNNNet_mnist_mnistClassifier_net.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -146,5 +562,5 @@ class Net_0(gluon.HybridBlock): softmax3_ = F.softmax(fc3_, axis=-1) predictions_ = F.identity(softmax3_) - return predictions_ + return [[predictions_]] diff --git a/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h b/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h index 75d7c61ecdf6ad99e934c09dbcee574020c2d693..55a9d6320cde26d83702a786447c4bca4e84234c 100644 --- a/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h +++ b/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h @@ -1,107 +1,149 @@ #ifndef CNNPREDICTOR_MNIST_MNISTCLASSIFIER_NET #define CNNPREDICTOR_MNIST_MNISTCLASSIFIER_NET -#include +#include #include #include #include + +#include +#include -#include - +using namespace mxnet::cpp; + class CNNPredictor_mnist_mnistClassifier_net_0{ public: - const std::string json_file = "model/mnist.LeNetNetwork/model_0_newest-symbol.json"; - const std::string param_file = "model/mnist.LeNetNetwork/model_0_newest-0000.params"; - const std::vector input_keys = { + const std::string file_prefix = "model/mnist.LeNetNetwork/model_0_newest"; + + //network + const std::vector network_input_keys = { "data" }; - const std::vector> input_shapes = {{1, 1, 28, 28}}; - const bool use_gpu = false; - - PredictorHandle handle; - + const std::vector> network_input_shapes = {{1, 1, 28, 28}}; + std::vector network_input_sizes; + std::vector> network_arg_names; + std::vector network_handles; + + + //misc + Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu + int dtype = 0; //use data type (float32=0 float64=1 ...) + + explicit CNNPredictor_mnist_mnistClassifier_net_0(){ - init(json_file, param_file, input_keys, input_shapes, use_gpu); + init(file_prefix, network_input_keys, network_input_shapes); } ~CNNPredictor_mnist_mnistClassifier_net_0(){ - if(handle) MXPredFree(handle); + for(Executor * handle : network_handles){ + delete handle; + } + MXNotifyShutdown(); } void predict(const std::vector &in_image_, std::vector &out_predictions_){ - MXPredSetInput(handle, input_keys[0].c_str(), in_image_.data(), static_cast(in_image_.size())); - - MXPredForward(handle); - mx_uint output_index; - mx_uint *shape = 0; - mx_uint shape_len; - size_t size; - - output_index = 0; - MXPredGetOutputShape(handle, output_index, &shape, &shape_len); - size = 1; - for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; - assert(size == out_predictions_.size()); - MXPredGetOutput(handle, output_index, &(out_predictions_[0]), out_predictions_.size()); + NDArray input_temp; + input_temp = NDArray(network_input_shapes[0], ctx, false, dtype); + input_temp.SyncCopyFromCPU(in_image_.data(), network_input_sizes[0]); + input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]])); + NDArray::WaitAll(); + + network_handles[0]->Forward(false); + CheckMXNetError("Forward, predict, handle ind. 0"); + + + std::vector output = network_handles.back()->outputs; + std::vector curr_output_shape; + size_t curr_output_size; + curr_output_shape = output[0].GetShape(); + curr_output_size = 1; + for (mx_uint i : curr_output_shape) curr_output_size *= i; + //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs + assert((curr_output_size == out_predictions_.size()) || (curr_output_size == out_predictions_[0])); + output[0].SyncCopyToCPU(&out_predictions_); + } + + + + Executor* initExecutor(Symbol &sym, + std::map ¶m_map, + const std::vector &exec_input_keys, + const std::vector> &exec_input_shapes){ + + const mx_uint num_exec_input_nodes = exec_input_keys.size(); + for(mx_uint i = 0; i < num_exec_input_nodes; i++){ + param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype); + } - void init(const std::string &json_file, - const std::string ¶m_file, - const std::vector &input_keys, - const std::vector> &input_shapes, - const bool &use_gpu){ + std::vector param_arrays; + std::vector grad_array; + std::vector grad_reqs; + std::vector aux_arrays; + std::map< std::string, NDArray> aux_map; - BufferFile json_data(json_file); - BufferFile param_data(param_file); + sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs, + &aux_arrays, param_map, std::map(), + std::map(), aux_map); - int dev_type = use_gpu ? 2 : 1; - int dev_id = 0; + Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays); + assert(handle); + return handle; + } - if (json_data.GetLength() == 0 || - param_data.GetLength() == 0) { - std::exit(-1); + std::vector getSizesOfShapes(const std::vector> shapes){ + std::vector sizes; + for(std::vector shape : shapes){ + mx_uint val = 1; + for(mx_uint i: shape){ + val *= i; + } + sizes.push_back(val); } + return sizes; + } - const mx_uint num_input_nodes = input_keys.size(); - - const char* input_keys_ptr[num_input_nodes]; - for(mx_uint i = 0; i < num_input_nodes; i++){ - input_keys_ptr[i] = input_keys[i].c_str(); + void CheckMXNetError(std::string loc){ + const char* err = MXGetLastError(); + if (err && err[0] != 0) { + std::cout << "MXNet error at " << loc << err << std::endl; + exit(-1); } - - mx_uint shape_data_size = 0; - mx_uint input_shape_indptr[input_shapes.size() + 1]; - input_shape_indptr[0] = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - shape_data_size += input_shapes[i].size(); - input_shape_indptr[i+1] = shape_data_size; + } + + void init(const std::string &file_prefix, + const std::vector &network_input_keys, + const std::vector> &network_input_shapes){ + + CNNLAOptimizer_mnist_mnistClassifier_net optimizer_creator = CNNLAOptimizer_mnist_mnistClassifier_net(); + + if(optimizer_creator.getContextName() == "gpu"){ + ctx = Context::gpu(); } - - mx_uint input_shape_data[shape_data_size]; - mx_uint index = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - for(mx_uint j = 0; j < input_shapes[i].size(); j++){ - input_shape_data[index] = input_shapes[i][j]; - index++; - } + + network_input_sizes = getSizesOfShapes(network_input_shapes); + + ModelLoader model_loader(file_prefix, 0, ctx); + + std::vector network_symbols = model_loader.GetNetworkSymbols(); + std::vector> network_param_maps; + network_param_maps = model_loader.GetNetworkParamMaps(); + + //Init handles + std::map> in_shape_map; + for(mx_uint i=0; i < network_input_keys.size(); i++){ + in_shape_map[network_input_keys[i]] = network_input_shapes[i]; } - - MXPredCreate(static_cast(json_data.GetBuffer()), - static_cast(param_data.GetBuffer()), - static_cast(param_data.GetLength()), - dev_type, - dev_id, - num_input_nodes, - input_keys_ptr, - input_shape_indptr, - input_shape_data, - &handle); - assert(handle); + std::vector> in_shapes; + std::vector> aux_shapes; + std::vector> out_shapes; + network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes); + network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes)); + } }; - #endif // CNNPREDICTOR_MNIST_MNISTCLASSIFIER_NET diff --git a/src/test/resources/target_code/gluon/CNNSupervisedTrainer_mnist_mnistClassifier_net.py b/src/test/resources/target_code/gluon/CNNSupervisedTrainer_mnist_mnistClassifier_net.py index b6f1a59372febc71f4311d93c99c8165570c861c..7403768f7343f1cfe11767ed31de2966d3e85924 100644 --- a/src/test/resources/target_code/gluon/CNNSupervisedTrainer_mnist_mnistClassifier_net.py +++ b/src/test/resources/target_code/gluon/CNNSupervisedTrainer_mnist_mnistClassifier_net.py @@ -7,7 +7,13 @@ import shutil import pickle import math import sys +import inspect from mxnet import gluon, autograd, nd +try: + import AdamW +except: + pass + class CrossEntropyLoss(gluon.loss.Loss): def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs): @@ -54,7 +60,7 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss): loss = -(pred * label).sum(axis=self._axis, keepdims=True) # ignore some indices for loss, e.g. tokens in NLP applications for i in self._ignore_indices: - loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label)) + loss = F.broadcast_mul(loss, F.logical_not(F.broadcast_equal(F.argmax(pred, axis=1), F.ones_like(F.argmax(pred, axis=1))*i) * F.broadcast_equal(F.argmax(pred, axis=1), label))) return loss.mean(axis=self._batch_axis, exclude=True) class DiceLoss(gluon.loss.Loss): @@ -277,12 +283,21 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: shuffle_data=False, clip_global_grad_norm=None, preprocessing = False): + num_pus = 1 if context == 'gpu': - mx_context = mx.gpu() + num_pus = mx.context.num_gpus() + if num_pus >= 1: + if num_pus == 1: + mx_context = [mx.gpu(0)] + else: + mx_context = [mx.gpu(i) for i in range(num_pus)] + else: + logging.error("Context argument is '" + context + "'. But no gpu is present in the system.") elif context == 'cpu': - mx_context = mx.cpu() + mx_context = [mx.cpu()] else: logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.") + single_pu_batch_size = int(batch_size/num_pus) if preprocessing: preproc_lib = "CNNPreprocessor_mnist_mnistClassifier_net_executor" @@ -327,7 +342,10 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: if not os.path.isdir(self._net_creator._model_dir_): raise - trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0] + if optimizer == "adamw": + trainers = [mx.gluon.Trainer(network.collect_params(), AdamW.AdamW(**optimizer_params)) for network in self._networks.values() if len(network.collect_params().values()) != 0] + else: + trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0] margin = loss_params['margin'] if 'margin' in loss_params else 1.0 sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True @@ -372,9 +390,16 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: loss_function = LogCoshLoss() else: logging.error("Invalid loss parameter.") + + loss_function.hybridize() + + tic = None + avg_speed = 0 + n = 0 + for epoch in range(begin_epoch, begin_epoch + num_epoch): if shuffle_data: if preprocessing: @@ -389,31 +414,36 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: loss_total = 0 train_iter.reset() for batch_i, batch in enumerate(train_iter): + + with autograd.record(): - labels = [batch.label[i].as_in_context(mx_context) for i in range(1)] + labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False) for i in range(1)] - image_ = batch.data[0].as_in_context(mx_context) + image_ = gluon.utils.split_and_load(batch.data[0], ctx_list=mx_context, even_split=False) - predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context) + predictions_ = [mx.nd.zeros((single_pu_batch_size, 10,), ctx=context) for context in mx_context] nd.waitall() - lossList = [] + for i in range(num_pus): + lossList.append([]) - predictions_ = self._networks[0](image_) + net_ret = [self._networks[0](image_[i]) for i in range(num_pus)] + predictions_ = [net_ret[i][0][0] for i in range(num_pus)] + [lossList[i].append(loss_function(predictions_[i], labels[0][i])) for i in range(num_pus)] - lossList.append(loss_function(predictions_, labels[0])) - - loss = 0 - for element in lossList: - loss = loss + element - loss.backward() + losses = [0]*num_pus + for i in range(num_pus): + for element in lossList[i]: + losses[i] = losses[i] + element - loss_total += loss.sum().asscalar() + for loss in losses: + loss.backward() + loss_total += loss.sum().asscalar() + global_loss_train += loss.sum().asscalar() - global_loss_train += loss.sum().asscalar() train_batches += 1 if clip_global_grad_norm: @@ -426,7 +456,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: for trainer in trainers: trainer.step(batch_size) - + if tic is None: tic = time.time() else: @@ -440,36 +470,39 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: loss_total = 0 logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg)) - + + avg_speed += speed + n += 1 + tic = time.time() global_loss_train /= (train_batches * batch_size) tic = None - if eval_train: train_iter.reset() metric = mx.metric.create(eval_metric, **eval_metric_params) for batch_i, batch in enumerate(train_iter): - labels = [batch.label[i].as_in_context(mx_context) for i in range(1)] - - image_ = batch.data[0].as_in_context(mx_context) + labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False)[0] for i in range(1)] + image_ = gluon.utils.split_and_load(batch.data[0], ctx_list=mx_context, even_split=False)[0] - predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context) + predictions_ = mx.nd.zeros((single_pu_batch_size, 10,), ctx=mx_context[0]) nd.waitall() - outputs = [] lossList = [] + outputs = [] attentionList = [] - predictions_ = self._networks[0](image_) + net_ret = self._networks[0](image_) + predictions_ = net_ret[0][0] outputs.append(predictions_) lossList.append(loss_function(predictions_, labels[0])) + if save_attention_image == "True": import matplotlib matplotlib.use('Agg') @@ -510,7 +543,6 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: os.makedirs(target_dir) plt.savefig(target_dir + '/attention_train.png') plt.close() - predictions = [] for output_name in outputs: if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: @@ -518,7 +550,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: else: predictions.append(output_name) - metric.update(preds=predictions, labels=labels) + metric.update(preds=predictions, labels=[labels[j] for j in range(len(labels))]) + train_metric_score = metric.get()[1] else: train_metric_score = 0 @@ -529,25 +562,26 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: test_iter.reset() metric = mx.metric.create(eval_metric, **eval_metric_params) for batch_i, batch in enumerate(test_iter): - if True: - labels = [batch.label[i].as_in_context(mx_context) for i in range(1)] - - image_ = batch.data[0].as_in_context(mx_context) + if True: + labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False)[0] for i in range(1)] + image_ = gluon.utils.split_and_load(batch.data[0], ctx_list=mx_context, even_split=False)[0] - predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context) + predictions_ = mx.nd.zeros((single_pu_batch_size, 10,), ctx=mx_context[0]) nd.waitall() - outputs = [] lossList = [] + outputs = [] attentionList = [] - predictions_ = self._networks[0](image_) + net_ret = self._networks[0](image_) + predictions_ = net_ret[0][0] outputs.append(predictions_) lossList.append(loss_function(predictions_, labels[0])) + if save_attention_image == "True": if not eval_train: import matplotlib @@ -594,26 +628,40 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: loss = loss + element global_loss_test += loss.sum().asscalar() + test_batches += 1 predictions = [] for output_name in outputs: predictions.append(output_name) - metric.update(preds=predictions, labels=labels) + metric.update(preds=predictions, labels=[labels[j] for j in range(len(labels))]) + test_metric_score = metric.get()[1] - global_loss_test /= (test_batches * batch_size) + global_loss_test /= (test_batches * single_pu_batch_size) logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test)) - if (epoch - begin_epoch) % checkpoint_period == 0: + if (epoch+1) % checkpoint_period == 0: for i, network in self._networks.items(): network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params') + if hasattr(network, 'episodic_sub_nets'): + for j, net in enumerate(network.episodic_sub_nets): + episodic_layers[i][j].save_memory(self.parameter_path(i) + "_episodic_memory_sub_net_" + str(j + 1) + "-" + str(epoch).zfill(4)) for i, network in self._networks.items(): - network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch + 1).zfill(4) + '.params') + network.save_parameters(self.parameter_path(i) + '-' + str((num_epoch-1) + begin_epoch).zfill(4) + '.params') network.export(self.parameter_path(i) + '_newest', epoch=0) + + if hasattr(network, 'episodic_sub_nets'): + network.episodicsubnet0_.export(self.parameter_path(i) + '_newest_episodic_sub_net_' + str(0), epoch=0) + for j, net in enumerate(network.episodic_sub_nets): + net.export(self.parameter_path(i) + '_newest_episodic_sub_net_' + str(j+1), epoch=0) + episodic_query_networks[i][j].export(self.parameter_path(i) + '_newest_episodic_query_net_' + str(j+1), epoch=0) + episodic_layers[i][j].save_memory(self.parameter_path(i) + "_episodic_memory_sub_net_" + str(j + 1) + "-" + str((num_epoch - 1) + begin_epoch).zfill(4)) + episodic_layers[i][j].save_memory(self.parameter_path(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + loss_function.export(self.parameter_path(i) + '_newest_loss', epoch=0) def parameter_path(self, index): return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index) diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNCreator_defaultGAN_defaultGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNCreator_defaultGAN_defaultGANConnector_predictor.py index 9e5e33e65f52240401b74b96bd273e106bab9efb..fafc619d8d2a31e4f2145a10ff9b824cb35bcc20 100644 --- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNCreator_defaultGAN_defaultGANConnector_predictor.py +++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNCreator_defaultGAN_defaultGANConnector_predictor.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_defaultGAN_defaultGANConnector_predictor import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_defaultGAN_defaultGANConnector_predictor: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 100,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 100,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py index 42fa27f0786c314fa7ac0df5c545cc92027fc71a..9cfe8224c5bbd0a2f9db82adbd4fbd8bfdd4a921 100644 --- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py +++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py @@ -184,16 +184,16 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor: del discriminator_optimizer_params['learning_rate_decay'] if normalize: - self._net_creator_dis.construct(mx_context, data_mean=data_mean, data_std=data_std) + self._net_creator_dis.construct([mx_context], data_mean=data_mean, data_std=data_std) else: - self._net_creator_dis.construct(mx_context) + self._net_creator_dis.construct([mx_context]) - self._net_creator_gen.construct(mx_context) + self._net_creator_gen.construct([mx_context]) if self.use_qnet: - self._net_creator_qnet.construct(mx_context) + self._net_creator_qnet.construct([mx_context]) if load_checkpoint: - self._net_creator_qnet.load(mx_context) + self._net_creator_qnet.load([mx_context]) else: if os.path.isdir(self._net_creator_qnet._model_dir_): shutil.rmtree(self._net_creator_qnet._model_dir_) @@ -206,8 +206,8 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor: begin_epoch = 0 if load_checkpoint: - begin_epoch = self._net_creator_dis.load(mx_context) - self._net_creator_gen.load(mx_context) + begin_epoch = self._net_creator_dis.load([mx_context]) + self._net_creator_gen.load([mx_context]) else: if os.path.isdir(self._net_creator_dis._model_dir_): shutil.rmtree(self._net_creator_dis._model_dir_) @@ -351,9 +351,9 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor: gen_input, exp_qnet_output = create_generator_input(batch) with autograd.record(): - fake_data = gen_net(*gen_input) + fake_data = gen_net(*gen_input)[0][0] fake_data.detach() - discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input) + discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)[0][0] if self.use_qnet: discriminated_fake_dis, _ = discriminated_fake_dis @@ -361,7 +361,7 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor: real_labels = mx.nd.ones(discriminated_fake_dis.shape, ctx=mx_context) loss_resultF = dis_loss(discriminated_fake_dis, fake_labels) - discriminated_real_dis = dis_net(real_data, *dis_conditional_input) + discriminated_real_dis = dis_net(real_data, *dis_conditional_input)[0][0] if self.use_qnet: discriminated_real_dis, _ = discriminated_real_dis loss_resultR = dis_loss(discriminated_real_dis, real_labels) @@ -372,8 +372,8 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor: if batch_i % k_value == 0: with autograd.record(): - fake_data = gen_net(*gen_input) - discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input) + fake_data = gen_net(*gen_input)[0][0] + discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)[0][0] if self.use_qnet: discriminated_fake_gen, features = discriminated_fake_gen loss_resultG = dis_loss(discriminated_fake_gen, real_labels) @@ -381,7 +381,7 @@ class CNNGanTrainer_defaultGAN_defaultGANConnector_predictor: condition = batch.data[traindata_to_index[generator_target_name + "_"]] loss_resultG = loss_resultG + gen_loss_weight * generator_loss_func(fake_data, condition) if self.use_qnet: - qnet_discriminated = [q_net(features)] + qnet_discriminated = [q_net(features)[0][0]] for i, qnet_out in enumerate(qnet_discriminated): loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i]) loss_resultG.backward() diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNNet_defaultGAN_defaultGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNNet_defaultGAN_defaultGANConnector_predictor.py index 32d8a55ff8131a54a11955dfd652338f43d97037..4333be26146ebdadfa61306efc8bf610778201cb 100644 --- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNNet_defaultGAN_defaultGANConnector_predictor.py +++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNNet_defaultGAN_defaultGANConnector_predictor.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -177,5 +593,5 @@ class Net_0(gluon.HybridBlock): tanh5_ = self.tanh5_(upconvolution5_) data_ = F.identity(tanh5_) - return data_ + return [[data_]] diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNPredictor_defaultGAN_defaultGANConnector_predictor.h b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNPredictor_defaultGAN_defaultGANConnector_predictor.h index a99f9c1b8c799648903e501c686773b03aea1dbf..d7ad1aab8082d22192f169560e8ff50016ab9704 100644 --- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNPredictor_defaultGAN_defaultGANConnector_predictor.h +++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNPredictor_defaultGAN_defaultGANConnector_predictor.h @@ -1,108 +1,149 @@ #ifndef CNNPREDICTOR_DEFAULTGAN_DEFAULTGANCONNECTOR_PREDICTOR #define CNNPREDICTOR_DEFAULTGAN_DEFAULTGANCONNECTOR_PREDICTOR -#include +#include #include #include #include + +#include +#include -#include - +using namespace mxnet::cpp; + class CNNPredictor_defaultGAN_defaultGANConnector_predictor_0{ public: - const std::string json_file = "model/defaultGAN.DefaultGANGenerator/model_0_newest-symbol.json"; - const std::string param_file = "model/defaultGAN.DefaultGANGenerator/model_0_newest-0000.params"; - const std::vector input_keys = { + const std::string file_prefix = "model/defaultGAN.DefaultGANGenerator/model_0_newest"; + + //network + const std::vector network_input_keys = { "data" }; - const std::vector> input_shapes = {{1, 100}}; - const bool use_gpu = false; - - PredictorHandle handle; - + const std::vector> network_input_shapes = {{1, 100}}; + std::vector network_input_sizes; + std::vector> network_arg_names; + std::vector network_handles; + + + //misc + Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu + int dtype = 0; //use data type (float32=0 float64=1 ...) + + explicit CNNPredictor_defaultGAN_defaultGANConnector_predictor_0(){ - init(json_file, param_file, input_keys, input_shapes, use_gpu); + init(file_prefix, network_input_keys, network_input_shapes); } ~CNNPredictor_defaultGAN_defaultGANConnector_predictor_0(){ - if(handle) MXPredFree(handle); + for(Executor * handle : network_handles){ + delete handle; + } + MXNotifyShutdown(); } void predict(const std::vector &in_noise_, std::vector &out_data_){ - MXPredSetInput(handle, input_keys[0].c_str(), in_noise_.data(), static_cast(in_noise_.size())); - - MXPredForward(handle); - mx_uint output_index; - mx_uint *shape = 0; - mx_uint shape_len; - size_t size; - - output_index = 0; - MXPredGetOutputShape(handle, output_index, &shape, &shape_len); - size = 1; - for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; - - assert(size == out_data_.size()); - MXPredGetOutput(handle, output_index, &(out_data_[0]), out_data_.size()); + NDArray input_temp; + input_temp = NDArray(network_input_shapes[0], ctx, false, dtype); + input_temp.SyncCopyFromCPU(in_noise_.data(), network_input_sizes[0]); + input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]])); + NDArray::WaitAll(); + + network_handles[0]->Forward(false); + CheckMXNetError("Forward, predict, handle ind. 0"); + + + std::vector output = network_handles.back()->outputs; + std::vector curr_output_shape; + size_t curr_output_size; + curr_output_shape = output[0].GetShape(); + curr_output_size = 1; + for (mx_uint i : curr_output_shape) curr_output_size *= i; + //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs + assert((curr_output_size == out_data_.size()) || (curr_output_size == out_data_[0])); + output[0].SyncCopyToCPU(&out_data_); + } + + + + Executor* initExecutor(Symbol &sym, + std::map ¶m_map, + const std::vector &exec_input_keys, + const std::vector> &exec_input_shapes){ + + const mx_uint num_exec_input_nodes = exec_input_keys.size(); + for(mx_uint i = 0; i < num_exec_input_nodes; i++){ + param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype); + } - void init(const std::string &json_file, - const std::string ¶m_file, - const std::vector &input_keys, - const std::vector> &input_shapes, - const bool &use_gpu){ + std::vector param_arrays; + std::vector grad_array; + std::vector grad_reqs; + std::vector aux_arrays; + std::map< std::string, NDArray> aux_map; - BufferFile json_data(json_file); - BufferFile param_data(param_file); + sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs, + &aux_arrays, param_map, std::map(), + std::map(), aux_map); - int dev_type = use_gpu ? 2 : 1; - int dev_id = 0; + Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays); + assert(handle); + return handle; + } - if (json_data.GetLength() == 0 || - param_data.GetLength() == 0) { - std::exit(-1); + std::vector getSizesOfShapes(const std::vector> shapes){ + std::vector sizes; + for(std::vector shape : shapes){ + mx_uint val = 1; + for(mx_uint i: shape){ + val *= i; + } + sizes.push_back(val); } + return sizes; + } - const mx_uint num_input_nodes = input_keys.size(); - - const char* input_keys_ptr[num_input_nodes]; - for(mx_uint i = 0; i < num_input_nodes; i++){ - input_keys_ptr[i] = input_keys[i].c_str(); + void CheckMXNetError(std::string loc){ + const char* err = MXGetLastError(); + if (err && err[0] != 0) { + std::cout << "MXNet error at " << loc << err << std::endl; + exit(-1); } - - mx_uint shape_data_size = 0; - mx_uint input_shape_indptr[input_shapes.size() + 1]; - input_shape_indptr[0] = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - shape_data_size += input_shapes[i].size(); - input_shape_indptr[i+1] = shape_data_size; + } + + void init(const std::string &file_prefix, + const std::vector &network_input_keys, + const std::vector> &network_input_shapes){ + + CNNLAOptimizer_defaultGAN_defaultGANConnector_predictor optimizer_creator = CNNLAOptimizer_defaultGAN_defaultGANConnector_predictor(); + + if(optimizer_creator.getContextName() == "gpu"){ + ctx = Context::gpu(); } - - mx_uint input_shape_data[shape_data_size]; - mx_uint index = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - for(mx_uint j = 0; j < input_shapes[i].size(); j++){ - input_shape_data[index] = input_shapes[i][j]; - index++; - } + + network_input_sizes = getSizesOfShapes(network_input_shapes); + + ModelLoader model_loader(file_prefix, 0, ctx); + + std::vector network_symbols = model_loader.GetNetworkSymbols(); + std::vector> network_param_maps; + network_param_maps = model_loader.GetNetworkParamMaps(); + + //Init handles + std::map> in_shape_map; + for(mx_uint i=0; i < network_input_keys.size(); i++){ + in_shape_map[network_input_keys[i]] = network_input_shapes[i]; } - - MXPredCreate(static_cast(json_data.GetBuffer()), - static_cast(param_data.GetBuffer()), - static_cast(param_data.GetLength()), - dev_type, - dev_id, - num_input_nodes, - input_keys_ptr, - input_shape_indptr, - input_shape_data, - &handle); - assert(handle); + std::vector> in_shapes; + std::vector> aux_shapes; + std::vector> out_shapes; + network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes); + network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes)); + } }; - #endif // CNNPREDICTOR_DEFAULTGAN_DEFAULTGANCONNECTOR_PREDICTOR diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNTrainer_defaultGAN_defaultGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNTrainer_defaultGAN_defaultGANConnector_predictor.py index 82954430835cdbe977415d4c694240a29e7e92c0..7119c688509d054a73c38bfad34796854df9bd61 100644 --- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNTrainer_defaultGAN_defaultGANConnector_predictor.py +++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/CNNTrainer_defaultGAN_defaultGANConnector_predictor.py @@ -55,5 +55,3 @@ if __name__ == "__main__": log_period=10, print_images=True, ) - - diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py index 0e7501600ef4a5ae20cb0af65aa78f5090a93745..1920ecee6a5e558930fb6e2e3259ceed2ed98a2b 100644 --- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py +++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_defaultGAN_defaultGANDiscriminator import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 1,64,64,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 1,64,64,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNNet_defaultGAN_defaultGANDiscriminator.py b/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNNet_defaultGAN_defaultGANDiscriminator.py index 027ed194713072192cd56ece3f55216888ea8c94..6c220b51e34c969cf20defd9fee40498055e854a 100644 --- a/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNNet_defaultGAN_defaultGANDiscriminator.py +++ b/src/test/resources/target_code/gluon/ganModel/defaultGAN/gan/CNNNet_defaultGAN_defaultGANDiscriminator.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -172,5 +588,5 @@ class Net_0(gluon.HybridBlock): sigmoid5_ = self.sigmoid5_(conv5_) dis_ = F.identity(sigmoid5_) - return dis_ + return [[dis_]] diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNCreator_infoGAN_infoGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNCreator_infoGAN_infoGANConnector_predictor.py index ba115b47978ef21143be826f7ab0b6b639075682..c84c36bf315d49197e604ef09981a6d78073a932 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNCreator_infoGAN_infoGANConnector_predictor.py +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNCreator_infoGAN_infoGANConnector_predictor.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_infoGAN_infoGANConnector_predictor import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_infoGAN_infoGANConnector_predictor: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_infoGAN_infoGANConnector_predictor: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_infoGAN_infoGANConnector_predictor: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 62,), ctx=context), mx.nd.zeros((1, 10,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 62,), ctx=context[0]), mx.nd.zeros((1, 10,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNGanTrainer_infoGAN_infoGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNGanTrainer_infoGAN_infoGANConnector_predictor.py index b49f11dc6de390d416e705a14d481baf90e49a51..426cc112a4ed84ab3f4158d47bbcbfdfc276e82e 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNGanTrainer_infoGAN_infoGANConnector_predictor.py +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNGanTrainer_infoGAN_infoGANConnector_predictor.py @@ -184,16 +184,16 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor: del discriminator_optimizer_params['learning_rate_decay'] if normalize: - self._net_creator_dis.construct(mx_context, data_mean=data_mean, data_std=data_std) + self._net_creator_dis.construct([mx_context], data_mean=data_mean, data_std=data_std) else: - self._net_creator_dis.construct(mx_context) + self._net_creator_dis.construct([mx_context]) - self._net_creator_gen.construct(mx_context) + self._net_creator_gen.construct([mx_context]) if self.use_qnet: - self._net_creator_qnet.construct(mx_context) + self._net_creator_qnet.construct([mx_context]) if load_checkpoint: - self._net_creator_qnet.load(mx_context) + self._net_creator_qnet.load([mx_context]) else: if os.path.isdir(self._net_creator_qnet._model_dir_): shutil.rmtree(self._net_creator_qnet._model_dir_) @@ -206,8 +206,8 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor: begin_epoch = 0 if load_checkpoint: - begin_epoch = self._net_creator_dis.load(mx_context) - self._net_creator_gen.load(mx_context) + begin_epoch = self._net_creator_dis.load([mx_context]) + self._net_creator_gen.load([mx_context]) else: if os.path.isdir(self._net_creator_dis._model_dir_): shutil.rmtree(self._net_creator_dis._model_dir_) @@ -351,9 +351,9 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor: gen_input, exp_qnet_output = create_generator_input(batch) with autograd.record(): - fake_data = gen_net(*gen_input) + fake_data = gen_net(*gen_input)[0][0] fake_data.detach() - discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input) + discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)[0][0] if self.use_qnet: discriminated_fake_dis, _ = discriminated_fake_dis @@ -361,7 +361,7 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor: real_labels = mx.nd.ones(discriminated_fake_dis.shape, ctx=mx_context) loss_resultF = dis_loss(discriminated_fake_dis, fake_labels) - discriminated_real_dis = dis_net(real_data, *dis_conditional_input) + discriminated_real_dis = dis_net(real_data, *dis_conditional_input)[0][0] if self.use_qnet: discriminated_real_dis, _ = discriminated_real_dis loss_resultR = dis_loss(discriminated_real_dis, real_labels) @@ -372,8 +372,8 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor: if batch_i % k_value == 0: with autograd.record(): - fake_data = gen_net(*gen_input) - discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input) + fake_data = gen_net(*gen_input)[0][0] + discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)[0][0] if self.use_qnet: discriminated_fake_gen, features = discriminated_fake_gen loss_resultG = dis_loss(discriminated_fake_gen, real_labels) @@ -381,7 +381,7 @@ class CNNGanTrainer_infoGAN_infoGANConnector_predictor: condition = batch.data[traindata_to_index[generator_target_name + "_"]] loss_resultG = loss_resultG + gen_loss_weight * generator_loss_func(fake_data, condition) if self.use_qnet: - qnet_discriminated = [q_net(features)] + qnet_discriminated = [q_net(features)[0][0]] for i, qnet_out in enumerate(qnet_discriminated): loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i]) loss_resultG.backward() diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNNet_infoGAN_infoGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNNet_infoGAN_infoGANConnector_predictor.py index 4aeb33b350c6d3e898cabddec1499b3ad07a7f10..ae8d9f10d1a696e3f65391c6b0100078753ab836 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNNet_infoGAN_infoGANConnector_predictor.py +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNNet_infoGAN_infoGANConnector_predictor.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -186,5 +602,5 @@ class Net_0(gluon.HybridBlock): tanh7_ = self.tanh7_(upconvolution7_) data_ = F.identity(tanh7_) - return data_ + return [[data_]] diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNPredictor_infoGAN_infoGANConnector_predictor.h b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNPredictor_infoGAN_infoGANConnector_predictor.h index 76240ebe330a66765750cec5cbfa1623022c779d..e0968814cfe44e76d9bbc2669f1ba1827a975443 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNPredictor_infoGAN_infoGANConnector_predictor.h +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNPredictor_infoGAN_infoGANConnector_predictor.h @@ -1,109 +1,152 @@ #ifndef CNNPREDICTOR_INFOGAN_INFOGANCONNECTOR_PREDICTOR #define CNNPREDICTOR_INFOGAN_INFOGANCONNECTOR_PREDICTOR -#include +#include #include #include #include + +#include +#include -#include - +using namespace mxnet::cpp; + class CNNPredictor_infoGAN_infoGANConnector_predictor_0{ public: - const std::string json_file = "model/infoGAN.InfoGANGenerator/model_0_newest-symbol.json"; - const std::string param_file = "model/infoGAN.InfoGANGenerator/model_0_newest-0000.params"; - const std::vector input_keys = { + const std::string file_prefix = "model/infoGAN.InfoGANGenerator/model_0_newest"; + + //network + const std::vector network_input_keys = { "data0", "data1" }; - const std::vector> input_shapes = {{1, 62}, {1, 10}}; - const bool use_gpu = false; - - PredictorHandle handle; - + const std::vector> network_input_shapes = {{1, 62}, {1, 10}}; + std::vector network_input_sizes; + std::vector> network_arg_names; + std::vector network_handles; + + + //misc + Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu + int dtype = 0; //use data type (float32=0 float64=1 ...) + + explicit CNNPredictor_infoGAN_infoGANConnector_predictor_0(){ - init(json_file, param_file, input_keys, input_shapes, use_gpu); + init(file_prefix, network_input_keys, network_input_shapes); } ~CNNPredictor_infoGAN_infoGANConnector_predictor_0(){ - if(handle) MXPredFree(handle); + for(Executor * handle : network_handles){ + delete handle; + } + MXNotifyShutdown(); } void predict(const std::vector &in_noise_, const std::vector &in_c1_, std::vector &out_data_){ - MXPredSetInput(handle, input_keys[0].c_str(), in_noise_.data(), static_cast(in_noise_.size())); - MXPredSetInput(handle, input_keys[1].c_str(), in_c1_.data(), static_cast(in_c1_.size())); - - MXPredForward(handle); - mx_uint output_index; - mx_uint *shape = 0; - mx_uint shape_len; - size_t size; - - output_index = 0; - MXPredGetOutputShape(handle, output_index, &shape, &shape_len); - size = 1; - for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; - - assert(size == out_data_.size()); - MXPredGetOutput(handle, output_index, &(out_data_[0]), out_data_.size()); + NDArray input_temp; + input_temp = NDArray(network_input_shapes[0], ctx, false, dtype); + input_temp.SyncCopyFromCPU(in_noise_.data(), network_input_sizes[0]); + input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]])); + input_temp = NDArray(network_input_shapes[1], ctx, false, dtype); + input_temp.SyncCopyFromCPU(in_c1_.data(), network_input_sizes[1]); + input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[1]])); + NDArray::WaitAll(); + + network_handles[0]->Forward(false); + CheckMXNetError("Forward, predict, handle ind. 0"); + + + std::vector output = network_handles.back()->outputs; + std::vector curr_output_shape; + size_t curr_output_size; + curr_output_shape = output[0].GetShape(); + curr_output_size = 1; + for (mx_uint i : curr_output_shape) curr_output_size *= i; + //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs + assert((curr_output_size == out_data_.size()) || (curr_output_size == out_data_[0])); + output[0].SyncCopyToCPU(&out_data_); + } + + + + Executor* initExecutor(Symbol &sym, + std::map ¶m_map, + const std::vector &exec_input_keys, + const std::vector> &exec_input_shapes){ + + const mx_uint num_exec_input_nodes = exec_input_keys.size(); + for(mx_uint i = 0; i < num_exec_input_nodes; i++){ + param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype); + } - void init(const std::string &json_file, - const std::string ¶m_file, - const std::vector &input_keys, - const std::vector> &input_shapes, - const bool &use_gpu){ + std::vector param_arrays; + std::vector grad_array; + std::vector grad_reqs; + std::vector aux_arrays; + std::map< std::string, NDArray> aux_map; - BufferFile json_data(json_file); - BufferFile param_data(param_file); + sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs, + &aux_arrays, param_map, std::map(), + std::map(), aux_map); - int dev_type = use_gpu ? 2 : 1; - int dev_id = 0; + Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays); + assert(handle); + return handle; + } - if (json_data.GetLength() == 0 || - param_data.GetLength() == 0) { - std::exit(-1); + std::vector getSizesOfShapes(const std::vector> shapes){ + std::vector sizes; + for(std::vector shape : shapes){ + mx_uint val = 1; + for(mx_uint i: shape){ + val *= i; + } + sizes.push_back(val); } + return sizes; + } - const mx_uint num_input_nodes = input_keys.size(); - - const char* input_keys_ptr[num_input_nodes]; - for(mx_uint i = 0; i < num_input_nodes; i++){ - input_keys_ptr[i] = input_keys[i].c_str(); + void CheckMXNetError(std::string loc){ + const char* err = MXGetLastError(); + if (err && err[0] != 0) { + std::cout << "MXNet error at " << loc << err << std::endl; + exit(-1); } - - mx_uint shape_data_size = 0; - mx_uint input_shape_indptr[input_shapes.size() + 1]; - input_shape_indptr[0] = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - shape_data_size += input_shapes[i].size(); - input_shape_indptr[i+1] = shape_data_size; + } + + void init(const std::string &file_prefix, + const std::vector &network_input_keys, + const std::vector> &network_input_shapes){ + + CNNLAOptimizer_infoGAN_infoGANConnector_predictor optimizer_creator = CNNLAOptimizer_infoGAN_infoGANConnector_predictor(); + + if(optimizer_creator.getContextName() == "gpu"){ + ctx = Context::gpu(); } - - mx_uint input_shape_data[shape_data_size]; - mx_uint index = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - for(mx_uint j = 0; j < input_shapes[i].size(); j++){ - input_shape_data[index] = input_shapes[i][j]; - index++; - } + + network_input_sizes = getSizesOfShapes(network_input_shapes); + + ModelLoader model_loader(file_prefix, 0, ctx); + + std::vector network_symbols = model_loader.GetNetworkSymbols(); + std::vector> network_param_maps; + network_param_maps = model_loader.GetNetworkParamMaps(); + + //Init handles + std::map> in_shape_map; + for(mx_uint i=0; i < network_input_keys.size(); i++){ + in_shape_map[network_input_keys[i]] = network_input_shapes[i]; } - - MXPredCreate(static_cast(json_data.GetBuffer()), - static_cast(param_data.GetBuffer()), - static_cast(param_data.GetLength()), - dev_type, - dev_id, - num_input_nodes, - input_keys_ptr, - input_shape_indptr, - input_shape_data, - &handle); - assert(handle); + std::vector> in_shapes; + std::vector> aux_shapes; + std::vector> out_shapes; + network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes); + network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes)); + } }; - #endif // CNNPREDICTOR_INFOGAN_INFOGANCONNECTOR_PREDICTOR diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNTrainer_infoGAN_infoGANConnector_predictor.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNTrainer_infoGAN_infoGANConnector_predictor.py index dda1486b1cd310b030c00eec90cdecd1833928cc..97e91e467658d9a2033f55e920f2269eee6b2971 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNTrainer_infoGAN_infoGANConnector_predictor.py +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/CNNTrainer_infoGAN_infoGANConnector_predictor.py @@ -58,5 +58,3 @@ if __name__ == "__main__": log_period=10, print_images=True, ) - - diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANDiscriminator.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANDiscriminator.py index 1254ae9a51cef6d70f8d6f9bc09161cd6ab00ce5..41c442dd044f4b6758382ab1572c39a1b5eb96d7 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANDiscriminator.py +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANDiscriminator.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_infoGAN_infoGANDiscriminator import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_infoGAN_infoGANDiscriminator: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_infoGAN_infoGANDiscriminator: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_infoGAN_infoGANDiscriminator: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 1,28,28,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANQNetwork.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANQNetwork.py index 7316199fc4c5080033b8e146763c5fe797b5d782..de6c36acc4ed22f7f8b5f93c3f21833c19669b6c 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANQNetwork.py +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNCreator_infoGAN_infoGANQNetwork.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_infoGAN_infoGANQNetwork import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_infoGAN_infoGANQNetwork: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_infoGAN_infoGANQNetwork: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_infoGAN_infoGANQNetwork: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 512,4,4,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 512,4,4,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANDiscriminator.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANDiscriminator.py index d6000f4999d620a77e4361db82f3dadfe97b9d9b..0147901ba984e585ccab730ef067fe184f7b84fa 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANDiscriminator.py +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANDiscriminator.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -173,5 +589,5 @@ class Net_0(gluon.HybridBlock): dis_ = F.identity(sigmoid5_1_) features_ = F.identity(leakyrelu4_) - return dis_, features_ + return [[dis_, features_]] diff --git a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANQNetwork.py b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANQNetwork.py index a39b202096c47a7ad603f371922f4f34cb66c5cf..df2d89c866211410aed9bc3cf84fe87d439b6b09 100644 --- a/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANQNetwork.py +++ b/src/test/resources/target_code/gluon/ganModel/infoGAN/gan/CNNNet_infoGAN_infoGANQNetwork.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -120,5 +536,5 @@ class Net_0(gluon.HybridBlock): softmax2_ = F.softmax(fc2_, axis=-1) c1_ = F.identity(softmax2_) - return c1_ + return [[c1_]] diff --git a/src/test/resources/target_code/gluon/mnist_mnistClassifier_net.h b/src/test/resources/target_code/gluon/mnist_mnistClassifier_net.h index 8b21dbcdab7bd81d53191332941887949309ffe1..fdb1bb174d04ac75a4db001bab69a27555142fac 100644 --- a/src/test/resources/target_code/gluon/mnist_mnistClassifier_net.h +++ b/src/test/resources/target_code/gluon/mnist_mnistClassifier_net.h @@ -20,8 +20,10 @@ predictions=colvec(classes); } void execute(){ vector image_ = CNNTranslator::translate(image); + vector predictions_(10); + _predictor_0_.predict(image_, predictions_); predictions = CNNTranslator::translateToCol(predictions_, std::vector {10}); diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNBufferFile.h b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNBufferFile.h deleted file mode 100644 index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000 --- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNBufferFile.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef CNNBUFFERFILE_H -#define CNNBUFFERFILE_H - -#include -#include -#include - -// Read file to buffer -class BufferFile { - public : - std::string file_path_; - int length_; - char* buffer_; - - explicit BufferFile(std::string file_path) - :file_path_(file_path) { - - std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); - if (!ifs) { - std::cerr << "Can't open the file. Please check " << file_path << ". \n"; - length_ = 0; - buffer_ = NULL; - return; - } - - ifs.seekg(0, std::ios::end); - length_ = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; - - buffer_ = new char[sizeof(char) * length_]; - ifs.read(buffer_, length_); - ifs.close(); - } - - int GetLength() { - return length_; - } - char* GetBuffer() { - return buffer_; - } - - ~BufferFile() { - if (buffer_) { - delete[] buffer_; - buffer_ = NULL; - } - } -}; - -#endif // CNNBUFFERFILE_H diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNCreator_cartpole_master_dqn.py b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNCreator_cartpole_master_dqn.py index 0f6f1d26d08fc8bbe3ca19df9bc975ab4a16fa0e..000d6bb92d64ba83f6e4606797a8bd4adaf57be6 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNCreator_cartpole_master_dqn.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNCreator_cartpole_master_dqn.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_cartpole_master_dqn import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_cartpole_master_dqn: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_cartpole_master_dqn: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_cartpole_master_dqn: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 4,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 4,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNModelLoader.h b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNModelLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4 --- /dev/null +++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNModelLoader.h @@ -0,0 +1,141 @@ +#ifndef CNNMODELLOADER +#define CNNMODELLOADER + +#include + +#include +#include +#include + +using namespace mxnet::cpp; + +// Read files to load moddel symbol and parameters +class ModelLoader { +private: + Context ctx = Context::cpu(); + std::vector network_symbol_list; + std::vector> network_param_map_list; + + std::vector query_symbol_list; + std::vector> query_param_map_list; + + std::vector> replay_memory; + + std::vector loss_symbol; + std::vector> loss_param_map; + + + void checkFile(std::string file_path){ + std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); + if (!ifs) { + std::cerr << "Can't open the file. Please check " << file_path << ". \n"; + return; + } + + int length_; + ifs.seekg(0, std::ios::end); + length_ = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; + ifs.close(); + } + + void loadComponent(std::string json_path, + std::string param_path, + std::vector &symbols_list, + std::vector> ¶m_map_list){ + checkFile(json_path); + symbols_list.push_back(Symbol::Load(json_path)); + checkFile(param_path); + std::map params; + NDArray::Load(param_path, 0, ¶ms); + param_map_list.push_back(processParamMap(params)); + } + + std::map processParamMap(std::map param_map){ + std::map processed_param_map; + if(!param_map.empty()){ + for (const auto &pair : param_map) { + std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure) + processed_param_map[name] = pair.second.Copy(ctx); + } + } + return processed_param_map; + } + +public: + explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){ + + ctx = ctx_param; + std::string network_json_path; + std::string network_param_path; + std::string query_json_path; + std::string query_param_path; + std::string memory_path; + std::string loss_json_path; + std::string loss_param_path; + + //Load network + if(!num_subnets){ + network_json_path = file_prefix + "-symbol.json"; + network_param_path = file_prefix + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + }else{ + for(int i=0; i < num_subnets; i++){ + network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json"; + network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + if(i >= 1){ + query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json"; + query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params"; + loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list); + + memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000"; + checkFile(memory_path); + + std::map mem_map = NDArray::LoadToMap(memory_path); + for(auto &mem : mem_map){ + mem.second = mem.second.Copy(ctx); + } + replay_memory.push_back(mem_map); + } + } + } + + //Load Loss + loss_json_path = file_prefix + "_loss-symbol.json"; + loss_param_path = file_prefix + "_loss-0000.params"; + loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map); + + NDArray::WaitAll(); + } + + std::vector GetNetworkSymbols() { + return network_symbol_list; + } + + std::vector> GetNetworkParamMaps() { + return network_param_map_list; + } + + Symbol GetLoss() { + return loss_symbol[0]; + } + + std::map GetLossParamMap() { + return loss_param_map[0]; + } + + std::vector GetQuerySymbols() { + return query_symbol_list; + } + + std::vector> GetQueryParamMaps() { + return query_param_map_list; + } + + std::vector> GetReplayMemory(){ + return replay_memory; + } +}; +#endif // CNNMODELLOADER diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNNet_cartpole_master_dqn.py b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNNet_cartpole_master_dqn.py index 98fc737d351baf0c83dc46c3a29cc9955faaef60..acc3c897d0c4784b2120f2b7267c4a19f32dcddc 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNNet_cartpole_master_dqn.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNNet_cartpole_master_dqn.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -121,5 +537,5 @@ class Net_0(gluon.HybridBlock): fc3_ = self.fc3_(tanh2_) qvalues_ = F.identity(fc3_) - return qvalues_ + return [[qvalues_]] diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNPredictor_cartpole_master_dqn.h b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNPredictor_cartpole_master_dqn.h index 9c02b06612cc9982ba9f91a7b1e29b51b206dcfb..b1dd0cf661c18568e8cff572e5bd64d6fa3f54e6 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNPredictor_cartpole_master_dqn.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNPredictor_cartpole_master_dqn.h @@ -1,107 +1,149 @@ #ifndef CNNPREDICTOR_CARTPOLE_MASTER_DQN #define CNNPREDICTOR_CARTPOLE_MASTER_DQN -#include +#include #include #include #include + +#include +#include -#include - +using namespace mxnet::cpp; + class CNNPredictor_cartpole_master_dqn_0{ public: - const std::string json_file = "model/cartpole.agent.CartPoleDQN/model_0_newest-symbol.json"; - const std::string param_file = "model/cartpole.agent.CartPoleDQN/model_0_newest-0000.params"; - const std::vector input_keys = { + const std::string file_prefix = "model/cartpole.agent.CartPoleDQN/model_0_newest"; + + //network + const std::vector network_input_keys = { "data" }; - const std::vector> input_shapes = {{1, 4}}; - const bool use_gpu = false; - - PredictorHandle handle; - + const std::vector> network_input_shapes = {{1, 4}}; + std::vector network_input_sizes; + std::vector> network_arg_names; + std::vector network_handles; + + + //misc + Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu + int dtype = 0; //use data type (float32=0 float64=1 ...) + + explicit CNNPredictor_cartpole_master_dqn_0(){ - init(json_file, param_file, input_keys, input_shapes, use_gpu); + init(file_prefix, network_input_keys, network_input_shapes); } ~CNNPredictor_cartpole_master_dqn_0(){ - if(handle) MXPredFree(handle); + for(Executor * handle : network_handles){ + delete handle; + } + MXNotifyShutdown(); } void predict(const std::vector &in_state_, std::vector &out_qvalues_){ - MXPredSetInput(handle, input_keys[0].c_str(), in_state_.data(), static_cast(in_state_.size())); - - MXPredForward(handle); - mx_uint output_index; - mx_uint *shape = 0; - mx_uint shape_len; - size_t size; - - output_index = 0; - MXPredGetOutputShape(handle, output_index, &shape, &shape_len); - size = 1; - for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; - assert(size == out_qvalues_.size()); - MXPredGetOutput(handle, output_index, &(out_qvalues_[0]), out_qvalues_.size()); + NDArray input_temp; + input_temp = NDArray(network_input_shapes[0], ctx, false, dtype); + input_temp.SyncCopyFromCPU(in_state_.data(), network_input_sizes[0]); + input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]])); + NDArray::WaitAll(); + + network_handles[0]->Forward(false); + CheckMXNetError("Forward, predict, handle ind. 0"); + + + std::vector output = network_handles.back()->outputs; + std::vector curr_output_shape; + size_t curr_output_size; + curr_output_shape = output[0].GetShape(); + curr_output_size = 1; + for (mx_uint i : curr_output_shape) curr_output_size *= i; + //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs + assert((curr_output_size == out_qvalues_.size()) || (curr_output_size == out_qvalues_[0])); + output[0].SyncCopyToCPU(&out_qvalues_); + } + + + + Executor* initExecutor(Symbol &sym, + std::map ¶m_map, + const std::vector &exec_input_keys, + const std::vector> &exec_input_shapes){ + + const mx_uint num_exec_input_nodes = exec_input_keys.size(); + for(mx_uint i = 0; i < num_exec_input_nodes; i++){ + param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype); + } - void init(const std::string &json_file, - const std::string ¶m_file, - const std::vector &input_keys, - const std::vector> &input_shapes, - const bool &use_gpu){ + std::vector param_arrays; + std::vector grad_array; + std::vector grad_reqs; + std::vector aux_arrays; + std::map< std::string, NDArray> aux_map; - BufferFile json_data(json_file); - BufferFile param_data(param_file); + sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs, + &aux_arrays, param_map, std::map(), + std::map(), aux_map); - int dev_type = use_gpu ? 2 : 1; - int dev_id = 0; + Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays); + assert(handle); + return handle; + } - if (json_data.GetLength() == 0 || - param_data.GetLength() == 0) { - std::exit(-1); + std::vector getSizesOfShapes(const std::vector> shapes){ + std::vector sizes; + for(std::vector shape : shapes){ + mx_uint val = 1; + for(mx_uint i: shape){ + val *= i; + } + sizes.push_back(val); } + return sizes; + } - const mx_uint num_input_nodes = input_keys.size(); - - const char* input_keys_ptr[num_input_nodes]; - for(mx_uint i = 0; i < num_input_nodes; i++){ - input_keys_ptr[i] = input_keys[i].c_str(); + void CheckMXNetError(std::string loc){ + const char* err = MXGetLastError(); + if (err && err[0] != 0) { + std::cout << "MXNet error at " << loc << err << std::endl; + exit(-1); } - - mx_uint shape_data_size = 0; - mx_uint input_shape_indptr[input_shapes.size() + 1]; - input_shape_indptr[0] = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - shape_data_size += input_shapes[i].size(); - input_shape_indptr[i+1] = shape_data_size; + } + + void init(const std::string &file_prefix, + const std::vector &network_input_keys, + const std::vector> &network_input_shapes){ + + CNNLAOptimizer_cartpole_master_dqn optimizer_creator = CNNLAOptimizer_cartpole_master_dqn(); + + if(optimizer_creator.getContextName() == "gpu"){ + ctx = Context::gpu(); } - - mx_uint input_shape_data[shape_data_size]; - mx_uint index = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - for(mx_uint j = 0; j < input_shapes[i].size(); j++){ - input_shape_data[index] = input_shapes[i][j]; - index++; - } + + network_input_sizes = getSizesOfShapes(network_input_shapes); + + ModelLoader model_loader(file_prefix, 0, ctx); + + std::vector network_symbols = model_loader.GetNetworkSymbols(); + std::vector> network_param_maps; + network_param_maps = model_loader.GetNetworkParamMaps(); + + //Init handles + std::map> in_shape_map; + for(mx_uint i=0; i < network_input_keys.size(); i++){ + in_shape_map[network_input_keys[i]] = network_input_shapes[i]; } - - MXPredCreate(static_cast(json_data.GetBuffer()), - static_cast(param_data.GetBuffer()), - static_cast(param_data.GetLength()), - dev_type, - dev_id, - num_input_nodes, - input_keys_ptr, - input_shape_indptr, - input_shape_data, - &handle); - assert(handle); + std::vector> in_shapes; + std::vector> aux_shapes; + std::vector> out_shapes; + network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes); + network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes)); + } }; - #endif // CNNPREDICTOR_CARTPOLE_MASTER_DQN diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNTrainer_cartpole_master_dqn.py b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNTrainer_cartpole_master_dqn.py index f650e1da41ac8ce7abd500bcf7b108912bae8665..a831588507c27ee8136d6fe6e0559c747483e5c5 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNTrainer_cartpole_master_dqn.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/CNNTrainer_cartpole_master_dqn.py @@ -58,6 +58,7 @@ if __name__ == "__main__": 'state_dtype': 'float32', 'action_dtype': 'uint8', 'rewards_dtype': 'float32' + }, 'strategy_params': { 'method':'epsgreedy', diff --git a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/cartpole_master_dqn.h b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/cartpole_master_dqn.h index 8ecc9075a034b051632dc6b29447261c31430f20..9269801b6aa22cc1172f144a5055a0c7b345e957 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/cartpole/cartpole_master_dqn.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/cartpole/cartpole_master_dqn.h @@ -19,8 +19,10 @@ qvalues=colvec(2); } void execute(){ vector state_ = CNNTranslator::translate(state); + vector qvalues_(2); + _predictor_0_.predict(state_, qvalues_); qvalues = CNNTranslator::translateToCol(qvalues_, std::vector {2}); diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNBufferFile.h b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNBufferFile.h deleted file mode 100644 index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000 --- a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNBufferFile.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef CNNBUFFERFILE_H -#define CNNBUFFERFILE_H - -#include -#include -#include - -// Read file to buffer -class BufferFile { - public : - std::string file_path_; - int length_; - char* buffer_; - - explicit BufferFile(std::string file_path) - :file_path_(file_path) { - - std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); - if (!ifs) { - std::cerr << "Can't open the file. Please check " << file_path << ". \n"; - length_ = 0; - buffer_ = NULL; - return; - } - - ifs.seekg(0, std::ios::end); - length_ = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; - - buffer_ = new char[sizeof(char) * length_]; - ifs.read(buffer_, length_); - ifs.close(); - } - - int GetLength() { - return length_; - } - char* GetBuffer() { - return buffer_; - } - - ~BufferFile() { - if (buffer_) { - delete[] buffer_; - buffer_ = NULL; - } - } -}; - -#endif // CNNBUFFERFILE_H diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNCreator_mountaincar_master_actor.py b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNCreator_mountaincar_master_actor.py index 3ad337081f314950702b462536784d28e364b9db..74e18c536e3976e454924d793af5c7421a5fdc42 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNCreator_mountaincar_master_actor.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNCreator_mountaincar_master_actor.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_mountaincar_master_actor import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_mountaincar_master_actor: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_mountaincar_master_actor: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_mountaincar_master_actor: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 2,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 2,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNModelLoader.h b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNModelLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4 --- /dev/null +++ b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNModelLoader.h @@ -0,0 +1,141 @@ +#ifndef CNNMODELLOADER +#define CNNMODELLOADER + +#include + +#include +#include +#include + +using namespace mxnet::cpp; + +// Read files to load moddel symbol and parameters +class ModelLoader { +private: + Context ctx = Context::cpu(); + std::vector network_symbol_list; + std::vector> network_param_map_list; + + std::vector query_symbol_list; + std::vector> query_param_map_list; + + std::vector> replay_memory; + + std::vector loss_symbol; + std::vector> loss_param_map; + + + void checkFile(std::string file_path){ + std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); + if (!ifs) { + std::cerr << "Can't open the file. Please check " << file_path << ". \n"; + return; + } + + int length_; + ifs.seekg(0, std::ios::end); + length_ = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; + ifs.close(); + } + + void loadComponent(std::string json_path, + std::string param_path, + std::vector &symbols_list, + std::vector> ¶m_map_list){ + checkFile(json_path); + symbols_list.push_back(Symbol::Load(json_path)); + checkFile(param_path); + std::map params; + NDArray::Load(param_path, 0, ¶ms); + param_map_list.push_back(processParamMap(params)); + } + + std::map processParamMap(std::map param_map){ + std::map processed_param_map; + if(!param_map.empty()){ + for (const auto &pair : param_map) { + std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure) + processed_param_map[name] = pair.second.Copy(ctx); + } + } + return processed_param_map; + } + +public: + explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){ + + ctx = ctx_param; + std::string network_json_path; + std::string network_param_path; + std::string query_json_path; + std::string query_param_path; + std::string memory_path; + std::string loss_json_path; + std::string loss_param_path; + + //Load network + if(!num_subnets){ + network_json_path = file_prefix + "-symbol.json"; + network_param_path = file_prefix + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + }else{ + for(int i=0; i < num_subnets; i++){ + network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json"; + network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + if(i >= 1){ + query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json"; + query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params"; + loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list); + + memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000"; + checkFile(memory_path); + + std::map mem_map = NDArray::LoadToMap(memory_path); + for(auto &mem : mem_map){ + mem.second = mem.second.Copy(ctx); + } + replay_memory.push_back(mem_map); + } + } + } + + //Load Loss + loss_json_path = file_prefix + "_loss-symbol.json"; + loss_param_path = file_prefix + "_loss-0000.params"; + loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map); + + NDArray::WaitAll(); + } + + std::vector GetNetworkSymbols() { + return network_symbol_list; + } + + std::vector> GetNetworkParamMaps() { + return network_param_map_list; + } + + Symbol GetLoss() { + return loss_symbol[0]; + } + + std::map GetLossParamMap() { + return loss_param_map[0]; + } + + std::vector GetQuerySymbols() { + return query_symbol_list; + } + + std::vector> GetQueryParamMaps() { + return query_param_map_list; + } + + std::vector> GetReplayMemory(){ + return replay_memory; + } +}; +#endif // CNNMODELLOADER diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNNet_mountaincar_master_actor.py b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNNet_mountaincar_master_actor.py index 2bec1be40248799ff69cbda6ea132844ab974a32..8f389630a7afc5162c90cb542574979cd8001a44 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNNet_mountaincar_master_actor.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNNet_mountaincar_master_actor.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -123,5 +539,5 @@ class Net_0(gluon.HybridBlock): tanh3_ = self.tanh3_(fc3_) action_ = F.identity(tanh3_) - return action_ + return [[action_]] diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNPredictor_mountaincar_master_actor.h b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNPredictor_mountaincar_master_actor.h index 4a59ca9475e9acc4678817d1baedb380a693eac1..7f843f8242d7e0e2f8b9e47573eaccb879d2bd2d 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNPredictor_mountaincar_master_actor.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNPredictor_mountaincar_master_actor.h @@ -1,107 +1,149 @@ #ifndef CNNPREDICTOR_MOUNTAINCAR_MASTER_ACTOR #define CNNPREDICTOR_MOUNTAINCAR_MASTER_ACTOR -#include +#include #include #include #include + +#include +#include -#include - +using namespace mxnet::cpp; + class CNNPredictor_mountaincar_master_actor_0{ public: - const std::string json_file = "model/mountaincar.agent.MountaincarActor/model_0_newest-symbol.json"; - const std::string param_file = "model/mountaincar.agent.MountaincarActor/model_0_newest-0000.params"; - const std::vector input_keys = { + const std::string file_prefix = "model/mountaincar.agent.MountaincarActor/model_0_newest"; + + //network + const std::vector network_input_keys = { "data" }; - const std::vector> input_shapes = {{1, 2}}; - const bool use_gpu = false; - - PredictorHandle handle; - + const std::vector> network_input_shapes = {{1, 2}}; + std::vector network_input_sizes; + std::vector> network_arg_names; + std::vector network_handles; + + + //misc + Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu + int dtype = 0; //use data type (float32=0 float64=1 ...) + + explicit CNNPredictor_mountaincar_master_actor_0(){ - init(json_file, param_file, input_keys, input_shapes, use_gpu); + init(file_prefix, network_input_keys, network_input_shapes); } ~CNNPredictor_mountaincar_master_actor_0(){ - if(handle) MXPredFree(handle); + for(Executor * handle : network_handles){ + delete handle; + } + MXNotifyShutdown(); } void predict(const std::vector &in_state_, std::vector &out_action_){ - MXPredSetInput(handle, input_keys[0].c_str(), in_state_.data(), static_cast(in_state_.size())); - - MXPredForward(handle); - mx_uint output_index; - mx_uint *shape = 0; - mx_uint shape_len; - size_t size; - - output_index = 0; - MXPredGetOutputShape(handle, output_index, &shape, &shape_len); - size = 1; - for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; - assert(size == out_action_.size()); - MXPredGetOutput(handle, output_index, &(out_action_[0]), out_action_.size()); + NDArray input_temp; + input_temp = NDArray(network_input_shapes[0], ctx, false, dtype); + input_temp.SyncCopyFromCPU(in_state_.data(), network_input_sizes[0]); + input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]])); + NDArray::WaitAll(); + + network_handles[0]->Forward(false); + CheckMXNetError("Forward, predict, handle ind. 0"); + + + std::vector output = network_handles.back()->outputs; + std::vector curr_output_shape; + size_t curr_output_size; + curr_output_shape = output[0].GetShape(); + curr_output_size = 1; + for (mx_uint i : curr_output_shape) curr_output_size *= i; + //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs + assert((curr_output_size == out_action_.size()) || (curr_output_size == out_action_[0])); + output[0].SyncCopyToCPU(&out_action_); + } + + + + Executor* initExecutor(Symbol &sym, + std::map ¶m_map, + const std::vector &exec_input_keys, + const std::vector> &exec_input_shapes){ + + const mx_uint num_exec_input_nodes = exec_input_keys.size(); + for(mx_uint i = 0; i < num_exec_input_nodes; i++){ + param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype); + } - void init(const std::string &json_file, - const std::string ¶m_file, - const std::vector &input_keys, - const std::vector> &input_shapes, - const bool &use_gpu){ + std::vector param_arrays; + std::vector grad_array; + std::vector grad_reqs; + std::vector aux_arrays; + std::map< std::string, NDArray> aux_map; - BufferFile json_data(json_file); - BufferFile param_data(param_file); + sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs, + &aux_arrays, param_map, std::map(), + std::map(), aux_map); - int dev_type = use_gpu ? 2 : 1; - int dev_id = 0; + Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays); + assert(handle); + return handle; + } - if (json_data.GetLength() == 0 || - param_data.GetLength() == 0) { - std::exit(-1); + std::vector getSizesOfShapes(const std::vector> shapes){ + std::vector sizes; + for(std::vector shape : shapes){ + mx_uint val = 1; + for(mx_uint i: shape){ + val *= i; + } + sizes.push_back(val); } + return sizes; + } - const mx_uint num_input_nodes = input_keys.size(); - - const char* input_keys_ptr[num_input_nodes]; - for(mx_uint i = 0; i < num_input_nodes; i++){ - input_keys_ptr[i] = input_keys[i].c_str(); + void CheckMXNetError(std::string loc){ + const char* err = MXGetLastError(); + if (err && err[0] != 0) { + std::cout << "MXNet error at " << loc << err << std::endl; + exit(-1); } - - mx_uint shape_data_size = 0; - mx_uint input_shape_indptr[input_shapes.size() + 1]; - input_shape_indptr[0] = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - shape_data_size += input_shapes[i].size(); - input_shape_indptr[i+1] = shape_data_size; + } + + void init(const std::string &file_prefix, + const std::vector &network_input_keys, + const std::vector> &network_input_shapes){ + + CNNLAOptimizer_mountaincar_master_actor optimizer_creator = CNNLAOptimizer_mountaincar_master_actor(); + + if(optimizer_creator.getContextName() == "gpu"){ + ctx = Context::gpu(); } - - mx_uint input_shape_data[shape_data_size]; - mx_uint index = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - for(mx_uint j = 0; j < input_shapes[i].size(); j++){ - input_shape_data[index] = input_shapes[i][j]; - index++; - } + + network_input_sizes = getSizesOfShapes(network_input_shapes); + + ModelLoader model_loader(file_prefix, 0, ctx); + + std::vector network_symbols = model_loader.GetNetworkSymbols(); + std::vector> network_param_maps; + network_param_maps = model_loader.GetNetworkParamMaps(); + + //Init handles + std::map> in_shape_map; + for(mx_uint i=0; i < network_input_keys.size(); i++){ + in_shape_map[network_input_keys[i]] = network_input_shapes[i]; } - - MXPredCreate(static_cast(json_data.GetBuffer()), - static_cast(param_data.GetBuffer()), - static_cast(param_data.GetLength()), - dev_type, - dev_id, - num_input_nodes, - input_keys_ptr, - input_shape_indptr, - input_shape_data, - &handle); - assert(handle); + std::vector> in_shapes; + std::vector> aux_shapes; + std::vector> out_shapes; + network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes); + network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes)); + } }; - #endif // CNNPREDICTOR_MOUNTAINCAR_MASTER_ACTOR diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNTrainer_mountaincar_master_actor.py b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNTrainer_mountaincar_master_actor.py index a2827e4244ebb33a3dfc59a953450ac106c65c48..198280f40934d0f15546db98f30bd4b0d1547a72 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNTrainer_mountaincar_master_actor.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/CNNTrainer_mountaincar_master_actor.py @@ -61,6 +61,7 @@ if __name__ == "__main__": 'state_dtype': 'float32', 'action_dtype': 'float32', 'rewards_dtype': 'float32' + }, 'strategy_params': { 'method':'ornstein_uhlenbeck', diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/mountaincar_master_actor.h b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/mountaincar_master_actor.h index 4cc18f44efd5f36240259fcb9db324557fc91e32..39061da1abe9ef352f840e1f5b8cd218a1efaf33 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/mountaincar_master_actor.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/mountaincar_master_actor.h @@ -19,8 +19,10 @@ action=colvec(1); } void execute(){ vector state_ = CNNTranslator::translate(state); + vector action_(1); + _predictor_0_.predict(state_, action_); action = CNNTranslator::translateToCol(action_, std::vector {1}); diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/reinforcement_learning/CNNCreator_mountaincar_agent_mountaincarCritic.py b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/reinforcement_learning/CNNCreator_mountaincar_agent_mountaincarCritic.py index 2f3f2bda7dfdd36b7a335aeae38bb62a91e78dab..0a13ecf55f50a2b74bcbce205025c13e58d5f7f5 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/reinforcement_learning/CNNCreator_mountaincar_agent_mountaincarCritic.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/reinforcement_learning/CNNCreator_mountaincar_agent_mountaincarCritic.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_mountaincar_agent_mountaincarCritic import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_mountaincar_agent_mountaincarCritic: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_mountaincar_agent_mountaincarCritic: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_mountaincar_agent_mountaincarCritic: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 2,), ctx=context), mx.nd.zeros((1, 1,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 2,), ctx=context[0]), mx.nd.zeros((1, 1,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/reinforcement_learning/CNNNet_mountaincar_agent_mountaincarCritic.py b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/reinforcement_learning/CNNNet_mountaincar_agent_mountaincarCritic.py index fe452eb86e91a366fcd26e0db1fa27a6165a28fa..9e90a181998dfd7013ab7e77b51285bcb92e6443 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/reinforcement_learning/CNNNet_mountaincar_agent_mountaincarCritic.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/mountaincar/reinforcement_learning/CNNNet_mountaincar_agent_mountaincarCritic.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -134,5 +550,5 @@ class Net_0(gluon.HybridBlock): fc4_ = self.fc4_(relu4_) qvalues_ = F.identity(fc4_) - return qvalues_ + return [[qvalues_]] diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNBufferFile.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNBufferFile.h deleted file mode 100644 index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNBufferFile.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef CNNBUFFERFILE_H -#define CNNBUFFERFILE_H - -#include -#include -#include - -// Read file to buffer -class BufferFile { - public : - std::string file_path_; - int length_; - char* buffer_; - - explicit BufferFile(std::string file_path) - :file_path_(file_path) { - - std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); - if (!ifs) { - std::cerr << "Can't open the file. Please check " << file_path << ". \n"; - length_ = 0; - buffer_ = NULL; - return; - } - - ifs.seekg(0, std::ios::end); - length_ = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; - - buffer_ = new char[sizeof(char) * length_]; - ifs.read(buffer_, length_); - ifs.close(); - } - - int GetLength() { - return length_; - } - char* GetBuffer() { - return buffer_; - } - - ~BufferFile() { - if (buffer_) { - delete[] buffer_; - buffer_ = NULL; - } - } -}; - -#endif // CNNBUFFERFILE_H diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNCreator_torcs_agent_torcsAgent_dqn.py b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNCreator_torcs_agent_torcsAgent_dqn.py index 508fc9381b7e9c6ce386a8c027b2a8e0d325f9d8..0632c4040bd1da6180a878f726251c686d1fefd1 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNCreator_torcs_agent_torcsAgent_dqn.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNCreator_torcs_agent_torcsAgent_dqn.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_torcs_agent_torcsAgent_dqn import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_torcs_agent_torcsAgent_dqn: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_torcs_agent_torcsAgent_dqn: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_torcs_agent_torcsAgent_dqn: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 5,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 5,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNModelLoader.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNModelLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4 --- /dev/null +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNModelLoader.h @@ -0,0 +1,141 @@ +#ifndef CNNMODELLOADER +#define CNNMODELLOADER + +#include + +#include +#include +#include + +using namespace mxnet::cpp; + +// Read files to load moddel symbol and parameters +class ModelLoader { +private: + Context ctx = Context::cpu(); + std::vector network_symbol_list; + std::vector> network_param_map_list; + + std::vector query_symbol_list; + std::vector> query_param_map_list; + + std::vector> replay_memory; + + std::vector loss_symbol; + std::vector> loss_param_map; + + + void checkFile(std::string file_path){ + std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); + if (!ifs) { + std::cerr << "Can't open the file. Please check " << file_path << ". \n"; + return; + } + + int length_; + ifs.seekg(0, std::ios::end); + length_ = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; + ifs.close(); + } + + void loadComponent(std::string json_path, + std::string param_path, + std::vector &symbols_list, + std::vector> ¶m_map_list){ + checkFile(json_path); + symbols_list.push_back(Symbol::Load(json_path)); + checkFile(param_path); + std::map params; + NDArray::Load(param_path, 0, ¶ms); + param_map_list.push_back(processParamMap(params)); + } + + std::map processParamMap(std::map param_map){ + std::map processed_param_map; + if(!param_map.empty()){ + for (const auto &pair : param_map) { + std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure) + processed_param_map[name] = pair.second.Copy(ctx); + } + } + return processed_param_map; + } + +public: + explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){ + + ctx = ctx_param; + std::string network_json_path; + std::string network_param_path; + std::string query_json_path; + std::string query_param_path; + std::string memory_path; + std::string loss_json_path; + std::string loss_param_path; + + //Load network + if(!num_subnets){ + network_json_path = file_prefix + "-symbol.json"; + network_param_path = file_prefix + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + }else{ + for(int i=0; i < num_subnets; i++){ + network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json"; + network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + if(i >= 1){ + query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json"; + query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params"; + loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list); + + memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000"; + checkFile(memory_path); + + std::map mem_map = NDArray::LoadToMap(memory_path); + for(auto &mem : mem_map){ + mem.second = mem.second.Copy(ctx); + } + replay_memory.push_back(mem_map); + } + } + } + + //Load Loss + loss_json_path = file_prefix + "_loss-symbol.json"; + loss_param_path = file_prefix + "_loss-0000.params"; + loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map); + + NDArray::WaitAll(); + } + + std::vector GetNetworkSymbols() { + return network_symbol_list; + } + + std::vector> GetNetworkParamMaps() { + return network_param_map_list; + } + + Symbol GetLoss() { + return loss_symbol[0]; + } + + std::map GetLossParamMap() { + return loss_param_map[0]; + } + + std::vector GetQuerySymbols() { + return query_symbol_list; + } + + std::vector> GetQueryParamMaps() { + return query_param_map_list; + } + + std::vector> GetReplayMemory(){ + return replay_memory; + } +}; +#endif // CNNMODELLOADER diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNNet_torcs_agent_torcsAgent_dqn.py b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNNet_torcs_agent_torcsAgent_dqn.py index a13ab817cce7fff4653da898a81c8a4b96dbb2c5..a743e0b1bc864f674d74278766daedfe8928c38a 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNNet_torcs_agent_torcsAgent_dqn.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNNet_torcs_agent_torcsAgent_dqn.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -121,5 +537,5 @@ class Net_0(gluon.HybridBlock): fc3_ = self.fc3_(tanh2_) qvalues_ = F.identity(fc3_) - return qvalues_ + return [[qvalues_]] diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNPredictor_torcs_agent_torcsAgent_dqn.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNPredictor_torcs_agent_torcsAgent_dqn.h index 8e45765adf37fe02d84395b2c6524bd26308a7dd..adb17d99866ca87a2f66f1f6dfbd9339632d7609 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNPredictor_torcs_agent_torcsAgent_dqn.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNPredictor_torcs_agent_torcsAgent_dqn.h @@ -1,107 +1,149 @@ #ifndef CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_DQN #define CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_DQN -#include +#include #include #include #include + +#include +#include -#include - +using namespace mxnet::cpp; + class CNNPredictor_torcs_agent_torcsAgent_dqn_0{ public: - const std::string json_file = "model/torcs.agent.dqn.TorcsDQN/model_0_newest-symbol.json"; - const std::string param_file = "model/torcs.agent.dqn.TorcsDQN/model_0_newest-0000.params"; - const std::vector input_keys = { + const std::string file_prefix = "model/torcs.agent.dqn.TorcsDQN/model_0_newest"; + + //network + const std::vector network_input_keys = { "data" }; - const std::vector> input_shapes = {{1, 5}}; - const bool use_gpu = false; - - PredictorHandle handle; - + const std::vector> network_input_shapes = {{1, 5}}; + std::vector network_input_sizes; + std::vector> network_arg_names; + std::vector network_handles; + + + //misc + Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu + int dtype = 0; //use data type (float32=0 float64=1 ...) + + explicit CNNPredictor_torcs_agent_torcsAgent_dqn_0(){ - init(json_file, param_file, input_keys, input_shapes, use_gpu); + init(file_prefix, network_input_keys, network_input_shapes); } ~CNNPredictor_torcs_agent_torcsAgent_dqn_0(){ - if(handle) MXPredFree(handle); + for(Executor * handle : network_handles){ + delete handle; + } + MXNotifyShutdown(); } void predict(const std::vector &in_state_, std::vector &out_qvalues_){ - MXPredSetInput(handle, input_keys[0].c_str(), in_state_.data(), static_cast(in_state_.size())); - - MXPredForward(handle); - mx_uint output_index; - mx_uint *shape = 0; - mx_uint shape_len; - size_t size; - - output_index = 0; - MXPredGetOutputShape(handle, output_index, &shape, &shape_len); - size = 1; - for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; - assert(size == out_qvalues_.size()); - MXPredGetOutput(handle, output_index, &(out_qvalues_[0]), out_qvalues_.size()); + NDArray input_temp; + input_temp = NDArray(network_input_shapes[0], ctx, false, dtype); + input_temp.SyncCopyFromCPU(in_state_.data(), network_input_sizes[0]); + input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]])); + NDArray::WaitAll(); + + network_handles[0]->Forward(false); + CheckMXNetError("Forward, predict, handle ind. 0"); + + + std::vector output = network_handles.back()->outputs; + std::vector curr_output_shape; + size_t curr_output_size; + curr_output_shape = output[0].GetShape(); + curr_output_size = 1; + for (mx_uint i : curr_output_shape) curr_output_size *= i; + //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs + assert((curr_output_size == out_qvalues_.size()) || (curr_output_size == out_qvalues_[0])); + output[0].SyncCopyToCPU(&out_qvalues_); + } + + + + Executor* initExecutor(Symbol &sym, + std::map ¶m_map, + const std::vector &exec_input_keys, + const std::vector> &exec_input_shapes){ + + const mx_uint num_exec_input_nodes = exec_input_keys.size(); + for(mx_uint i = 0; i < num_exec_input_nodes; i++){ + param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype); + } - void init(const std::string &json_file, - const std::string ¶m_file, - const std::vector &input_keys, - const std::vector> &input_shapes, - const bool &use_gpu){ + std::vector param_arrays; + std::vector grad_array; + std::vector grad_reqs; + std::vector aux_arrays; + std::map< std::string, NDArray> aux_map; - BufferFile json_data(json_file); - BufferFile param_data(param_file); + sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs, + &aux_arrays, param_map, std::map(), + std::map(), aux_map); - int dev_type = use_gpu ? 2 : 1; - int dev_id = 0; + Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays); + assert(handle); + return handle; + } - if (json_data.GetLength() == 0 || - param_data.GetLength() == 0) { - std::exit(-1); + std::vector getSizesOfShapes(const std::vector> shapes){ + std::vector sizes; + for(std::vector shape : shapes){ + mx_uint val = 1; + for(mx_uint i: shape){ + val *= i; + } + sizes.push_back(val); } + return sizes; + } - const mx_uint num_input_nodes = input_keys.size(); - - const char* input_keys_ptr[num_input_nodes]; - for(mx_uint i = 0; i < num_input_nodes; i++){ - input_keys_ptr[i] = input_keys[i].c_str(); + void CheckMXNetError(std::string loc){ + const char* err = MXGetLastError(); + if (err && err[0] != 0) { + std::cout << "MXNet error at " << loc << err << std::endl; + exit(-1); } - - mx_uint shape_data_size = 0; - mx_uint input_shape_indptr[input_shapes.size() + 1]; - input_shape_indptr[0] = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - shape_data_size += input_shapes[i].size(); - input_shape_indptr[i+1] = shape_data_size; + } + + void init(const std::string &file_prefix, + const std::vector &network_input_keys, + const std::vector> &network_input_shapes){ + + CNNLAOptimizer_torcs_agent_torcsAgent_dqn optimizer_creator = CNNLAOptimizer_torcs_agent_torcsAgent_dqn(); + + if(optimizer_creator.getContextName() == "gpu"){ + ctx = Context::gpu(); } - - mx_uint input_shape_data[shape_data_size]; - mx_uint index = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - for(mx_uint j = 0; j < input_shapes[i].size(); j++){ - input_shape_data[index] = input_shapes[i][j]; - index++; - } + + network_input_sizes = getSizesOfShapes(network_input_shapes); + + ModelLoader model_loader(file_prefix, 0, ctx); + + std::vector network_symbols = model_loader.GetNetworkSymbols(); + std::vector> network_param_maps; + network_param_maps = model_loader.GetNetworkParamMaps(); + + //Init handles + std::map> in_shape_map; + for(mx_uint i=0; i < network_input_keys.size(); i++){ + in_shape_map[network_input_keys[i]] = network_input_shapes[i]; } - - MXPredCreate(static_cast(json_data.GetBuffer()), - static_cast(param_data.GetBuffer()), - static_cast(param_data.GetLength()), - dev_type, - dev_id, - num_input_nodes, - input_keys_ptr, - input_shape_indptr, - input_shape_data, - &handle); - assert(handle); + std::vector> in_shapes; + std::vector> aux_shapes; + std::vector> out_shapes; + network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes); + network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes)); + } }; - #endif // CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_DQN diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNTrainer_torcs_agent_torcsAgent_dqn.py b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNTrainer_torcs_agent_torcsAgent_dqn.py index 4c697411469769c2cf281640582ee715b4a9b26e..7c8fd23517dd323eb7c5ea4e89211e80173edc77 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNTrainer_torcs_agent_torcsAgent_dqn.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs/CNNTrainer_torcs_agent_torcsAgent_dqn.py @@ -65,6 +65,7 @@ if __name__ == "__main__": 'state_dtype': 'float32', 'action_dtype': 'uint8', 'rewards_dtype': 'float32' + }, 'strategy_params': { 'method':'epsgreedy', diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs/torcs_agent_torcsAgent_dqn.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs/torcs_agent_torcsAgent_dqn.h index f80f2b168a1b27477b9eb2d4ad6d33b28918b171..93514655d6c0a820ce384a87d47481b40a426c3f 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs/torcs_agent_torcsAgent_dqn.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs/torcs_agent_torcsAgent_dqn.h @@ -20,8 +20,10 @@ qvalues=colvec(discrete_actions); } void execute(){ vector state_ = CNNTranslator::translate(state); + vector qvalues_(30); + _predictor_0_.predict(state_, qvalues_); qvalues = CNNTranslator::translateToCol(qvalues_, std::vector {30}); diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNBufferFile.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNBufferFile.h deleted file mode 100644 index c0d8dd9cbe6878e07be976dda5ce9046e6c05606..0000000000000000000000000000000000000000 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNBufferFile.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef CNNBUFFERFILE_H -#define CNNBUFFERFILE_H - -#include -#include -#include - -// Read file to buffer -class BufferFile { - public : - std::string file_path_; - int length_; - char* buffer_; - - explicit BufferFile(std::string file_path) - :file_path_(file_path) { - - std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); - if (!ifs) { - std::cerr << "Can't open the file. Please check " << file_path << ". \n"; - length_ = 0; - buffer_ = NULL; - return; - } - - ifs.seekg(0, std::ios::end); - length_ = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; - - buffer_ = new char[sizeof(char) * length_]; - ifs.read(buffer_, length_); - ifs.close(); - } - - int GetLength() { - return length_; - } - char* GetBuffer() { - return buffer_; - } - - ~BufferFile() { - if (buffer_) { - delete[] buffer_; - buffer_ = NULL; - } - } -}; - -#endif // CNNBUFFERFILE_H diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNCreator_torcs_agent_torcsAgent_actor.py b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNCreator_torcs_agent_torcsAgent_actor.py index 6a1b7fa3956b8e16d20702e6c716b71770f0c5bc..092dfdb9f515254a471317d98c73bdd85f12edc6 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNCreator_torcs_agent_torcsAgent_actor.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNCreator_torcs_agent_torcsAgent_actor.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_torcs_agent_torcsAgent_actor import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_torcs_agent_torcsAgent_actor: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_torcs_agent_torcsAgent_actor: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_torcs_agent_torcsAgent_actor: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 29,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 29,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNModelLoader.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNModelLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..c15e03e9ccd51c9d37e3793d556ed044b4dd6af4 --- /dev/null +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNModelLoader.h @@ -0,0 +1,141 @@ +#ifndef CNNMODELLOADER +#define CNNMODELLOADER + +#include + +#include +#include +#include + +using namespace mxnet::cpp; + +// Read files to load moddel symbol and parameters +class ModelLoader { +private: + Context ctx = Context::cpu(); + std::vector network_symbol_list; + std::vector> network_param_map_list; + + std::vector query_symbol_list; + std::vector> query_param_map_list; + + std::vector> replay_memory; + + std::vector loss_symbol; + std::vector> loss_param_map; + + + void checkFile(std::string file_path){ + std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); + if (!ifs) { + std::cerr << "Can't open the file. Please check " << file_path << ". \n"; + return; + } + + int length_; + ifs.seekg(0, std::ios::end); + length_ = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; + ifs.close(); + } + + void loadComponent(std::string json_path, + std::string param_path, + std::vector &symbols_list, + std::vector> ¶m_map_list){ + checkFile(json_path); + symbols_list.push_back(Symbol::Load(json_path)); + checkFile(param_path); + std::map params; + NDArray::Load(param_path, 0, ¶ms); + param_map_list.push_back(processParamMap(params)); + } + + std::map processParamMap(std::map param_map){ + std::map processed_param_map; + if(!param_map.empty()){ + for (const auto &pair : param_map) { + std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure) + processed_param_map[name] = pair.second.Copy(ctx); + } + } + return processed_param_map; + } + +public: + explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){ + + ctx = ctx_param; + std::string network_json_path; + std::string network_param_path; + std::string query_json_path; + std::string query_param_path; + std::string memory_path; + std::string loss_json_path; + std::string loss_param_path; + + //Load network + if(!num_subnets){ + network_json_path = file_prefix + "-symbol.json"; + network_param_path = file_prefix + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + }else{ + for(int i=0; i < num_subnets; i++){ + network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json"; + network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params"; + loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list); + if(i >= 1){ + query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json"; + query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params"; + loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list); + + memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000"; + checkFile(memory_path); + + std::map mem_map = NDArray::LoadToMap(memory_path); + for(auto &mem : mem_map){ + mem.second = mem.second.Copy(ctx); + } + replay_memory.push_back(mem_map); + } + } + } + + //Load Loss + loss_json_path = file_prefix + "_loss-symbol.json"; + loss_param_path = file_prefix + "_loss-0000.params"; + loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map); + + NDArray::WaitAll(); + } + + std::vector GetNetworkSymbols() { + return network_symbol_list; + } + + std::vector> GetNetworkParamMaps() { + return network_param_map_list; + } + + Symbol GetLoss() { + return loss_symbol[0]; + } + + std::map GetLossParamMap() { + return loss_param_map[0]; + } + + std::vector GetQuerySymbols() { + return query_symbol_list; + } + + std::vector> GetQueryParamMaps() { + return query_param_map_list; + } + + std::vector> GetReplayMemory(){ + return replay_memory; + } +}; +#endif // CNNMODELLOADER diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNNet_torcs_agent_torcsAgent_actor.py b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNNet_torcs_agent_torcsAgent_actor.py index 26e6138dbd870292d52a68c41c86424be9171be8..3194456bdb63050199ebf1b20e04631dbf04adb1 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNNet_torcs_agent_torcsAgent_actor.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNNet_torcs_agent_torcsAgent_actor.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -123,5 +539,5 @@ class Net_0(gluon.HybridBlock): tanh3_ = self.tanh3_(fc3_) commands_ = F.identity(tanh3_) - return commands_ + return [[commands_]] diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNPredictor_torcs_agent_torcsAgent_actor.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNPredictor_torcs_agent_torcsAgent_actor.h index b196a59964430ca3df802e898a0ce93703367a91..dc3ee21a8174fd2e8029b3c94b34fa9e266b3db0 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNPredictor_torcs_agent_torcsAgent_actor.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNPredictor_torcs_agent_torcsAgent_actor.h @@ -1,108 +1,149 @@ #ifndef CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_ACTOR #define CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_ACTOR -#include +#include #include #include #include + +#include +#include -#include - +using namespace mxnet::cpp; + class CNNPredictor_torcs_agent_torcsAgent_actor_0{ public: - const std::string json_file = "model/torcs.agent.network.TorcsActor/model_0_newest-symbol.json"; - const std::string param_file = "model/torcs.agent.network.TorcsActor/model_0_newest-0000.params"; - const std::vector input_keys = { + const std::string file_prefix = "model/torcs.agent.network.TorcsActor/model_0_newest"; + + //network + const std::vector network_input_keys = { "data" }; - const std::vector> input_shapes = {{1, 29}}; - const bool use_gpu = false; - - PredictorHandle handle; - + const std::vector> network_input_shapes = {{1, 29}}; + std::vector network_input_sizes; + std::vector> network_arg_names; + std::vector network_handles; + + + //misc + Context ctx = Context::cpu(); //Will be updated later in init according to use_gpu + int dtype = 0; //use data type (float32=0 float64=1 ...) + + explicit CNNPredictor_torcs_agent_torcsAgent_actor_0(){ - init(json_file, param_file, input_keys, input_shapes, use_gpu); + init(file_prefix, network_input_keys, network_input_shapes); } ~CNNPredictor_torcs_agent_torcsAgent_actor_0(){ - if(handle) MXPredFree(handle); + for(Executor * handle : network_handles){ + delete handle; + } + MXNotifyShutdown(); } void predict(const std::vector &in_state_, std::vector &out_commands_){ - MXPredSetInput(handle, input_keys[0].c_str(), in_state_.data(), static_cast(in_state_.size())); - - MXPredForward(handle); - mx_uint output_index; - mx_uint *shape = 0; - mx_uint shape_len; - size_t size; - - output_index = 0; - MXPredGetOutputShape(handle, output_index, &shape, &shape_len); - size = 1; - for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; - - assert(size == out_commands_.size()); - MXPredGetOutput(handle, output_index, &(out_commands_[0]), out_commands_.size()); + NDArray input_temp; + input_temp = NDArray(network_input_shapes[0], ctx, false, dtype); + input_temp.SyncCopyFromCPU(in_state_.data(), network_input_sizes[0]); + input_temp.CopyTo(&(network_handles[0]->arg_dict()[network_input_keys[0]])); + NDArray::WaitAll(); + + network_handles[0]->Forward(false); + CheckMXNetError("Forward, predict, handle ind. 0"); + + + std::vector output = network_handles.back()->outputs; + std::vector curr_output_shape; + size_t curr_output_size; + curr_output_shape = output[0].GetShape(); + curr_output_size = 1; + for (mx_uint i : curr_output_shape) curr_output_size *= i; + //Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs + assert((curr_output_size == out_commands_.size()) || (curr_output_size == out_commands_[0])); + output[0].SyncCopyToCPU(&out_commands_); + } + + + + Executor* initExecutor(Symbol &sym, + std::map ¶m_map, + const std::vector &exec_input_keys, + const std::vector> &exec_input_shapes){ + + const mx_uint num_exec_input_nodes = exec_input_keys.size(); + for(mx_uint i = 0; i < num_exec_input_nodes; i++){ + param_map[exec_input_keys[i]] = NDArray(exec_input_shapes[i], ctx, false, dtype); + } - void init(const std::string &json_file, - const std::string ¶m_file, - const std::vector &input_keys, - const std::vector> &input_shapes, - const bool &use_gpu){ + std::vector param_arrays; + std::vector grad_array; + std::vector grad_reqs; + std::vector aux_arrays; + std::map< std::string, NDArray> aux_map; - BufferFile json_data(json_file); - BufferFile param_data(param_file); + sym.InferExecutorArrays(ctx, ¶m_arrays, &grad_array, &grad_reqs, + &aux_arrays, param_map, std::map(), + std::map(), aux_map); - int dev_type = use_gpu ? 2 : 1; - int dev_id = 0; + Executor *handle = new Executor(sym, ctx, param_arrays, grad_array, grad_reqs, aux_arrays); + assert(handle); + return handle; + } - if (json_data.GetLength() == 0 || - param_data.GetLength() == 0) { - std::exit(-1); + std::vector getSizesOfShapes(const std::vector> shapes){ + std::vector sizes; + for(std::vector shape : shapes){ + mx_uint val = 1; + for(mx_uint i: shape){ + val *= i; + } + sizes.push_back(val); } + return sizes; + } - const mx_uint num_input_nodes = input_keys.size(); - - const char* input_keys_ptr[num_input_nodes]; - for(mx_uint i = 0; i < num_input_nodes; i++){ - input_keys_ptr[i] = input_keys[i].c_str(); + void CheckMXNetError(std::string loc){ + const char* err = MXGetLastError(); + if (err && err[0] != 0) { + std::cout << "MXNet error at " << loc << err << std::endl; + exit(-1); } - - mx_uint shape_data_size = 0; - mx_uint input_shape_indptr[input_shapes.size() + 1]; - input_shape_indptr[0] = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - shape_data_size += input_shapes[i].size(); - input_shape_indptr[i+1] = shape_data_size; + } + + void init(const std::string &file_prefix, + const std::vector &network_input_keys, + const std::vector> &network_input_shapes){ + + CNNLAOptimizer_torcs_agent_torcsAgent_actor optimizer_creator = CNNLAOptimizer_torcs_agent_torcsAgent_actor(); + + if(optimizer_creator.getContextName() == "gpu"){ + ctx = Context::gpu(); } - - mx_uint input_shape_data[shape_data_size]; - mx_uint index = 0; - for(mx_uint i = 0; i < input_shapes.size(); i++){ - for(mx_uint j = 0; j < input_shapes[i].size(); j++){ - input_shape_data[index] = input_shapes[i][j]; - index++; - } + + network_input_sizes = getSizesOfShapes(network_input_shapes); + + ModelLoader model_loader(file_prefix, 0, ctx); + + std::vector network_symbols = model_loader.GetNetworkSymbols(); + std::vector> network_param_maps; + network_param_maps = model_loader.GetNetworkParamMaps(); + + //Init handles + std::map> in_shape_map; + for(mx_uint i=0; i < network_input_keys.size(); i++){ + in_shape_map[network_input_keys[i]] = network_input_shapes[i]; } - - MXPredCreate(static_cast(json_data.GetBuffer()), - static_cast(param_data.GetBuffer()), - static_cast(param_data.GetLength()), - dev_type, - dev_id, - num_input_nodes, - input_keys_ptr, - input_shape_indptr, - input_shape_data, - &handle); - assert(handle); + std::vector> in_shapes; + std::vector> aux_shapes; + std::vector> out_shapes; + network_symbols[0].InferShape(in_shape_map, &in_shapes, &aux_shapes, &out_shapes); + network_handles.push_back(initExecutor(network_symbols[0], network_param_maps[0], network_input_keys, network_input_shapes)); + } }; - #endif // CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_ACTOR diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNTrainer_torcs_agent_torcsAgent_actor.py b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNTrainer_torcs_agent_torcsAgent_actor.py index c11dfc7ee4ae9d67f6fd07b89f5cd690ee490916..62ea6ad05e0050a51aea546d5d7be37470fad439 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNTrainer_torcs_agent_torcsAgent_actor.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/CNNTrainer_torcs_agent_torcsAgent_actor.py @@ -68,6 +68,7 @@ if __name__ == "__main__": 'state_dtype': 'float32', 'action_dtype': 'float32', 'rewards_dtype': 'float32' + }, 'strategy_params': { 'method':'ornstein_uhlenbeck', diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reinforcement_learning/CNNCreator_torcs_agent_network_torcsCritic.py b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reinforcement_learning/CNNCreator_torcs_agent_network_torcsCritic.py index 62e7853e7c18fcdef792e17609e77eef4d86d4bf..14449c6c0ca15150602cb649c8952849f8ea105f 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reinforcement_learning/CNNCreator_torcs_agent_network_torcsCritic.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reinforcement_learning/CNNCreator_torcs_agent_network_torcsCritic.py @@ -2,6 +2,8 @@ import mxnet as mx import logging import os import shutil +import warnings +import inspect from CNNNet_torcs_agent_network_torcsCritic import Net_0 @@ -20,6 +22,10 @@ class CNNCreator_torcs_agent_network_torcsCritic: for i, network in self.networks.items(): lastEpoch = 0 param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0]*num_episodic_sub_nets + mem_files = [None]*num_episodic_sub_nets try: os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") @@ -30,22 +36,77 @@ class CNNCreator_torcs_agent_network_torcsCritic: except OSError: pass + if hasattr(network, 'episodic_sub_nets'): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json") + except OSError: + pass + + for j in range(len(network.episodic_sub_nets)): + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json") + except OSError: + pass + try: + os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000") + except OSError: + pass + if os.path.isdir(self._model_dir_): for file in os.listdir(self._model_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: - epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: + epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + if param_file is None: earliestLastEpoch = 0 else: logging.info("Loading checkpoint: " + param_file) network.load_parameters(self._model_dir_ + param_file) + if hasattr(network, 'episodic_sub_nets'): + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading Replay Memory: " + mem_files[j]) + mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) - if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: - earliestLastEpoch = lastEpoch + if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch: + earliestLastEpoch = lastEpoch + 1 return earliestLastEpoch @@ -56,27 +117,52 @@ class CNNCreator_torcs_agent_network_torcsCritic: for i, network in self.networks.items(): # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" param_file = None + if hasattr(network, 'episodic_sub_nets'): + num_episodic_sub_nets = len(network.episodic_sub_nets) + lastMemEpoch = [0] * num_episodic_sub_nets + mem_files = [None] * num_episodic_sub_nets + if os.path.isdir(self._weights_dir_): lastEpoch = 0 for file in os.listdir(self._weights_dir_): - if ".params" in file and self._model_prefix_ + "_" + str(i) in file: + if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file: epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epoch = int(epochStr) - if epoch > lastEpoch: + if epoch >= lastEpoch: lastEpoch = epoch param_file = file + elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file: + relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-") + memSubNet = int(relMemPathInfo[0]) + memEpochStr = relMemPathInfo[1] + memEpoch = int(memEpochStr) + if memEpoch >= lastMemEpoch[memSubNet-1]: + lastMemEpoch[memSubNet-1] = memEpoch + mem_files[memSubNet-1] = file + logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) + if hasattr(network, 'episodic_sub_nets'): + assert lastEpoch == lastMemEpoch + for j, sub_net in enumerate(network.episodic_sub_nets): + if mem_files[j] != None: + logging.info("Loading pretrained Replay Memory: " + mem_files[j]) + mem_layer = \ + [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if + param[0].startswith("memory")][0][1] + mem_layer.load_memory(self._model_dir_ + mem_files[j]) else: logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) def construct(self, context, data_mean=None, data_std=None): - self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std) - self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context) self.networks[0].hybridize() - self.networks[0](mx.nd.zeros((1, 29,), ctx=context), mx.nd.zeros((1, 3,), ctx=context)) + self.networks[0](mx.nd.zeros((1, 29,), ctx=context[0]), mx.nd.zeros((1, 3,), ctx=context[0])) if not os.path.exists(self._model_dir_): os.makedirs(self._model_dir_) diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reinforcement_learning/CNNNet_torcs_agent_network_torcsCritic.py b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reinforcement_learning/CNNNet_torcs_agent_network_torcsCritic.py index 506ddf9c24bff1750b0eac873049e934feb98899..0f0274058916fda106bc9a1e8b0fe53ec87a418a 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reinforcement_learning/CNNNet_torcs_agent_network_torcsCritic.py +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reinforcement_learning/CNNNet_torcs_agent_network_torcsCritic.py @@ -1,7 +1,10 @@ import mxnet as mx import numpy as np import math -from mxnet import gluon +import os +import abc +import warnings +from mxnet import gluon, nd class ZScoreNormalization(gluon.HybridBlock): @@ -86,9 +89,422 @@ class CustomGRU(gluon.HybridBlock): output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)]) return output, F.swapaxes(state0, 0, 1) + +class DotProductSelfAttention(gluon.HybridBlock): + def __init__(self, + scale_factor, + num_heads, + dim_model, + dim_keys, + dim_values, + use_proj_bias, + use_mask, + **kwargs): + super(DotProductSelfAttention, self).__init__(**kwargs) + with self.name_scope(): + self.num_heads = num_heads + self.dim_model = dim_model + self.use_proj_bias = use_proj_bias + self.use_mask = use_mask + + if dim_keys == -1: + self.dim_keys = int(dim_model / self.num_heads) + else: + self.dim_keys = dim_keys + if dim_values == -1: + self.dim_values = int(dim_model / self.num_heads) + else: + self.dim_values = dim_values + + if scale_factor == -1: + self.scale_factor = math.sqrt(self.dim_keys) + else: + self.scale_factor = scale_factor + + self.proj_q = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_k = gluon.nn.Dense(self.num_heads*self.dim_keys, use_bias=self.use_proj_bias, flatten=False) + self.proj_v = gluon.nn.Dense(self.num_heads*self.dim_values, use_bias=self.use_proj_bias, flatten=False) + self.proj_o = gluon.nn.Dense(self.dim_model, use_bias=self.use_proj_bias, flatten=False) + + def hybrid_forward(self, F, queries, keys, values, *args, **kwargs): + + queries = F.Reshape(queries, shape=(0, 0,-1)) + keys = F.Reshape(queries, shape=(0, 0, -1)) + values = F.Reshape(queries, shape=(0, 0, -1)) + + head_queries = self.proj_q(queries) + head_keys = self.proj_k(keys) + head_values = self.proj_v(values) + + head_queries = F.reshape(head_queries, shape=(0, 0, self.num_heads, -1)) + head_queries = F.transpose(head_queries, axes=(0,2,1,3)) + head_queries = F.reshape(head_queries, shape=(-1, 0, 0), reverse=True) + + head_keys = F.reshape(head_keys, shape=(0, 0, self.num_heads, -1)) + head_keys = F.transpose(head_keys, axes=(0,2,1,3)) + head_keys = F.reshape(head_keys, shape=(-1, 0, 0), reverse=True) + + score = F.batch_dot(head_queries, head_keys, transpose_b=True) + score = score * self.scale_factor + if self.use_mask: + mask = F.tile(mask, self.num_heads) + mask = F.repeat(mask, self.dim_model) + mask = F.reshape(mask, shape=(-1, self.dim_model)) + weights = F.softmax(score, mask, use_length=self.use_mask) + + head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1)) + head_values = F.transpose(head_values, axes=(0,2,1,3)) + head_values = F.reshape(head_values, shape=(-1, 0, 0), reverse=True) + + ret = F.batch_dot(weights, head_values) + ret = F.reshape(ret, shape=(-1, self.num_heads, 0, 0), reverse=True) + ret = F.transpose(ret, axes=(0, 2, 1, 3)) + ret = F.reshape(ret, shape=(0, 0, -1)) + + ret = self.proj_o(ret) + + return ret + + +class EpisodicReplayMemoryInterface(gluon.HybridBlock): + __metaclass__ = abc.ABCMeta + + def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs): + super(EpisodicReplayMemoryInterface, self).__init__(**kwargs) + + self.use_replay = use_replay + self.replay_interval = replay_interval + self.replay_batch_size = replay_batch_size + self.replay_steps = replay_steps + self.replay_gradient_steps = replay_gradient_steps + self.num_heads = num_heads + + @abc.abstractmethod + def store_samples(self, data, y, query_network, store_prob, mx_context): + pass + + @abc.abstractmethod + def sample_memory(self, batch_size, mx_context): + pass + + @abc.abstractmethod + def get_query_network(self, mx_context): + pass + + @abc.abstractmethod + def save_memory(self, path): + pass + + @abc.abstractmethod + def load_memory(self, path): + pass + +#Memory layer +class LargeMemory(gluon.HybridBlock): + def __init__(self, + sub_key_size, + query_size, + query_act, + dist_measure, + k, + num_heads, + values_dim, + **kwargs): + super(LargeMemory, self).__init__(**kwargs) + with self.name_scope(): + #Memory parameters + self.dist_measure = dist_measure + self.k = k + self.num_heads = num_heads + self.query_act = query_act + self.query_size = query_size + self.num_heads = num_heads + + #Batch norm sub-layer + self.batch_norm = gluon.nn.BatchNorm() + + #Memory sub-layer + self.sub_key_size = sub_key_size + sub_key_shape = (self.num_heads, self.sub_key_size, int(query_size[-1] / 2)) + + if values_dim == -1: + values_shape = (self.sub_key_size * self.sub_key_size, self.query_size[-1]) + else: + values_shape = (self.sub_key_size*self.sub_key_size, values_dim) + + self.sub_keys1 = self.params.get("sub_keys1", shape=sub_key_shape, differentiable=True) + self.sub_keys2 = self.params.get("sub_keys2", shape=sub_key_shape, differentiable=True) + self.values = self.params.get("values", shape=values_shape, differentiable=True) + self.label_memory = nd.array([]) + + self.get_query_network() + + def hybrid_forward(self, F, x, sub_keys1, sub_keys2, values): + x = self.batch_norm(x) + + x = F.reshape(x, shape=(0, -1)) + + q = self.query_network(x) + + q = F.reshape(q, shape=(0, self.num_heads, -1)) + + q_split = F.split(q, num_outputs=2, axis=-1) + + if self.dist_measure == "l2": + q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1)) + sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True) + q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh) + q1_dist = F.norm(q1_diff, axis=-1) + q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1)) + sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True) + q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh) + q2_dist = F.norm(q2_diff, axis=-1) + else: + q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1) + q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1) + sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + q1 = [q1] + q2 = [q2] + sub_keys1_resh = [sub_keys1_resh ] + sub_keys2_resh = [sub_keys2_resh ] + + q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True) + q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True) + for h in range(1, self.num_heads): + q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1) + q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1) + + i1 = F.topk(q1_dist, k=self.k, ret_typ="indices") + i2 = F.topk(q2_dist, k=self.k, ret_typ="indices") + + # Calculate cross product for keys at indices I1 and I2 + + # def head_take(data, state): + # return [F.take(data[0], data[2]), F.take(data[1], data[3])], state, + # + # i1 = F.transpose(i1, axes=(1,0,2)) + # i2 = F.transpose(i2, axes=(1, 0, 2)) + # st = F.zeros(1) + # (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st) + # k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True) + # k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True) + i1 = F.split(i1, num_outputs=self.num_heads, axis=1) + i2 = F.split(i2, num_outputs=self.num_heads, axis=1) + sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + sub_keys2 = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True) + if self.num_heads == 1: + i1 = [i1] + i2 = [i2] + sub_keys1 = [sub_keys1] + sub_keys2 = [sub_keys2] + + k1 = F.take(sub_keys1[0], i1[0]) + k2 = F.take(sub_keys2[0], i2[0]) + for h in range(1, self.num_heads): + k1 = F.concat(k1, F.take(sub_keys1[h], i1[h]), dim=1) + k2 = F.concat(k2, F.take(sub_keys2[h], i2[h]), dim=1) + + k1 = F.tile(k1, (1, 1, self.k, 1)) + k2 = F.repeat(k2, self.k, 2) + c_cart = F.concat(k1, k2, dim=3) + + q = F.reshape(q, shape=(-1,0), reverse=True) + q = F.reshape(q, shape=(0, 1, -1)) + c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True) + if self.dist_measure == "l2": + k_diff = F.broadcast_sub(q, c_cart) + k_dist = F.norm(k_diff, axis=-1) + else: + k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist) + k_dist = F.reshape(k_dist, shape=(0, -1)) + + i = F.topk(k_dist, k=self.k, ret_typ="both") + + w = F.softmax(i[0]) + w = F.reshape(w, shape=(0,1,-1)) + vi = F.take(values, i[1]) + aggr_value = F.batch_dot(w, vi) #F.contrib.foreach(loop_batch_dot, [w, vi], init_states=state_batch_dist) + + ret = F.reshape(aggr_value, shape=(-1, self.num_heads, 0), reverse=True) + one_vec = F.ones((1, 1, self.num_heads)) + one_vec = F.broadcast_like(one_vec, ret, lhs_axes=0, rhs_axes=0) + ret = F.batch_dot(one_vec, ret) + ret = F.reshape(ret, shape=(-1, 0), reverse=True) + + return ret + + def get_query_network(self): + if hasattr(self, 'query_network'): + return self.query_network + else: + self.query_network = gluon.nn.HybridSequential() + for size in self.query_size: + if self.query_act == "linear": + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, flatten=False)) + else: + self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False)) + return self.query_network + + +#EpisodicMemory layer +class EpisodicMemory(EpisodicReplayMemoryInterface): + def __init__(self, + replay_interval, + replay_batch_size, + replay_steps, + replay_gradient_steps, + store_prob, + max_stored_samples, + memory_replacement_strategy, + use_replay, + query_net_dir, + query_net_prefix, + query_net_num_inputs, + **kwargs): + super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs) + with self.name_scope(): + #Replay parameters + self.store_prob = store_prob + self.max_stored_samples = max_stored_samples + self.memory_replacement_strategy = memory_replacement_strategy + + self.query_net_dir = query_net_dir + self.query_net_prefix = query_net_prefix + self.query_net_num_inputs = query_net_num_inputs + + #Memory + self.key_memory = nd.array([]) + self.value_memory = nd.array([]) + self.label_memory = nd.array([]) + + def hybrid_forward(self, F, *args): + #propagate the input as the rest is only used for replay + return [args, []] + + def store_samples(self, data, y, query_network, store_prob, context): + if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples): + num_pus = len(data) + sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)] + num_inputs = len(data[0][0]) + num_outputs = len(y) + mx_context = context[0] + + if len(self.key_memory) == 0: + self.key_memory = nd.empty(0, ctx=mx.cpu()) + self.value_memory = [] + self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu()) + + ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)] + + max_inds = [nd.max(ind[i]) for i in range(num_pus)] + if any(max_inds): + to_store_values = [] + for i in range(num_inputs): + tmp_values = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_values, list): + tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]) + else: + tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0) + to_store_values.append(tmp_values) + + to_store_labels = [] + for i in range(num_outputs): + tmp_labels = [] + for j in range(0, num_pus): + if max_inds[j]: + if isinstance(tmp_labels, list): + tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]) + else: + tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0) + to_store_labels.append(tmp_labels) + + to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs]) + + if self.key_memory.shape[0] == 0: + self.key_memory = to_store_keys.as_in_context(mx.cpu()) + for i in range(num_inputs): + self.value_memory.append(to_store_values[i].as_in_context(mx.cpu())) + for i in range(num_outputs): + self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu())) + elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples: + num_to_store = to_store_keys.shape[0] + self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + else: + self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0) + for i in range(num_inputs): + self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0) + for i in range(num_outputs): + self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0) + + def sample_memory(self, batch_size): + num_stored_samples = self.key_memory.shape[0] + if self.replay_batch_size == -1: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, batch_size), ctx=mx.cpu()) + else: + sample_ind = nd.random.randint(0, num_stored_samples, (self.replay_steps, self.replay_batch_size), ctx=mx.cpu()) + + num_outputs = len(self.label_memory) + + sample_labels = [[self.label_memory[i][ind] for i in range(num_outputs)] for ind in sample_ind] + sample_batches = [[[self.value_memory[j][ind] for j in range(len(self.value_memory))], sample_labels[i]] for i, ind in enumerate(sample_ind)] + + return sample_batches + + def get_query_network(self, context): + lastEpoch = 0 + for file in os.listdir(self.query_net_dir): + if self.query_net_prefix in file and ".json" in file: + symbolFile = file + + if self.query_net_prefix in file and ".param" in file: + epochStr = file.replace(".params", "").replace(self.query_net_prefix, "") + epoch = int(epochStr) + if epoch >= lastEpoch: + lastEpoch = epoch + weightFile = file + + inputNames = [] + if self.query_net_num_inputs == 1: + inputNames.append("data") + else: + for i in range(self.query_net_num_inputs): + inputNames.append("data" + str(i)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + net = mx.gluon.nn.SymbolBlock.imports(self.query_net_dir + symbolFile, inputNames, self.query_net_dir + weightFile, ctx=context[0]) + net.hybridize() + return net + + def save_memory(self, path): + mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)] + mem_dict = {entry[0]:entry[1] for entry in mem_arr} + nd.save(path, mem_dict) + + def load_memory(self, path): + mem_dict = nd.load(path) + self.value_memory = [] + self.label_memory = [] + for key in sorted(mem_dict.keys()): + if key == "keys": + self.key_memory = mem_dict[key] + elif key.startswith("values_"): + self.value_memory.append(mem_dict[key]) + elif key.startswith("labels_"): + self.label_memory.append(mem_dict[key]) + + +#Stream 0 class Net_0(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): + def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs): super(Net_0, self).__init__(**kwargs) with self.name_scope(): if data_mean: @@ -130,5 +546,5 @@ class Net_0(gluon.HybridBlock): fc5_ = self.fc5_(relu4_) qvalues_ = F.identity(fc5_) - return qvalues_ + return [[qvalues_]] diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.cpp b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.cpp index f162dd4eaba50d87f2ef491388fc6fc6f34558a5..799093acd890132e2ceb66ecb98280550a8cee62 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.cpp +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.cpp @@ -14,4 +14,4 @@ torcs_agent_network_reward_output torcs_agent_network_reward_executor::execute(t output.reward = instance.reward; return output; -} \ No newline at end of file +} diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.h index 5e42cf3f67f632498ecf0e10debf9f93d6c9fb08..84b105fb0988fde7ad702f32b3d9f7e8a343bbf8 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.h @@ -19,4 +19,4 @@ public: void init(); torcs_agent_network_reward_output execute(torcs_agent_network_reward_input input); }; -#endif \ No newline at end of file +#endif diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.i b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.i index a593d0235eb364e6997c83d6f69daa5b57646645..a04ca846b801ff11dd7933cce9c02c388841f78f 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.i +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/reward/pylib/torcs_agent_network_reward_executor.i @@ -6,4 +6,4 @@ %} %include "armanpy/armanpy.i" -%include "torcs_agent_network_reward_executor.h" \ No newline at end of file +%include "torcs_agent_network_reward_executor.h" diff --git a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/torcs_agent_torcsAgent_actor.h b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/torcs_agent_torcsAgent_actor.h index 14cd12da5e72c7b0eb2997a77c9708a98b960431..6cc687569b01d1c8947b2f1ab56144e3b81e30a6 100644 --- a/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/torcs_agent_torcsAgent_actor.h +++ b/src/test/resources/target_code/gluon/reinforcementModel/torcs_td3/torcs_agent_torcsAgent_actor.h @@ -19,8 +19,10 @@ commands=colvec(3); } void execute(){ vector state_ = CNNTranslator::translate(state); + vector commands_(3); + _predictor_0_.predict(state_, commands_); commands = CNNTranslator::translateToCol(commands_, std::vector {3}); diff --git a/src/test/resources/training_data/episodicMemorySimple/test.h5 b/src/test/resources/training_data/episodicMemorySimple/test.h5 new file mode 100644 index 0000000000000000000000000000000000000000..e7a7d697033b4c691f92dbf780f87c1261bd4dd6 Binary files /dev/null and b/src/test/resources/training_data/episodicMemorySimple/test.h5 differ diff --git a/src/test/resources/training_data/episodicMemorySimple/train.h5 b/src/test/resources/training_data/episodicMemorySimple/train.h5 new file mode 100644 index 0000000000000000000000000000000000000000..e7a7d697033b4c691f92dbf780f87c1261bd4dd6 Binary files /dev/null and b/src/test/resources/training_data/episodicMemorySimple/train.h5 differ