Commit de0bca61 authored by lr119628's avatar lr119628
Browse files

[update] moved DataClass class defintion out off the ftl file

parent 0b7fb88f
Pipeline #524787 failed with stage
in 23 seconds
......@@ -182,7 +182,7 @@ class ${tc.fileNameWithoutEnding}:
<#list tc.architecture.networkInstructions as networkInstruction>
<#if tc.containsAdaNet()>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}()
self.dataClass[${networkInstruction?index}] = DataClass_${networkInstruction?index}()
self.dataClass[${networkInstruction?index}] = DataClass_${networkInstruction?index}
<#else>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
</#if>
......
......@@ -532,8 +532,7 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
elif key.startswith("labels_"):
self.label_memory.append(mem_dict[key])
<#if tc.containsAdaNet()>
# Generation of the artificial blocks for the Streams below
from mxnet.gluon import nn
# Blocks needed for AdaNet are generated below
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.containsAdaNet()>
${tc.include(networkInstruction.body, "ADANET_CONSTRUCTION")}
......@@ -552,31 +551,23 @@ class Net_${networkInstruction?index}(gluon.HybridBlock):
def hybrid_forward(self,F,x):
return self.dummy(x)
class DataClass_${networkInstruction?index}:
"""
this object holds all the necessary information for AdaNet
"""
def __init__(self, **kwargs):
FullyConnected = AdaNetConfig.DEFAULT_BLOCK.value
<#if outblock.isPresent()>
self.outBlock = ${outblock.get().name}
<#else>
self.outBlock = None
</#if>
<#if inblock.isPresent()>
self.inBlock = ${inblock.get().name}
<#else>
self.inBlock = None
</#if>
<#if block.isPresent()>
self.block = ${block.get().name}
if self.block is AdaNetConfig.DEFAULT_BLOCK.value:
self.block = CoreAdaNet.DefaultBuildingBlock
<#else>
self.block = None
</#if>
self.model_shape = ${tc.getDefinedOutputDimension()}
DataClass_${networkInstruction?index} = CoreAdaNet.DataClass(
<#if outblock.isPresent()>
outBlock = ${outblock.get().name},
<#else>
outBlock = None,
</#if>
<#if inblock.isPresent()>
inBlock = ${inblock.get().name},
<#else>
inBlock = None,
</#if>
<#if block.isPresent()>
block = ${block.get().name},
<#else>
block = None,
</#if>
model_shape = ${tc.getDefinedOutputDimension()})
</#if>
</#list>
<#else>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment