Commit 2ae0c453 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'ba_mann' into 'master'

Merge ba_mann into master

See merge request !46
parents 1c37d990 47860bdf
Pipeline #553487 passed with stages
in 1 minute and 16 seconds
<?xml version="1.0" encoding="UTF-8"?>
<!-- (c) https://github.com/MontiCore/monticore -->
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>
......@@ -9,15 +9,14 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.4.10-SNAPSHOT</version>
<version>0.4.11-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch2X.version>0.4.9-SNAPSHOT</CNNArch2X.version>
<CNNArch2X.version>0.4.10-SNAPSHOT</CNNArch2X.version>
<EMADL2PythonWrapper.version>0.0.3-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
......
......@@ -46,6 +46,8 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.DOT_PRODUCT_SELF_ATTENTION_NAME);
supportedLayerList.add(AllPredefinedLayers.LOAD_NETWORK_NAME);
supportedLayerList.add(AllPredefinedLayers.LAYERNORM_NAME);
supportedLayerList.add(AllPredefinedLayers.UP_CONVOLUTION3D_NAME);
supportedLayerList.add(AllPredefinedLayers.CONVOLUTION3D_NAME);
supportedLayerList.add(AllPredefinedLayers.AdaNet_Name);
}
......
......@@ -234,12 +234,15 @@ class ${tc.fileNameWithoutEnding}:
qnet_trainer = mx.gluon.Trainer(q_net.collect_params(), discriminator_optimizer, discriminator_optimizer_params)
dis_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
dis_loss.hybridize()
if not generator_loss == None:
if generator_loss == "l2":
generator_loss_func = mx.gluon.loss.L2Loss()
generator_loss_func.hybridize()
elif generator_loss == "l1":
generator_loss_func = mx.gluon.loss.L1Loss()
generator_loss_func.hybridize()
else:
logging.error("Invalid generator loss parameter")
......@@ -328,6 +331,9 @@ class ${tc.fileNameWithoutEnding}:
gen_net.export(self.parameter_path_gen() + '_newest', epoch=0)
dis_net.save_parameters(self.parameter_path_dis() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
dis_net.export(self.parameter_path_dis() + '_newest', epoch=0)
if not generator_loss == None:
generator_loss_func.export(self.parameter_path_gen() + '_newest_loss', epoch=0)
dis_loss.export(self.parameter_path_dis() + '_newest_loss', epoch=0)
def parameter_path_gen(self):
return self._net_creator_gen._model_dir_ + self._net_creator_gen._model_prefix_ + '_' + str(0)
......
......@@ -105,10 +105,17 @@ public:
}
}
//Load Loss
loss_json_path = file_prefix + "_loss-symbol.json";
loss_param_path = file_prefix + "_loss-0000.params";
loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map);
//Check if loss files exists. If not, they arent necessary and are skipped
std::ifstream f(file_prefix + "_loss-symbol.json");
if (f.good()){
//Load Loss
loss_json_path = file_prefix + "_loss-symbol.json";
loss_param_path = file_prefix + "_loss-0000.params";
loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map);
}
else {
std::cerr << "Can't open the file. Please check if " << file_prefix << "_loss-symbol.json" << exists". \n";
}
NDArray::WaitAll();
}
......
<#-- (c) https://github.com/MontiCore/monticore -->
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = gluon.nn.Conv3D(channels=${element.channels?c},
kernel_size=(${tc.join(element.kernel, ",")}),
strides=(${tc.join(element.stride, ",")}),
use_bias=${element.noBias?string("False", "True")},
padding=(${tc.join(element.padding, ",")}))
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
<#-- (c) https://github.com/MontiCore/monticore -->
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
<#if element.padding??>
self.${element.name}padding = (${tc.join(element.transPadding, ",")})
</#if>
self.${element.name} = gluon.nn.Conv3DTranspose(channels=${element.channels?c},
kernel_size=(${tc.join(element.kernel, ",")}),
strides=(${tc.join(element.stride, ",")}),
padding=self.${element.name}padding,
use_bias=${element.noBias?string("False", "True")})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
......@@ -123,4 +123,11 @@
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}});
</#if>
</#if>
<#if shape?size == 4>
<#if (output.ioDeclaration.type.domain.isNaturalNumber() || output.ioDeclaration.type.domain.isWholeNumber())>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToIntCube(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[1]?c}, ${shape[2]?c}, ${shape[3]?c}});
<#else>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[1]?c}, ${shape[2]?c}, ${shape[3]?c}});
</#if>
</#if>
</#list>
......@@ -103,10 +103,17 @@ public:
}
}
//Load Loss
loss_json_path = file_prefix + "_loss-symbol.json";
loss_param_path = file_prefix + "_loss-0000.params";
loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map);
//Check if loss files exists. If not, they arent necessary and are skipped
std::ifstream f(file_prefix + "_loss-symbol.json");
if (f.good()){
//Load Loss
loss_json_path = file_prefix + "_loss-symbol.json";
loss_param_path = file_prefix + "_loss-0000.params";
loadComponent(loss_json_path, loss_param_path, loss_symbol, loss_param_map);
}
else {
std::cerr << "Can't open the file. Please check if " << file_prefix << "_loss-symbol.json" << exists". \n";
}
NDArray::WaitAll();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment