Aufgrund von Umarbeiten des s3 Storage wird es in GitLab, in nächster Zeit, mögliche Performance-Einbußen geben. Näheres dazu unter: https://maintenance.itc.rwth-aachen.de/ticket/status/messages/43/show_ticket/6670

Commit b0d04271 authored by Christian Fuß's avatar Christian Fuß

adjusted Show, Attend and Tell model


Former-commit-id: a8440765
parent 7917840b
This diff is collapsed.
......@@ -3,7 +3,7 @@ package showAttendTell;
component Main{
ports in Z(0:255)^{3, 224, 224} images,
in Z(-oo:oo)^{64,2048} data,
out Z(0:25316)^{1} target[25];
out Z(0:37758)^{1} target[25];
instance Show_attend_tell net;
......
......@@ -4,7 +4,6 @@ configuration Show_attend_tell{
context:cpu
eval_metric:bleu
loss:softmax_cross_entropy_ignore_indices{
sparse_label:true
ignore_indices:2
}
use_teacher_forcing:true
......
......@@ -9,7 +9,7 @@ component Show_attend_tell{
layer LSTM(units=512) decoder;
layer FullyConnected(units = 256) features;
layer FullyConnected(units = 256, flatten=false) features;
layer FullyConnected(units = 1, flatten=false) attention;
0 -> target[0];
......@@ -30,7 +30,6 @@ component Show_attend_tell{
Tanh() ->
FullyConnected(units=1, flatten=false) ->
Softmax(axis=0) ->
Dropout(p=0.25) ->
attention
|
features.output
......
......@@ -11,15 +11,15 @@ component Show_attend_tell_images_as_input{
layer LSTM(units=512) decoder;
layer FullyConnected(units = 256) features;
layer FullyConnected(units = 256, flatten=false) features;
layer FullyConnected(units = 1, flatten=false) attention;
0 -> target[0];
images ->
Convolution(kernel=(7,7), channels=64, stride=(7,7), padding="valid") ->
Convolution(kernel=(4,4), channels=64, stride=(4,4), padding="valid") ->
GlobalPooling(pool_type="max") ->
Convolution(kernel=(7,7), channels=128, stride=(7,7), padding="valid") ->
Convolution(kernel=(4,4), channels=128, stride=(4,4), padding="valid") ->
Reshape(shape=(64, 128)) ->
features;
timed <t> GreedySearch(max_length=25){
......@@ -36,7 +36,6 @@ component Show_attend_tell_images_as_input{
Tanh() ->
FullyConnected(units=1, flatten=false) ->
Softmax(axis=0) ->
Dropout(p=0.25) ->
attention
|
features.output
......
......@@ -52,7 +52,6 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
else:
label = _reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
#loss = _apply_weighting(F, loss, self._weight, sample_weight)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i))
......@@ -246,7 +245,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
if loss == 'softmax_cross_entropy_ignore_indices':
elif loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
elif loss == 'sigmoid_binary_cross_entropy':
......@@ -324,7 +323,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
train_test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......@@ -363,7 +362,12 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if dict[int(labels[l+1][0].asscalar())] == "<end>":
if int(labels[l+1][0].asscalar()) > len(dict):
ax.set_title("<unk>")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(labels[l+1][0].asscalar())] == "<end>":
ax.set_title(".")
img = ax.imshow(train_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
......@@ -394,7 +398,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......@@ -426,13 +430,18 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
attention = mx.nd.squeeze(attention)
attention_resized = np.resize(attention.asnumpy(), (8, 8))
ax = fig.add_subplot(max_length//3, max_length//4, l+2)
if dict[int(mx.nd.slice_axis(mx.nd.argmax(outputs[l+1], axis=1), axis=0, begin=0, end=1).asscalar())] == "<end>":
if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
ax.set_title("<unk>")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
ax.set_title(".")
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
else:
ax.set_title(dict[int(mx.nd.slice_axis(mx.nd.argmax(outputs[l+1], axis=1), axis=0, begin=0, end=1).asscalar())])
ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
img = ax.imshow(test_images[0+test_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment