Commit 68ee368d authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'develop' into 'master'

Develop

See merge request !30
parents 6c1c21bd d157159f
Pipeline #325057 passed with stage
in 4 minutes and 4 seconds
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId> <artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.11-SNAPSHOT</version> <version>0.2.12-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.5-SNAPSHOT</CNNArch.version> <CNNArch.version>0.3.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.10-SNAPSHOT</CNNTrain.version> <CNNTrain.version>0.3.12-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.6-SNAPSHOT</CNNArch2X.version> <CNNArch2X.version>0.0.7-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator> <embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version> <EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
......
...@@ -71,15 +71,18 @@ public class CNNArch2Gluon extends CNNArchGenerator { ...@@ -71,15 +71,18 @@ public class CNNArch2Gluon extends CNNArchGenerator {
temp = controller.process("execute", Target.CPP); temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue()); fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = controller.process("CNNBufferFile", Target.CPP); temp = controller.process("CNNModelLoader", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue()); fileContentMap.put("CNNModelLoader.h", temp.getValue());
return fileContentMap; return fileContentMap;
} }
private Map<String, String> compileFileContentMap(ArchitectureSymbol architecture) { private Map<String, String> compileFileContentMap(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration(); TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
architecture.processForEpisodicReplayMemory();
Map<String, String> fileContentMap = new HashMap<>(); Map<String, String> fileContentMap = new HashMap<>();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController( CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration); architecture, templateConfiguration);
......
...@@ -41,6 +41,10 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker { ...@@ -41,6 +41,10 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.BROADCAST_ADD_NAME); supportedLayerList.add(AllPredefinedLayers.BROADCAST_ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME); supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME);
// supportedLayerList.add(AllPredefinedLayers.CROP_NAME); // supportedLayerList.add(AllPredefinedLayers.CROP_NAME);
supportedLayerList.add(AllPredefinedLayers.LARGE_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.EPISODIC_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.DOT_PRODUCT_SELF_ATTENTION_NAME);
supportedLayerList.add(AllPredefinedLayers.LOAD_NETWORK_NAME);
} }
} }
...@@ -87,6 +87,17 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -87,6 +87,17 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
setCurrentElement(previousElement); setCurrentElement(previousElement);
} }
public void include(SerialCompositeElementSymbol compositeElement, Integer episodicSubNetIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getEpisodicSubNetworks().get(episodicSubNetIndex)){
include(element, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){ public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){
if (architectureElement instanceof CompositeElementSymbol){ if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode); include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode);
...@@ -106,6 +117,10 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -106,6 +117,10 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode)); include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode));
} }
public void include(ArchitectureElementSymbol architectureElementSymbol, Integer episodicSubNetIndex, String netDefinitionMode) {
include(architectureElementSymbol, episodicSubNetIndex, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElement, NetDefinitionMode netDefinitionMode){ public void include(ArchitectureElementSymbol architectureElement, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){ if (getWriter() == null){
throw new IllegalStateException("missing writer"); throw new IllegalStateException("missing writer");
...@@ -113,10 +128,21 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -113,10 +128,21 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElement, getWriter(), netDefinitionMode); include(architectureElement, getWriter(), netDefinitionMode);
} }
public void include(ArchitectureElementSymbol architectureElement, Integer episodicSubNetIndex, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include((SerialCompositeElementSymbol) architectureElement, episodicSubNetIndex, getWriter(), netDefinitionMode);
}
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream, boolean outputAsArray) { public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
return getStreamInputs(stream, outputAsArray).keySet(); return getStreamInputs(stream, outputAsArray).keySet();
} }
public Set<String> getSubnetInputNames(List<ArchitectureElementSymbol> subNet) {
return getSubnetInputs(subNet).keySet();
}
public ArrayList<String> getStreamInputVariableNames(SerialCompositeElementSymbol stream, boolean outputAsArray) { public ArrayList<String> getStreamInputVariableNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
ArrayList<String> inputVariableNames = new ArrayList<String>(); ArrayList<String> inputVariableNames = new ArrayList<String>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) { for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
...@@ -226,6 +252,29 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -226,6 +252,29 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return outputNames; return outputNames;
} }
public Set<String> getSubnetOutputNames(List<ArchitectureElementSymbol> subNet){
Set<String> outputNames = new LinkedHashSet<>();
for (ArchitectureElementSymbol element : subNet.get(subNet.size()-1).getLastAtomicElements()) {
String name = getName(element);
outputNames.add(name);
}
return outputNames;
}
public int getSubnetOutputSize(List<ArchitectureElementSymbol> subNet){
int outputSize = 0;
for (ArchitectureElementSymbol element : subNet.get(subNet.size()-1).getLastAtomicElements()) {
outputSize += element.getOutputTypes().size();
}
if(outputSize == 0){
outputSize = 1;
}
return outputSize;
}
public List<String> getUnrollOutputNames(UnrollInstructionSymbol unroll, String variable) { public List<String> getUnrollOutputNames(UnrollInstructionSymbol unroll, String variable) {
List<String> outputNames = new LinkedList<>(getStreamOutputNames(unroll.getBody(), true)); List<String> outputNames = new LinkedList<>(getStreamOutputNames(unroll.getBody(), true));
Map<String, String> pairs = getUnrollPairs(unroll.getBody(), unroll.getResolvedBodies().get(0), variable); Map<String, String> pairs = getUnrollPairs(unroll.getBody(), unroll.getResolvedBodies().get(0), variable);
...@@ -403,6 +452,31 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -403,6 +452,31 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return inputs; return inputs;
} }
public Map<String, List<String>> getSubnetInputs(List<ArchitectureElementSymbol> subNet) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : subNet.get(0).getFirstAtomicElements()) {
if (element instanceof ConstantSymbol) {
inputs.put(getName(element), Arrays.asList("1"));
}
else {
List<Integer> intDimensions = element.getOutputTypes().get(0).getDimensions();
List<String> dimensions = new ArrayList<>();
for (Integer intDimension : intDimensions) {
dimensions.add(intDimension.toString());
}
String name = getName(element);
inputs.put(name, dimensions);
}
}
return inputs;
}
public Map<String, List<String>> getStreamOutputs(SerialCompositeElementSymbol stream, boolean outputAsArray) { public Map<String, List<String>> getStreamOutputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
Map<String, List<String>> outputs = new LinkedHashMap<>(); Map<String, List<String>> outputs = new LinkedHashMap<>();
...@@ -519,6 +593,15 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -519,6 +593,15 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return dimensions; return dimensions;
} }
public List<Integer> cutDimensionsInteger(List<Integer> dimensions) {
while (dimensions.size() > 1 && dimensions.get(dimensions.size() - 1).equals(1)) {
dimensions.remove(dimensions.size() - 1);
}
return dimensions;
}
public boolean hasUnrollInstructions() { public boolean hasUnrollInstructions() {
for (NetworkInstructionSymbol networkInstruction : getArchitecture().getNetworkInstructions()) { for (NetworkInstructionSymbol networkInstruction : getArchitecture().getNetworkInstructions()) {
if (networkInstruction.isUnroll()) { if (networkInstruction.isUnroll()) {
......
...@@ -130,9 +130,27 @@ public class CNNTrain2Gluon extends CNNTrainGenerator { ...@@ -130,9 +130,27 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
Map<String, String> fileContentMap = new HashMap<>(); Map<String, String> fileContentMap = new HashMap<>();
//Context Information and Optimizer for local adaption during prediction for replay memory layer (the second only applicaple for supervised learning)
String cnnTrainLAOptimizerTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNLAOptimizer.ftl");
fileContentMap.put("CNNLAOptimizer_" + getInstanceName() + ".h", cnnTrainLAOptimizerTemplateContent);
//AdamW optimizer if used for training
if(configuration.getOptimizer() != null) {
String optimizerName = configuration.getOptimizer().getName();
Optional<OptimizerSymbol> criticOptimizer = configuration.getCriticOptimizer();
String criticOptimizerName = "";
if (criticOptimizer.isPresent()) {
criticOptimizerName = criticOptimizer.get().getName();
}
if (optimizerName.equals("adamw") || criticOptimizerName.equals("adamw")) {
String adamWContent = templateConfiguration.processTemplate(ftlContext, "Optimizer/AdamW.ftl");
fileContentMap.put("AdamW.py", adamWContent);
}
}
if (configData.isSupervisedLearning()) { if (configData.isSupervisedLearning()) {
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"); String cnnTrainTrainerTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent); fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTrainerTemplateContent);
} else if (configData.isGan()) { } else if (configData.isGan()) {
final String trainerName = "CNNTrainer_" + getInstanceName(); final String trainerName = "CNNTrainer_" + getInstanceName();
if (!configuration.getDiscriminatorNetwork().isPresent()) { if (!configuration.getDiscriminatorNetwork().isPresent()) {
...@@ -189,7 +207,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator { ...@@ -189,7 +207,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final String ganTrainerContent = templateConfiguration.processTemplate(ftlContext, "gan/Trainer.ftl"); final String ganTrainerContent = templateConfiguration.processTemplate(ftlContext, "gan/Trainer.ftl");
fileContentMap.put(trainerName + ".py", ganTrainerContent); fileContentMap.put(trainerName + ".py", ganTrainerContent);
} else if (configData.isReinforcementLearning()) { } else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName(); final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm(); final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
......
...@@ -6,6 +6,7 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator; ...@@ -6,6 +6,7 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator;
*/ */
public enum NetDefinitionMode { public enum NetDefinitionMode {
ARCHITECTURE_DEFINITION, ARCHITECTURE_DEFINITION,
PREDICTION_PARAMETER,
FORWARD_FUNCTION; FORWARD_FUNCTION;
public static NetDefinitionMode fromString(final String netDefinitionMode) { public static NetDefinitionMode fromString(final String netDefinitionMode) {
...@@ -14,6 +15,8 @@ public enum NetDefinitionMode { ...@@ -14,6 +15,8 @@ public enum NetDefinitionMode {
return ARCHITECTURE_DEFINITION; return ARCHITECTURE_DEFINITION;
case "FORWARD_FUNCTION": case "FORWARD_FUNCTION":
return FORWARD_FUNCTION; return FORWARD_FUNCTION;
case "PREDICTION_PARAMETER":
return PREDICTION_PARAMETER;
default: default:
throw new IllegalArgumentException("Unknown Net Definition Mode"); throw new IllegalArgumentException("Unknown Net Definition Mode");
} }
......
<#-- (c) https://github.com/MontiCore/monticore -->
#ifndef CNNBUFFERFILE_H
#define CNNBUFFERFILE_H
#include <stdio.h>
#include <iostream>
#include <fstream>
// Read file to buffer
class BufferFile {
public :
std::string file_path_;
int length_;
char* buffer_;
explicit BufferFile(std::string file_path)
:file_path_(file_path) {
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
std::cerr << "Can't open the file. Please check " << file_path << ". \n";
length_ = 0;
buffer_ = NULL;
return;
}
ifs.seekg(0, std::ios::end);
length_ = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
buffer_ = new char[sizeof(char) * length_];
ifs.read(buffer_, length_);
ifs.close();
}
int GetLength() {
return length_;
}
char* GetBuffer() {
return buffer_;
}
~BufferFile() {
if (buffer_) {
delete[] buffer_;
buffer_ = NULL;
}
}
};
#endif // CNNBUFFERFILE_H
...@@ -3,6 +3,8 @@ import mxnet as mx ...@@ -3,6 +3,8 @@ import mxnet as mx
import logging import logging
import os import os
import shutil import shutil
import warnings
import inspect
<#list tc.architecture.networkInstructions as networkInstruction> <#list tc.architecture.networkInstructions as networkInstruction>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index} from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
...@@ -27,6 +29,10 @@ class ${tc.fileNameWithoutEnding}: ...@@ -27,6 +29,10 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items(): for i, network in self.networks.items():
lastEpoch = 0 lastEpoch = 0
param_file = None param_file = None
if hasattr(network, 'episodic_sub_nets'):
num_episodic_sub_nets = len(network.episodic_sub_nets)
lastMemEpoch = [0]*num_episodic_sub_nets
mem_files = [None]*num_episodic_sub_nets
try: try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params") os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
...@@ -37,22 +43,77 @@ class ${tc.fileNameWithoutEnding}: ...@@ -37,22 +43,77 @@ class ${tc.fileNameWithoutEnding}:
except OSError: except OSError:
pass pass
if hasattr(network, 'episodic_sub_nets'):
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
except OSError:
pass
for j in range(len(network.episodic_sub_nets)):
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
except OSError:
pass
if os.path.isdir(self._model_dir_): if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_): for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file: if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr) epoch = int(epochStr)
if epoch > lastEpoch: if epoch >= lastEpoch:
lastEpoch = epoch lastEpoch = epoch
param_file = file param_file = file
elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
memSubNet = int(relMemPathInfo[0])
memEpochStr = relMemPathInfo[1]
memEpoch = int(memEpochStr)
if memEpoch >= lastMemEpoch[memSubNet-1]:
lastMemEpoch[memSubNet-1] = memEpoch
mem_files[memSubNet-1] = file
if param_file is None: if param_file is None:
earliestLastEpoch = 0 earliestLastEpoch = 0
else: else:
logging.info("Loading checkpoint: " + param_file) logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file) network.load_parameters(self._model_dir_ + param_file)
if hasattr(network, 'episodic_sub_nets'):
for j, sub_net in enumerate(network.episodic_sub_nets):
if mem_files[j] != None:
logging.info("Loading Replay Memory: " + mem_files[j])
mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
mem_layer.load_memory(self._model_dir_ + mem_files[j])
if earliestLastEpoch == None or lastEpoch < earliestLastEpoch: if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
earliestLastEpoch = lastEpoch earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch return earliestLastEpoch
...@@ -63,28 +124,56 @@ class ${tc.fileNameWithoutEnding}: ...@@ -63,28 +124,56 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items(): for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params" # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None param_file = None
if hasattr(network, 'episodic_sub_nets'):
num_episodic_sub_nets = len(network.episodic_sub_nets)
lastMemEpoch = [0] * num_episodic_sub_nets
mem_files = [None] * num_episodic_sub_nets
if os.path.isdir(self._weights_dir_): if os.path.isdir(self._weights_dir_):
lastEpoch = 0 lastEpoch = 0
for file in os.listdir(self._weights_dir_): for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file: if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","") epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr) epoch = int(epochStr)
if epoch > lastEpoch: if epoch >= lastEpoch:
lastEpoch = epoch lastEpoch = epoch
param_file = file param_file = file
elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
memSubNet = int(relMemPathInfo[0])
memEpochStr = relMemPathInfo[1]
memEpoch = int(memEpochStr)
if memEpoch >= lastMemEpoch[memSubNet-1]:
lastMemEpoch[memSubNet-1] = memEpoch
mem_files[memSubNet-1] = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file) logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True) network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
if hasattr(network, 'episodic_sub_nets'):
assert lastEpoch == lastMemEpoch
for j, sub_net in enumerate(network.episodic_sub_nets):
if mem_files[j] != None:
logging.info("Loading pretrained Replay Memory: " + mem_files[j])
mem_layer = \
[param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
param[0].startswith("memory")][0][1]
mem_layer.load_memory(self._model_dir_ + mem_files[j])
else: else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file) logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
<#list tc.architecture.networkInstructions as networkInstruction> <#list tc.architecture.networkInstructions as networkInstruction>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std) self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context) with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
self.networks[${networkInstruction?index}].hybridize() self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context)<#sep>, </#list>) self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
<#if networkInstruction.body.episodicSubNetworks?has_content>
self.networks[0].episodicsubnet0_(<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
</#if>
</#list> </#list>
if not os.path.exists(self._model_dir_): if not os.path.exists(self._model_dir_):
......
...@@ -184,16 +184,16 @@ class ${tc.fileNameWithoutEnding}: ...@@ -184,16 +184,16 @@ class ${tc.fileNameWithoutEnding}:
del discriminator_optimizer_params['learning_rate_decay'] del discriminator_optimizer_params['learning_rate_decay']
if normalize: