Commit de10570c authored by Julian Dierkes's avatar Julian Dierkes

constraint losses can now be specified in cnnt

parent 6ff123d1
Pipeline #226090 failed with stages
in 1 minute and 6 seconds
......@@ -173,6 +173,10 @@ public class GluonConfigurationData extends ConfigurationData {
return getMultiParamMapEntry(CONSTRAINT_DISTRIBUTION, "name");
}
public Map<String, Map<String, Object>> getConstraintLosses() {
return getMultiParamMapEntry(CONSTRAINT_LOSS, "name");
}
public Map<String, Object> getStrategy() {
assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " +
" non reinforcement learning context";
......
......@@ -6,6 +6,54 @@ import os
import shutil
from mxnet import gluon, autograd, nd
class CrossEntropyLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
def hybrid_forward(self, F, pred, label, sample_weight=None):
pred = F.log(pred)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
class LogCoshLoss(gluon.loss.Loss):
def __init__(self, weight=None, batch_axis=0, **kwargs):
super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)
def hybrid_forward(self, F, pred, label, sample_weight=None):
loss = F.log(F.cosh(pred - label))
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._ignore_indices = ignore_indices
self._sparse_label = sparse_label
self._from_logits = from_logits
def hybrid_forward(self, F, pred, label, sample_weight=None):
log_softmax = F.log_softmax
pick = F.pick
if not self._from_logits:
pred = log_softmax(pred, self._axis)
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
# ugly hardcoded
import matplotlib as mpl
from matplotlib import pyplot as plt
......
......@@ -27,8 +27,30 @@
high=int(max)+1, shape=(batch_size,)+domain[3],
ctx=mx_context), dtype="float32")
if namen[-1] in constraint_losses:
wierd = 'doNothingYet'
if name[-1] in constraint_losses:
loss_dict = constraint_losses[name[:-1]]
loss = loss_dict['name']
margin = loss_dict['margin'] if 'margin' in loss_dict else 1.0
sparseLabel = loss_dict['sparse_label'] if 'sparse_label' in loss_dict else True
ignore_indices = [loss_dict['ignore_indices']] if 'ignore_indices' in loss_dict else []
fromLogits = loss_dict['from_logits'] if 'from_logits' in loss_dict else False
if loss == 'softmax_cross_entropy':
qnet_losses += [mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)]
elif loss == 'softmax_cross_entropy_ignore_indices':
qnet_losses += [SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)]
elif loss == 'sigmoid_binary_cross_entropy':
qnet_losses += [mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()]
elif loss == 'cross_entropy':
qnet_losses += [CrossEntropyLoss(sparse_label=sparseLabel)]
elif loss == 'l2':
qnet_losses += [mx.gluon.loss.L2Loss()]
elif loss == 'l1':
qnet_losses += [mx.gluon.loss.L2Loss()]
elif loss == 'log_cosh':
qnet_losses += [LogCoshLoss()]
else:
logging.error("Invalid loss parameter for constraint:" + name[:-1] + ".")
else:
if domain[0] == float:
qnet_losses += [mx.gluon.loss.L2Loss()]
......
......@@ -97,6 +97,21 @@ if __name__ == "__main__":
</#list>
},
</#if>
<#if (config.constraintLosses)??>
<#assign map = (config.constraintLosses)>
constraint_losses = {
<#list map?keys as nameKey>
'${nameKey}' : {
'name': '${map[nameKey].name}',
<#list map[nameKey]?keys as param>
<#if (param != "name")>
'${param}': ${map[nameKey][param]}<#sep>,
</#if>
</#list>
},
</#list>
},
</#if>
<#if (config.noiseDistribution)??>
noise_distribution = '${config.noiseDistribution.name}',
noise_distribution_params = {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment