Commit af72aeef authored by danielkisov's avatar danielkisov
Browse files

adding custom layer template and new net definition mode for validation

parent b5689ec9
......@@ -7,7 +7,8 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator;
public enum NetDefinitionMode {
ARCHITECTURE_DEFINITION,
PREDICTION_PARAMETER,
FORWARD_FUNCTION;
FORWARD_FUNCTION,
PARAMETER_VALIDATION;
public static NetDefinitionMode fromString(final String netDefinitionMode) {
switch(netDefinitionMode) {
......@@ -17,6 +18,8 @@ public enum NetDefinitionMode {
return FORWARD_FUNCTION;
case "PREDICTION_PARAMETER":
return PREDICTION_PARAMETER;
case "PARAMETER_VALIDATION":
return PARAMETER_VALIDATION;
default:
throw new IllegalArgumentException("Unknown Net Definition Mode");
}
......
......@@ -5,6 +5,10 @@ import os
import shutil
import warnings
import inspect
import sys
sys.path.insert(1, '${tc.architecture.customPyFilesPath}')
from custom_layers import *
<#list tc.architecture.networkInstructions as networkInstruction>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
......@@ -207,3 +211,8 @@ class ${tc.fileNameWithoutEnding}:
</#list>
</#list>
return outputs
def validate_parameters(self):
<#list tc.architecture.networkInstructions as networkInstruction>
${tc.include(networkInstruction.body, "PARAMETER_VALIDATION")}
</#list>
\ No newline at end of file
......@@ -5,8 +5,12 @@ import math
import os
import abc
import warnings
import sys
from mxnet import gluon, nd
sys.path.insert(1, '${tc.architecture.customPyFilesPath}')
from custom_layers import *
class ZScoreNormalization(gluon.HybridBlock):
def __init__(self, data_mean, data_std, **kwargs):
......@@ -557,6 +561,8 @@ ${tc.include(networkInstruction.body, elements?index, "FORWARD_FUNCTION")}
</#list>
class Net_${networkInstruction?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_${networkInstruction?index}, self).__init__(**kwargs)
......@@ -579,6 +585,7 @@ ${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")}
</#if>
pass
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}):
<#if networkInstruction.body.episodicSubNetworks?has_content>
<#list networkInstruction.body.episodicSubNetworks as elements>
......
......@@ -14,7 +14,8 @@ try:
import AdamW
except:
pass
sys.path.insert(1, '${tc.architecture.customPyFilesPath}')
from custom_optimizers import *
class CrossEntropyLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
......
<#-- (c) https://github.com/MontiCore/monticore -->
import logging
import mxnet as mx
<#list configurations as config>
import CNNCreator_${config.instanceName}
import CNNDataLoader_${config.instanceName}
......@@ -15,6 +16,7 @@ if __name__ == "__main__":
<#list configurations as config>
${config.instanceName}_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
${config.instanceName}_creator.validate_parameters()
${config.instanceName}_loader = CNNDataLoader_${config.instanceName}.CNNDataLoader_${config.instanceName}()
${config.instanceName}_trainer = CNNSupervisedTrainer_${config.instanceName}.CNNSupervisedTrainer_${config.instanceName}(
${config.instanceName}_loader,
......
<#-- (c) https://github.com/MontiCore/monticore -->
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = ${element.element.arguments[0].rhs.value?remove_beginning("Optional[")?remove_ending("]")}.${element.element.arguments[0].rhs.value?remove_beginning("Optional[")?remove_ending("]")}(
<#list element.element.arguments[1].rhs.value?split(",") as parameter>
${parameter?remove_beginning("Optional[")?remove_ending("]")}<#sep>, </#sep>
</#list>)
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
<#elseif mode == "PARAMETER_VALIDATION">
${element.name}temp = ${element.element.arguments[0].rhs.value?remove_beginning("Optional[")?remove_ending("]")}.${element.element.arguments[0].rhs.value?remove_beginning("Optional[")?remove_ending("]")}()
parameters_with_type = ${element.name}temp.get_parameters()
<#list element.element.arguments[1].rhs.value?split(", ") as parameter>
if '${parameter?remove_beginning("Optional[")?remove_ending("]")?keep_before("=")}' in parameters_with_type:
if isinstance(${parameter?remove_beginning("Optional[")?remove_ending("]")?keep_after("=")},parameters_with_type['${parameter?remove_beginning("Optional[")?remove_ending("]")?keep_before("=")}']) == False:
raise TypeError('Wrong ' + str(type(${parameter?remove_beginning("Optional[")?remove_ending("]")?keep_after("=")})) + ' of parameter \'${parameter?remove_beginning("Optional[")?remove_ending("]")?keep_before("=")}\' given in the model')
else:
raise AttributeError('Parameter of Layer not added to get_parameters function')
</#list>
</#if>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment