Skip to content
Snippets Groups Projects
Commit d6185b6e authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Merge

parents 039f6e64 d04e1188
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #200183 failed
......@@ -33,6 +33,10 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.EXPAND_DIMS_NAME);
supportedLayerList.add(AllPredefinedLayers.SQUEEZE_NAME);
supportedLayerList.add(AllPredefinedLayers.SWAPAXES_NAME);
supportedLayerList.add(AllPredefinedLayers.BROADCAST_MULTIPLY_NAME);
supportedLayerList.add(AllPredefinedLayers.REDUCE_SUM_NAME);
supportedLayerList.add(AllPredefinedLayers.BROADCAST_ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME);
}
}
......@@ -40,6 +40,16 @@ class NoNormalization(gluon.HybridBlock):
return x
class Reshape(gluon.HybridBlock):
def __init__(self, shape, **kwargs):
super(Reshape, self).__init__(**kwargs)
with self.name_scope():
self.shape = shape
def hybrid_forward(self, F, x):
return F.reshape(data=x, shape=self.shape)
class CustomRNN(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, bidirectional, **kwargs):
super(CustomRNN, self).__init__(**kwargs)
......
......@@ -34,6 +34,30 @@ class LogCoshLoss(gluon.loss.Loss):
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._ignore_indices = ignore_indices
self._sparse_label = sparse_label
self._from_logits = from_logits
def hybrid_forward(self, F, pred, label, sample_weight=None):
log_softmax = F.log_softmax
pick = F.pick
if not self._from_logits:
pred = log_softmax(pred, self._axis)
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
#loss = _apply_weighting(F, loss, self._weight, sample_weight)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i))
return loss.mean(axis=self._batch_axis, exclude=True)
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
N = 4
......@@ -144,6 +168,8 @@ class BLEU(mx.metric.EvalMetric):
return new_list
class ${tc.fileNameWithoutEnding}:
def applyBeamSearch(input, length, width, maxLength, currProb, netIndex, bestOutput):
bestProb = 0.0
......@@ -336,12 +362,43 @@ class ${tc.fileNameWithoutEnding}:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(output_name).size > 1:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
'''
#Compute BLEU and NIST Score if data folder contains a dictionary -> NLP dataset
if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
dict = pickle.load(f)
import nltk.translate.bleu_score
import nltk.translate.nist_score
prediction = []
for index in range(batch_size):
sentence = ''
for entry in predictions:
sentence += dict[int(entry[index].asscalar())] + ' '
prediction.append(sentence)
for index in range(batch_size):
sentence = ''
for batchEntry in batch.label:
sentence += dict[int(batchEntry[index].asscalar())] + ' '
print("############################")
print("label1: ", sentence)
print("prediction1: ", prediction[index])
BLEUscore = nltk.translate.bleu_score.sentence_bleu([sentence], prediction[index])
NISTscore = nltk.translate.nist_score.sentence_nist([sentence], prediction[index])
print("BLEU: ", BLEUscore)
print("NIST: ", NISTscore)
print("############################")
'''
metric.update(preds=predictions, labels=labels)
train_metric_score = metric.get()[1]
......@@ -366,7 +423,7 @@ class ${tc.fileNameWithoutEnding}:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(output_name).size > 1:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
......
<#if mode == "FORWARD_FUNCTION">
${element.name} = F.broadcast_add(${tc.join(element.inputs, ",")})
<#elseif mode == "PYTHON_INLINE">
self.${element.name} = mx.nd.broadcast_add(${tc.join(element.inputs, ",")})
</#if>
\ No newline at end of file
<#if mode == "FORWARD_FUNCTION">
${element.name} = F.broadcast_mul(${tc.join(element.inputs, ", ")})
</#if>
\ No newline at end of file
<#if mode == "FORWARD_FUNCTION">
${element.name} = ${tc.join(element.inputs, " * ")}
<#elseif mode == "PYTHON_INLINE">
${element.name} = ${tc.join(element.inputs, " * ")}
<#elseif mode == "CPP_INLINE">
vector<float> ${element.name}(${element.inputs[0]}.size());
for (size_t i = 0; i != ${element.name}.size(); ++i) {
${element.name}[i] = ${tc.join(element.inputs, " * ", "", "[i]")};
}
</#if>
\ No newline at end of file
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = Reshape(shape=(${tc.join(element.shape, ",")}))
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
<#elseif mode == "PYTHON_INLINE">
self.${element.name} = Reshape(shape=${shape})
</#if>
\ No newline at end of file
<#-- This template is not used if the following architecture element is an output. See Output.ftl -->
<#assign axis = element.axis?c>
<#assign input = element.inputs[0]>
<#if mode == "FORWARD_FUNCTION">
${element.name} = F.softmax(${input})
${element.name} = F.softmax(${input}, axis=${axis})
</#if>
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment