Commit 3052564d authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'adaNet_luis_rickert' into 'master'

AdaNet Integration to CNNArch2Gluon

See merge request !44
parents 51eaec66 3b18a3f5
Pipeline #531549 passed with stages
in 2 minutes and 14 seconds
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.4.9-SNAPSHOT</version>
<version>0.4.10-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,13 +17,13 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch2X.version>0.4.8-SNAPSHOT</CNNArch2X.version>
<CNNArch2X.version>0.4.9-SNAPSHOT</CNNArch2X.version>
<EMADL2PythonWrapper.version>0.0.3-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>25.1-jre</guava.version>
<junit.version>4.13.1</junit.version>
<logback.version>1.2.0</logback.version>
<logback.version>1.1.2</logback.version>
<jscience.version>4.3.1</jscience.version>
<!-- .. Plugins ....................................................... -->
......
......@@ -42,10 +42,11 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME);
// supportedLayerList.add(AllPredefinedLayers.CROP_NAME);
supportedLayerList.add(AllPredefinedLayers.LARGE_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.EPISODIC_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.EPISODIC_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.DOT_PRODUCT_SELF_ATTENTION_NAME);
supportedLayerList.add(AllPredefinedLayers.LOAD_NETWORK_NAME);
supportedLayerList.add(AllPredefinedLayers.LAYERNORM_NAME);
supportedLayerList.add(AllPredefinedLayers.AdaNet_Name);
}
}
......@@ -16,12 +16,21 @@ import java.util.regex.Pattern;
public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public static final String NET_DEFINITION_MODE_KEY = "mode";
public List<String> worked_list = new ArrayList<String>();
public CNNArch2GluonTemplateController(ArchitectureSymbol architecture,
TemplateConfiguration templateConfiguration) {
super(architecture, templateConfiguration);
}
public String getDefinedOutputDimension(){
// function calculates the output shape as defined in the .emadl, used for AdaNet layer
ArchTypeSymbol types = ((IODeclarationSymbol)this.getArchitecture().getOutputs().get(0).getDeclaration()).getType();
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("(");
types.getDimensions().forEach(elem->stringBuilder.append(elem).append(','));
stringBuilder.append(")");
return stringBuilder.toString();
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer, NetDefinitionMode netDefinitionMode){
......@@ -30,12 +39,6 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
ftlContext.put(NET_DEFINITION_MODE_KEY, netDefinitionMode.toString());
if (this.getDataElement().getElement() instanceof LayerSymbol){
if(((LayerSymbol) (this.getDataElement().getElement())).getDeclaration() instanceof CustomLayerDeclaration){
templatePath = relativePath + "CustomLayer" + FTL_FILE_ENDING;
}
}
getTemplateConfiguration().processTemplate(ftlContext, templatePath, writer);
}
......@@ -70,12 +73,25 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public void include(LayerSymbol layer, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
if (layer.isAtomic()){
if(layer.getName().equals(AllPredefinedLayers.AdaNet_Name)&& netDefinitionMode.equals(NetDefinitionMode.ADANET_CONSTRUCTION)){
// construct the AdaNet Layer
include(TEMPLATE_ELEMENTS_DIR_PATH,"AdaNet",writer,netDefinitionMode);
}else if(layer.isAtomic()){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
else {
}else if(layer.isArtificial() && this.containsAdaNet()){
if(netDefinitionMode.equals(NetDefinitionMode.ARTIFICIAL_ARCH_CLASS)){
boolean originalArtificialState = layer.isArtificial();
layer.setArtificial(false);
if (!this.worked_list.contains(layer.getName())){
include(TEMPLATE_ELEMENTS_DIR_PATH, "ArtificialArchClass", writer, netDefinitionMode);
this.worked_list.add(layer.getName());
}
layer.setArtificial(originalArtificialState);
}else{
include((ArchitectureElementSymbol) layer.getResolvedThis().get(), writer, netDefinitionMode);
}
}else{
include((ArchitectureElementSymbol) layer.getResolvedThis().get(), writer, netDefinitionMode);
}
......
......@@ -8,7 +8,9 @@ public enum NetDefinitionMode {
ARCHITECTURE_DEFINITION,
PREDICTION_PARAMETER,
FORWARD_FUNCTION,
PARAMETER_VALIDATION;
PARAMETER_VALIDATION,
ARTIFICIAL_ARCH_CLASS,
ADANET_CONSTRUCTION;
public static NetDefinitionMode fromString(final String netDefinitionMode) {
switch(netDefinitionMode) {
......@@ -20,6 +22,10 @@ public enum NetDefinitionMode {
return PREDICTION_PARAMETER;
case "PARAMETER_VALIDATION":
return PARAMETER_VALIDATION;
case "ARTIFICIAL_ARCH_CLASS":
return ARTIFICIAL_ARCH_CLASS;
case "ADANET_CONSTRUCTION":
return ADANET_CONSTRUCTION;
default:
throw new IllegalArgumentException("Unknown Net Definition Mode");
}
......
......@@ -15,7 +15,11 @@ from custom_layers import *
</#if>
<#list tc.architecture.networkInstructions as networkInstruction>
<#if tc.containsAdaNet()>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index},DataClass_${networkInstruction?index}
<#else>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
</#if>
</#list>
class ${tc.fileNameWithoutEnding}:
......@@ -25,6 +29,9 @@ class ${tc.fileNameWithoutEnding}:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
<#if tc.containsAdaNet()>
self.dataClass = {}
</#if>
<#if (tc.weightsPath)??>
self._weights_dir_ = "${tc.weightsPath}/"
<#else>
......@@ -173,7 +180,12 @@ class ${tc.fileNameWithoutEnding}:
def construct(self, context, data_mean=None, data_std=None):
<#list tc.architecture.networkInstructions as networkInstruction>
<#if tc.containsAdaNet()>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}()
self.dataClass[${networkInstruction?index}] = DataClass_${networkInstruction?index}
<#else>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
</#if>
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
......
......@@ -9,13 +9,19 @@ import abc
import warnings
import sys
from mxnet import gluon, nd
<#if tc.containsAdaNet()>
from mxnet.gluon import nn, HybridBlock
from numpy import log, product,prod,sqrt
from mxnet.ndarray import zeros,zeros_like
sys.path.insert(1, '${tc.architecture.getAdaNetUtils()}')
from AdaNetConfig import AdaNetConfig
import CoreAdaNet
</#if>
<#if tc.architecture.customPyFilesPath??>
sys.path.insert(1, '${tc.architecture.customPyFilesPath}')
from custom_layers import *
</#if>
class ZScoreNormalization(gluon.HybridBlock):
def __init__(self, data_mean, data_std, **kwargs):
super(ZScoreNormalization, self).__init__(**kwargs)
......@@ -525,8 +531,46 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
self.value_memory.append(mem_dict[key])
elif key.startswith("labels_"):
self.label_memory.append(mem_dict[key])
<#if tc.containsAdaNet()>
# Blocks needed for AdaNet are generated below
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.containsAdaNet()>
${tc.include(networkInstruction.body, "ADANET_CONSTRUCTION")}
<#assign outblock = networkInstruction.body.getElements()[1].getDeclaration().getBlock("outBlock")>
<#assign block = networkInstruction.body.getElements()[1].getDeclaration().getBlock("block")>
<#assign inblock = networkInstruction.body.getElements()[1].getDeclaration().getBlock("inBlock")>
class Net_${networkInstruction?index}(gluon.HybridBlock):
# this is a dummy network during the AdaNet generation it gets overridden
# it is only here so many if tags in the .ftl files can be avoided
def __init__(self,**kwargs):
super(Net_${networkInstruction?index},self).__init__(**kwargs)
with self.name_scope():
self.AdaNet = True
self.dummy = nn.Dense(units=1)
def hybrid_forward(self,F,x):
return self.dummy(x)
FullyConnected = AdaNetConfig.DEFAULT_BLOCK.value
DataClass_${networkInstruction?index} = CoreAdaNet.DataClass(
<#if outblock.isPresent()>
outBlock = ${outblock.get().name},
<#else>
outBlock = None,
</#if>
<#if inblock.isPresent()>
inBlock = ${inblock.get().name},
<#else>
inBlock = None,
</#if>
<#if block.isPresent()>
block = ${block.get().name},
<#else>
block = None,
</#if>
model_shape = ${tc.getDefinedOutputDimension()})
</#if>
</#list>
<#else>
<#list tc.architecture.networkInstructions as networkInstruction>
#Stream ${networkInstruction?index}
<#list networkInstruction.body.episodicSubNetworks as elements>
......@@ -565,8 +609,6 @@ ${tc.include(networkInstruction.body, elements?index, "FORWARD_FUNCTION")}
</#list>
class Net_${networkInstruction?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, mx_context=None, **kwargs):
super(Net_${networkInstruction?index}, self).__init__(**kwargs)
......@@ -609,4 +651,4 @@ ${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
</#if>
</#if>
</#list>
</#if>
......@@ -12,6 +12,14 @@ import math
import sys
import inspect
from mxnet import gluon, autograd, nd
<#if tc.containsAdaNet()>
from typing import List
from mxnet.gluon.loss import Loss, SigmoidBCELoss
from mxnet.ndarray import add, concatenate
sys.path.insert(1, '${tc.architecture.getAdaNetUtils()}')
#${tc.architecture.getAdaNetUtils()}
from adanet import fit
</#if>
try:
import AdamW
except:
......@@ -295,6 +303,10 @@ class ${tc.fileNameWithoutEnding}:
def __init__(self, data_loader, net_constructor):
self._data_loader = data_loader
self._net_creator = net_constructor
<#if tc.containsAdaNet()>
self._dataClass = net_constructor.dataClass
self.AdaNet = ${tc.containsAdaNet()?string('True','False')}
</#if>
self._networks = {}
def train(self, batch_size=64,
......@@ -424,9 +436,33 @@ class ${tc.fileNameWithoutEnding}:
loss_function = LogCoshLoss()
else:
logging.error("Invalid loss parameter.")
loss_function.hybridize()
<#if tc.containsAdaNet()>
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.containsAdaNet()>
assert self._networks[${networkInstruction?index}].AdaNet, "passed model is not an AdaNet model"
self._networks[${networkInstruction?index}] = fit(
loss=loss_function,
optimizer=optimizer,
epochs=num_epoch,
optimizer_params = optimizer_params,
train_iter = train_iter,
data_class = self._dataClass[${networkInstruction?index}],
batch_size=batch_size,
ctx=mx_context[0],
logging=logging
)
logging.info(self._networks[${networkInstruction?index}])
logging.info(f"node count: {self._networks[${networkInstruction?index}].get_node_count()}")
</#if>
# update trainers
if optimizer == "adamw":
trainers = [mx.gluon.Trainer(network.collect_params(), AdamW.AdamW(**optimizer_params)) for network in self._networks.values() if len(network.collect_params().values()) != 0]
else:
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
</#list>
</#if>
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.episodicSubNetworks?has_content>
<#assign episodicReplayVisited = true>
......
<#if mode == "ADANET_CONSTRUCTION">
<#assign outBlock = element.element.getDeclaration().getBlock("outBlock")>
<#assign inBlock = element.element.getDeclaration().getBlock("inBlock")>
<#assign Block = element.element.getDeclaration().getBlock("block").get()>
<#if Block.isArtificial()>
#BuildingBlock
</#if>
${tc.include(Block,"ARTIFICIAL_ARCH_CLASS")}
<#if inBlock.isPresent()>
#inputBlock
<#if inBlock.get().isArtificial()>
${tc.include(inBlock.get(),"ARTIFICIAL_ARCH_CLASS")}
</#if>
</#if>
<#if outBlock.isPresent()>
<#if outBlock.get().isArtificial()>
#outputBlock
${tc.include(outBlock.get(),"ARTIFICIAL_ARCH_CLASS")}
</#if>
</#if>
</#if>
\ No newline at end of file
<#assign input = element.inputs[0]>
<#assign name = element.element.name>
<#assign args = element.element.arguments>
<#if mode == "ARTIFICIAL_ARCH_CLASS">
class ${name}(gluon.HybridBlock):
def __init__(self, **kwargs):
super(${name}, self).__init__(**kwargs)
with self.name_scope():
${tc.include(element.element,"ARCHITECTURE_DEFINITION")}
def hybrid_forward(self,F, ${input}):
${tc.include(element.element,"FORWARD_FUNCTION")}
return <#list element.element.getLastAtomicElements() as el><#if el?index ==0>${tc.getName(el)}<#else>,tc.getName(el)</#if></#list>
</#if>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment