Commit 7c76d83e authored by Eyüp Harputlu's avatar Eyüp Harputlu

added log cosh loss

parent 75d6c0dc
......@@ -16,7 +16,7 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.1-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.3.2-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -168,6 +168,9 @@ class ${tc.fileNameWithoutEnding}:
prediction = mx.symbol.log_softmax(prediction, axis=1)
loss_func = mx.symbol.mean(label * (mx.symbol.log(label) - prediction), axis=0, exclude=True)
loss_func = mx.symbol.MakeLoss(loss_func, name="kullback_leibler")
elif loss == 'log_cosh':
loss_func = mx.symbol.mean(mx.symbol.log(mx.symbol.cosh(prediction - label)), axis=0, exclude=True)
loss_func = mx.symbol.MakeLoss(loss_func, name="log_cosh")
else:
logging.error("Invalid loss parameter.")
......
......@@ -168,6 +168,9 @@ class CNNCreator_Alexnet:
prediction = mx.symbol.log_softmax(prediction, axis=1)
loss_func = mx.symbol.mean(label * (mx.symbol.log(label) - prediction), axis=0, exclude=True)
loss_func = mx.symbol.MakeLoss(loss_func, name="kullback_leibler")
elif loss == 'log_cosh':
loss_func = mx.symbol.mean(mx.symbol.log(mx.symbol.cosh(prediction - label)), axis=0, exclude=True)
loss_func = mx.symbol.MakeLoss(loss_func, name="log_cosh")
else:
logging.error("Invalid loss parameter.")
......
......@@ -168,6 +168,9 @@ class CNNCreator_CifarClassifierNetwork:
prediction = mx.symbol.log_softmax(prediction, axis=1)
loss_func = mx.symbol.mean(label * (mx.symbol.log(label) - prediction), axis=0, exclude=True)
loss_func = mx.symbol.MakeLoss(loss_func, name="kullback_leibler")
elif loss == 'log_cosh':
loss_func = mx.symbol.mean(mx.symbol.log(mx.symbol.cosh(prediction - label)), axis=0, exclude=True)
loss_func = mx.symbol.MakeLoss(loss_func, name="log_cosh")
else:
logging.error("Invalid loss parameter.")
......
......@@ -168,6 +168,9 @@ class CNNCreator_VGG16:
prediction = mx.symbol.log_softmax(prediction, axis=1)
loss_func = mx.symbol.mean(label * (mx.symbol.log(label) - prediction), axis=0, exclude=True)
loss_func = mx.symbol.MakeLoss(loss_func, name="kullback_leibler")
elif loss == 'log_cosh':
loss_func = mx.symbol.mean(mx.symbol.log(mx.symbol.cosh(prediction - label)), axis=0, exclude=True)
loss_func = mx.symbol.MakeLoss(loss_func, name="log_cosh")
else:
logging.error("Invalid loss parameter.")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment