Skip to content
Snippets Groups Projects
Commit 9f516b40 authored by Christian Fuß's avatar Christian Fuß
Browse files

adjusted network loss to be computed before applying ArgMax layer

parent dec0229d
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #181931 failed
......@@ -140,16 +140,16 @@ class ${tc.fileNameWithoutEnding}:
</#list>
with autograd.record():
<#include "pythonExecute.ftl">
<#include "pythonExecuteArgmax.ftl">
loss = \
<#list tc.architectureOutputs as output_name>
loss_function(${output_name}, ${output_name}label)<#sep> + \
</#list>
loss = 0
for element in lossList:
loss = loss + element
loss.backward()
for trainer in trainers:
trainer.step(batch_size)
......
<#list tc.getLayerVariableMembers("batch_size")?keys as member>
${member} = mx.nd.zeros((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context)
</#list>
<#list tc.architectureOutputSymbols as output>
${tc.getName(output)} = mx.nd.zeros((batch_size, ${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
</#list>
lossList = []
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")})
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}label))
<#list resolvedBody.elements as element>
<#if element.name == "ArgMax">
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")}, axis=1)
</#if>
</#list>
</#list>
<#else>
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}label))
<#list networkInstruction.body.elements as element>
<#if element.name == "ArgMax">
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}, axis=1)
</#if>
</#list>
<#else>
${tc.include(networkInstruction.body, "PYTHON_INLINE")}
</#if>
</#if>
</#list>
\ No newline at end of file
......@@ -138,13 +138,18 @@ class CNNSupervisedTrainer_Alexnet:
with autograd.record():
predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
lossList = []
predictions_ = self._networks[0](data_)
lossList.append(loss_function(predictions_, predictions_label))
loss = 0
for element in lossList:
loss = loss + element
loss = \
loss_function(predictions_, predictions_label)
loss.backward()
for trainer in trainers:
trainer.step(batch_size)
......
......@@ -138,13 +138,18 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
with autograd.record():
softmax_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
lossList = []
softmax_ = self._networks[0](data_)
lossList.append(loss_function(softmax_, softmax_label))
loss = 0
for element in lossList:
loss = loss + element
loss = \
loss_function(softmax_, softmax_label)
loss.backward()
for trainer in trainers:
trainer.step(batch_size)
......
......@@ -138,13 +138,18 @@ class CNNSupervisedTrainer_VGG16:
with autograd.record():
predictions_ = mx.nd.zeros((batch_size, 1000,), ctx=mx_context)
lossList = []
predictions_ = self._networks[0](data_)
lossList.append(loss_function(predictions_, predictions_label))
loss = 0
for element in lossList:
loss = loss + element
loss = \
loss_function(predictions_, predictions_label)
loss.backward()
for trainer in trainers:
trainer.step(batch_size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment