Commit 68ee368d authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Develop

See merge request !30
parents 6c1c21bd d157159f
Pipeline #325057 passed with stage
in 4 minutes and 4 seconds
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.11-SNAPSHOT</version>
<version>0.2.12-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,9 +17,9 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.5-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.10-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.6-SNAPSHOT</CNNArch2X.version>
<CNNArch.version>0.3.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.12-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.7-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
......
......@@ -71,15 +71,18 @@ public class CNNArch2Gluon extends CNNArchGenerator {
temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = controller.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
temp = controller.process("CNNModelLoader", Target.CPP);
fileContentMap.put("CNNModelLoader.h", temp.getValue());
return fileContentMap;
}
private Map<String, String> compileFileContentMap(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
architecture.processForEpisodicReplayMemory();
Map<String, String> fileContentMap = new HashMap<>();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
......
......@@ -41,6 +41,10 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.BROADCAST_ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME);
// supportedLayerList.add(AllPredefinedLayers.CROP_NAME);
supportedLayerList.add(AllPredefinedLayers.LARGE_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.EPISODIC_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.DOT_PRODUCT_SELF_ATTENTION_NAME);
supportedLayerList.add(AllPredefinedLayers.LOAD_NETWORK_NAME);
}
}
......@@ -87,6 +87,17 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
setCurrentElement(previousElement);
}
public void include(SerialCompositeElementSymbol compositeElement, Integer episodicSubNetIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getEpisodicSubNetworks().get(episodicSubNetIndex)){
include(element, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode);
......@@ -106,6 +117,10 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElementSymbol, Integer episodicSubNetIndex, String netDefinitionMode) {
include(architectureElementSymbol, episodicSubNetIndex, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElement, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
......@@ -113,10 +128,21 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElement, getWriter(), netDefinitionMode);
}
public void include(ArchitectureElementSymbol architectureElement, Integer episodicSubNetIndex, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include((SerialCompositeElementSymbol) architectureElement, episodicSubNetIndex, getWriter(), netDefinitionMode);
}
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
return getStreamInputs(stream, outputAsArray).keySet();
}
public Set<String> getSubnetInputNames(List<ArchitectureElementSymbol> subNet) {
return getSubnetInputs(subNet).keySet();
}
public ArrayList<String> getStreamInputVariableNames(SerialCompositeElementSymbol stream, boolean outputAsArray) {
ArrayList<String> inputVariableNames = new ArrayList<String>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
......@@ -226,6 +252,29 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return outputNames;
}
public Set<String> getSubnetOutputNames(List<ArchitectureElementSymbol> subNet){
Set<String> outputNames = new LinkedHashSet<>();
for (ArchitectureElementSymbol element : subNet.get(subNet.size()-1).getLastAtomicElements()) {
String name = getName(element);
outputNames.add(name);
}
return outputNames;
}
public int getSubnetOutputSize(List<ArchitectureElementSymbol> subNet){
int outputSize = 0;
for (ArchitectureElementSymbol element : subNet.get(subNet.size()-1).getLastAtomicElements()) {
outputSize += element.getOutputTypes().size();
}
if(outputSize == 0){
outputSize = 1;
}
return outputSize;
}
public List<String> getUnrollOutputNames(UnrollInstructionSymbol unroll, String variable) {
List<String> outputNames = new LinkedList<>(getStreamOutputNames(unroll.getBody(), true));
Map<String, String> pairs = getUnrollPairs(unroll.getBody(), unroll.getResolvedBodies().get(0), variable);
......@@ -403,6 +452,31 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return inputs;
}
public Map<String, List<String>> getSubnetInputs(List<ArchitectureElementSymbol> subNet) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : subNet.get(0).getFirstAtomicElements()) {
if (element instanceof ConstantSymbol) {
inputs.put(getName(element), Arrays.asList("1"));
}
else {
List<Integer> intDimensions = element.getOutputTypes().get(0).getDimensions();
List<String> dimensions = new ArrayList<>();
for (Integer intDimension : intDimensions) {
dimensions.add(intDimension.toString());
}
String name = getName(element);
inputs.put(name, dimensions);
}
}
return inputs;
}
public Map<String, List<String>> getStreamOutputs(SerialCompositeElementSymbol stream, boolean outputAsArray) {
Map<String, List<String>> outputs = new LinkedHashMap<>();
......@@ -519,6 +593,15 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return dimensions;
}
public List<Integer> cutDimensionsInteger(List<Integer> dimensions) {
while (dimensions.size() > 1 && dimensions.get(dimensions.size() - 1).equals(1)) {
dimensions.remove(dimensions.size() - 1);
}
return dimensions;
}
public boolean hasUnrollInstructions() {
for (NetworkInstructionSymbol networkInstruction : getArchitecture().getNetworkInstructions()) {
if (networkInstruction.isUnroll()) {
......
......@@ -130,9 +130,27 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
Map<String, String> fileContentMap = new HashMap<>();
//Context Information and Optimizer for local adaption during prediction for replay memory layer (the second only applicaple for supervised learning)
String cnnTrainLAOptimizerTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNLAOptimizer.ftl");
fileContentMap.put("CNNLAOptimizer_" + getInstanceName() + ".h", cnnTrainLAOptimizerTemplateContent);
//AdamW optimizer if used for training
if(configuration.getOptimizer() != null) {
String optimizerName = configuration.getOptimizer().getName();
Optional<OptimizerSymbol> criticOptimizer = configuration.getCriticOptimizer();
String criticOptimizerName = "";
if (criticOptimizer.isPresent()) {
criticOptimizerName = criticOptimizer.get().getName();
}
if (optimizerName.equals("adamw") || criticOptimizerName.equals("adamw")) {
String adamWContent = templateConfiguration.processTemplate(ftlContext, "Optimizer/AdamW.ftl");
fileContentMap.put("AdamW.py", adamWContent);
}
}
if (configData.isSupervisedLearning()) {
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
String cnnTrainTrainerTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTrainerTemplateContent);
} else if (configData.isGan()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
if (!configuration.getDiscriminatorNetwork().isPresent()) {
......@@ -189,7 +207,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final String ganTrainerContent = templateConfiguration.processTemplate(ftlContext, "gan/Trainer.ftl");
fileContentMap.put(trainerName + ".py", ganTrainerContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
......
......@@ -6,6 +6,7 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator;
*/
public enum NetDefinitionMode {
ARCHITECTURE_DEFINITION,
PREDICTION_PARAMETER,
FORWARD_FUNCTION;
public static NetDefinitionMode fromString(final String netDefinitionMode) {
......@@ -14,6 +15,8 @@ public enum NetDefinitionMode {
return ARCHITECTURE_DEFINITION;
case "FORWARD_FUNCTION":
return FORWARD_FUNCTION;
case "PREDICTION_PARAMETER":
return PREDICTION_PARAMETER;
default:
throw new IllegalArgumentException("Unknown Net Definition Mode");
}
......
<#-- (c) https://github.com/MontiCore/monticore -->
#ifndef CNNBUFFERFILE_H
#define CNNBUFFERFILE_H
#include <stdio.h>
#include <iostream>
#include <fstream>
// Read file to buffer
class BufferFile {
public :
std::string file_path_;
int length_;
char* buffer_;
explicit BufferFile(std::string file_path)
:file_path_(file_path) {
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
std::cerr << "Can't open the file. Please check " << file_path << ". \n";
length_ = 0;
buffer_ = NULL;
return;
}
ifs.seekg(0, std::ios::end);
length_ = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
buffer_ = new char[sizeof(char) * length_];
ifs.read(buffer_, length_);
ifs.close();
}
int GetLength() {
return length_;
}
char* GetBuffer() {
return buffer_;
}
~BufferFile() {
if (buffer_) {
delete[] buffer_;
buffer_ = NULL;
}
}
};
#endif // CNNBUFFERFILE_H
......@@ -184,16 +184,16 @@ class ${tc.fileNameWithoutEnding}:
del discriminator_optimizer_params['learning_rate_decay']
if normalize:
self._net_creator_dis.construct(mx_context, data_mean=data_mean, data_std=data_std)
self._net_creator_dis.construct([mx_context], data_mean=data_mean, data_std=data_std)
else:
self._net_creator_dis.construct(mx_context)
self._net_creator_dis.construct([mx_context])
self._net_creator_gen.construct(mx_context)
self._net_creator_gen.construct([mx_context])
if self.use_qnet:
self._net_creator_qnet.construct(mx_context)
self._net_creator_qnet.construct([mx_context])
if load_checkpoint:
self._net_creator_qnet.load(mx_context)
self._net_creator_qnet.load([mx_context])
else:
if os.path.isdir(self._net_creator_qnet._model_dir_):
shutil.rmtree(self._net_creator_qnet._model_dir_)
......@@ -206,8 +206,8 @@ class ${tc.fileNameWithoutEnding}:
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator_dis.load(mx_context)
self._net_creator_gen.load(mx_context)
begin_epoch = self._net_creator_dis.load([mx_context])
self._net_creator_gen.load([mx_context])
else:
if os.path.isdir(self._net_creator_dis._model_dir_):
shutil.rmtree(self._net_creator_dis._model_dir_)
......@@ -255,9 +255,9 @@ class ${tc.fileNameWithoutEnding}:
gen_input, exp_qnet_output = create_generator_input(batch)
with autograd.record():
fake_data = gen_net(*gen_input)
fake_data = gen_net(*gen_input)[0][0]
fake_data.detach()
discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)
discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_fake_dis, _ = discriminated_fake_dis
......@@ -265,7 +265,7 @@ class ${tc.fileNameWithoutEnding}:
real_labels = mx.nd.ones(discriminated_fake_dis.shape, ctx=mx_context)
loss_resultF = dis_loss(discriminated_fake_dis, fake_labels)
discriminated_real_dis = dis_net(real_data, *dis_conditional_input)
discriminated_real_dis = dis_net(real_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_real_dis, _ = discriminated_real_dis
loss_resultR = dis_loss(discriminated_real_dis, real_labels)
......@@ -276,8 +276,8 @@ class ${tc.fileNameWithoutEnding}:
if batch_i % k_value == 0:
with autograd.record():
fake_data = gen_net(*gen_input)
discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)
fake_data = gen_net(*gen_input)[0][0]
discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)[0][0]
if self.use_qnet:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = dis_loss(discriminated_fake_gen, real_labels)
......@@ -285,7 +285,7 @@ class ${tc.fileNameWithoutEnding}:
condition = batch.data[traindata_to_index[generator_target_name + "_"]]
loss_resultG = loss_resultG + gen_loss_weight * generator_loss_func(fake_data, condition)
if self.use_qnet:
qnet_discriminated = [q_net(features)]
qnet_discriminated = [q_net(features)[0][0]]
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultG.backward()
......
<#-- (c) https://github.com/MontiCore/monticore -->
<#list configurations as config>
#ifndef CNNLAOPTIMIZER_${config.instanceName?upper_case}
#define CNNLAOPTIMIZER_${config.instanceName?upper_case}
#include <mxnet-cpp/MxNetCpp.h>
#include <string>
#include <vector>
#include <memory>
using namespace mxnet::cpp;
class CNNLAOptimizer_${config.instanceName}{
private:
Optimizer *optimizerHandle;
<#if (config.context)??>
std::string context_name = "${config.context}";
<#else>
std::string context_name = "cpu";
</#if>
public:
explicit CNNLAOptimizer_${config.instanceName}(){
<#if (config.configuration.optimizer)??>
<#if config.optimizerName == "adamw">
optimizerHandle = OptimizerRegistry::Find("adam");
<#else>
optimizerHandle = OptimizerRegistry::Find("${config.optimizerName}");
</#if>
<#list config.optimizerParams?keys as param>
<#if param == "learning_rate">
optimizerHandle->SetParam("lr", ${config.optimizerParams[param]});
<#elseif param == "weight_decay">
optimizerHandle->SetParam("wd", ${config.optimizerParams[param]});
<#elseif param == "learning_rate_decay">
<#assign learningRateDecay = config.optimizerParams[param]>
<#elseif param == "learning_rate_minimum">
<#assign minLearningRate = config.optimizerParams[param]>
<#elseif param == "step_size">
<#assign stepSize = config.optimizerParams[param]>
<#else>
optimizerHandle->SetParam("${param}", ${config.optimizerParams[param]});
</#if>
</#list>
<#if (learningRateDecay)?? && (stepSize)??>
<#if !(minLearningRate)??>
<#assign minLearningRate = "1e-08">
</#if>
std::unique_ptr<LRScheduler> lrScheduler(new FactorScheduler(${stepSize}, ${learningRateDecay}, ${minLearningRate}));
optimizerHandle->SetLRScheduler(std::move(lrScheduler));
</#if>
<#else>
optimizerHandle = OptimizerRegistry::Find("adam");
optimizerHandle->SetParam("lr", 0.001);
</#if>
<#if (config.clipGlobalGradNorm)??>
//clip_global_grad_norm=${config.clipGlobalGradNorm},
</#if>
}
Optimizer *getOptimizer(){
return optimizerHandle;
}
std::string getContextName(){
return context_name;
}
};
#endif // CNNLAOPTIMIZER_${config.instanceName?upper_case}
</#list>
#ifndef CNNMODELLOADER
#define CNNMODELLOADER
#include <mxnet-cpp/MxNetCpp.h>
#include <stdio.h>
#include <iostream>
#include <fstream>
using namespace mxnet::cpp;
// Read files to load moddel symbol and parameters
class ModelLoader {
private:
Context ctx = Context::cpu();
std::vector<Symbol> network_symbol_list;
std::vector<std::map<std::string, NDArray>> network_param_map_list;
std::vector<Symbol> query_symbol_list;
std::vector<std::map<std::string, NDArray>> query_param_map_list;
std::vector<std::map<std::string, NDArray>> replay_memory;
std::vector<Symbol> loss_symbol;
std::vector<std::map<std::string, NDArray>> loss_param_map;
void checkFile(std::string file_path){
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
std::cerr << "Can't open the file. Please check " << file_path << ". \n";
return;
}
int length_;
ifs.seekg(0, std::ios::end);
length_ = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
ifs.close();
}
void loadComponent(std::string json_path,
std::string param_path,
std::vector<Symbol> &symbols_list,
std::vector<std::map<std::string, NDArray>> &param_map_list){
checkFile(json_path);
symbols_list.push_back(Symbol::Load(json_path));
checkFile(param_path);
std::map<std::string, NDArray> params;
NDArray::Load(param_path, 0, &params);
param_map_list.push_back(processParamMap(params));
}
std::map<std::string, NDArray> processParamMap(std::map<std::string, NDArray> param_map){
std::map<std::string, NDArray> processed_param_map;
if(!param_map.empty()){
for (const auto &pair : param_map) {
std::string name = pair.first.substr(4); //the first four letters would be the type (arg: or aux:, but we don't have aux parameters? <- need to make sure)
processed_param_map[name] = pair.second.Copy(ctx);
}
}
return processed_param_map;
}
public:
explicit ModelLoader(std::string file_prefix, mx_uint num_subnets, Context ctx_param){
ctx = ctx_param;
std::string network_json_path;
std::string network_param_path;
std::string query_json_path;
std::string query_param_path;
std::string memory_path;
std::string loss_json_path;
std::string loss_param_path;
//Load network
if(!num_subnets){
network_json_path = file_prefix + "-symbol.json";
network_param_path = file_prefix + "-0000.params";
loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list);
}else{
for(int i=0; i < num_subnets; i++){
network_json_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-symbol.json";
network_param_path = file_prefix + "_episodic_sub_net_" + std::to_string(i) + "-0000.params";
loadComponent(network_json_path, network_param_path, network_symbol_list, network_param_map_list);
if(i >= 1){
query_json_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-symbol.json";
query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params";
loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list);
memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000";
checkFile(memory_path);
std::map<std::string, NDArray> mem_map = NDArray::LoadToMap(memory_path);
for(auto &mem : mem_map){
mem.second = mem.second.Copy(ctx);
}
replay_memory.push_back(mem_map);
}
}
}
//Load Loss
loss_json_path = file_prefix + "_loss-symbol.json";
loss_param_path = file_prefix + "_loss-0000.params";
loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map);
NDArray::WaitAll();
}
std::vector<Symbol> GetNetworkSymbols() {
return network_symbol_list;
}
std::vector<std::map<std::string, NDArray>> GetNetworkParamMaps() {
return network_param_map_list;
}
Symbol GetLoss() {
return loss_symbol[0];
}
std::map<std::string, NDArray> GetLossParamMap() {
return loss_param_map[0];
}
std::vector<Symbol> GetQuerySymbols() {
return query_symbol_list;
}
std::vector<std::map<std::string, NDArray>> GetQueryParamMaps() {
return query_param_map_list;
}
std::vector<std::map<std::string, NDArray>> GetReplayMemory(){
return replay_memory;
}