Commit cb256538 authored by lr119628's avatar lr119628
Browse files

[update] test pass fix, CNNSupervisedTrainer.ftl

parent 3dce2377
......@@ -555,9 +555,12 @@ class ${tc.fileNameWithoutEnding}:
def __init__(self, data_loader, net_constructor):
self._data_loader = data_loader
self._net_creator = net_constructor
<#if tc.containsAdaNet()>
self._dataClass = net_constructor.dataClass
self._networks = {}
self.AdaNet = ${tc.containsAdaNet()?string('True','False')}
</#if>
self._networks = {}
def train(self, batch_size=64,
num_epoch=10,
......@@ -687,8 +690,9 @@ class ${tc.fileNameWithoutEnding}:
else:
logging.error("Invalid loss parameter.")
loss_function.hybridize()
<#list tc.architecture.networkInstructions as networkInstruction>
<#if tc.containsAdaNet()>
<#list tc.architecture.networkInstructions as networkInstruction>
assert self._networks[${networkInstruction?index}].AdaNet, "passed model is not an AdaNet model"
self._networks[${networkInstruction?index}] = fit(model= self._networks[${networkInstruction?index}],
loss=loss_function,
......@@ -706,8 +710,8 @@ class ${tc.fileNameWithoutEnding}:
)
logging.info(self._networks[0])
#put here the AdaNet logic
</#if>
</#list>
</#if>
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.episodicSubNetworks?has_content>
<#assign episodicReplayVisited = true>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment