Commit de0bca61 authored by lr119628's avatar lr119628
Browse files

[update] moved DataClass class defintion out off the ftl file

parent 0b7fb88f
Pipeline #524787 failed with stage
in 23 seconds
...@@ -182,7 +182,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -182,7 +182,7 @@ class ${tc.fileNameWithoutEnding}:
<#list tc.architecture.networkInstructions as networkInstruction> <#list tc.architecture.networkInstructions as networkInstruction>
<#if tc.containsAdaNet()> <#if tc.containsAdaNet()>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}() self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}()
self.dataClass[${networkInstruction?index}] = DataClass_${networkInstruction?index}() self.dataClass[${networkInstruction?index}] = DataClass_${networkInstruction?index}
<#else> <#else>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="") self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
</#if> </#if>
......
...@@ -532,8 +532,7 @@ class EpisodicMemory(EpisodicReplayMemoryInterface): ...@@ -532,8 +532,7 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
elif key.startswith("labels_"): elif key.startswith("labels_"):
self.label_memory.append(mem_dict[key]) self.label_memory.append(mem_dict[key])
<#if tc.containsAdaNet()> <#if tc.containsAdaNet()>
# Generation of the artificial blocks for the Streams below # Blocks needed for AdaNet are generated below
from mxnet.gluon import nn
<#list tc.architecture.networkInstructions as networkInstruction> <#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.containsAdaNet()> <#if networkInstruction.body.containsAdaNet()>
${tc.include(networkInstruction.body, "ADANET_CONSTRUCTION")} ${tc.include(networkInstruction.body, "ADANET_CONSTRUCTION")}
...@@ -552,31 +551,23 @@ class Net_${networkInstruction?index}(gluon.HybridBlock): ...@@ -552,31 +551,23 @@ class Net_${networkInstruction?index}(gluon.HybridBlock):
def hybrid_forward(self,F,x): def hybrid_forward(self,F,x):
return self.dummy(x) return self.dummy(x)
class DataClass_${networkInstruction?index}: DataClass_${networkInstruction?index} = CoreAdaNet.DataClass(
<#if outblock.isPresent()>
""" outBlock = ${outblock.get().name},
this object holds all the necessary information for AdaNet <#else>
""" outBlock = None,
def __init__(self, **kwargs): </#if>
FullyConnected = AdaNetConfig.DEFAULT_BLOCK.value <#if inblock.isPresent()>
<#if outblock.isPresent()> inBlock = ${inblock.get().name},
self.outBlock = ${outblock.get().name} <#else>
<#else> inBlock = None,
self.outBlock = None </#if>
</#if> <#if block.isPresent()>
<#if inblock.isPresent()> block = ${block.get().name},
self.inBlock = ${inblock.get().name} <#else>
<#else> block = None,
self.inBlock = None </#if>
</#if> model_shape = ${tc.getDefinedOutputDimension()})
<#if block.isPresent()>
self.block = ${block.get().name}
if self.block is AdaNetConfig.DEFAULT_BLOCK.value:
self.block = CoreAdaNet.DefaultBuildingBlock
<#else>
self.block = None
</#if>
self.model_shape = ${tc.getDefinedOutputDimension()}
</#if> </#if>
</#list> </#list>
<#else> <#else>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment