Commit 23631e61 authored by Sebastian N.'s avatar Sebastian N.
Browse files

Fixed a bug

parent 3ce335f0
Pipeline #200243 failed with stages
in 18 seconds
......@@ -266,7 +266,7 @@ class ${tc.fileNameWithoutEnding}:
if not os.path.isdir(self._net_creator._model_dir_):
raise
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) if len(net.collect_params().values()) != 0 for network in self._networks.values()]
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
......
......@@ -23,7 +23,7 @@
<#else>
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
<#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_"))>
<#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_")) && !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_state_"))>
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body)[0]}, ${tc.getStreamOutputNames(networkInstruction.body)[0]}label))
</#if>
<#list networkInstruction.body.elements as element>
......
......@@ -266,7 +266,7 @@ class CNNSupervisedTrainer_Alexnet:
if not os.path.isdir(self._net_creator._model_dir_):
raise
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) if len(net.collect_params().values()) != 0 for network in self._networks.values()]
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
......
......@@ -266,7 +266,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
if not os.path.isdir(self._net_creator._model_dir_):
raise
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) if len(net.collect_params().values()) != 0 for network in self._networks.values()]
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
......
......@@ -266,7 +266,7 @@ class CNNSupervisedTrainer_VGG16:
if not os.path.isdir(self._net_creator._model_dir_):
raise
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) if len(net.collect_params().values()) != 0 for network in self._networks.values()]
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment