Commit fb5b86ba authored by Christian Fuß's avatar Christian Fuß
Browse files

adjusted tests

parent 7c896581
Pipeline #220735 passed with stages
in 7 minutes and 13 seconds
...@@ -46,7 +46,6 @@ component Show_attend_tell{ ...@@ -46,7 +46,6 @@ component Show_attend_tell{
FullyConnected(units=37758) -> FullyConnected(units=37758) ->
Tanh() -> Tanh() ->
Dropout(p=0.25) -> Dropout(p=0.25) ->
Softmax() ->
ArgMax() -> ArgMax() ->
target[t] target[t]
}; };
......
...@@ -17,8 +17,8 @@ component Show_attend_tell_images_as_input{ ...@@ -17,8 +17,8 @@ component Show_attend_tell_images_as_input{
0 -> target[0]; 0 -> target[0];
images -> images ->
Convolution(kernel=(7,7), channels=128, stride=(7,7), padding="valid") -> Convolution(kernel=(7,7), channels=128, stride=(7,7), padding="same") ->
Convolution(kernel=(4,4), channels=128, stride=(4,4), padding="valid") -> Convolution(kernel=(4,4), channels=128, stride=(4,4), padding="same") ->
Reshape(shape=(64, 128)) -> Reshape(shape=(64, 128)) ->
features; features;
...@@ -52,7 +52,6 @@ component Show_attend_tell_images_as_input{ ...@@ -52,7 +52,6 @@ component Show_attend_tell_images_as_input{
FullyConnected(units=37758) -> FullyConnected(units=37758) ->
Tanh() -> Tanh() ->
Dropout(p=0.25) -> Dropout(p=0.25) ->
Softmax() ->
ArgMax() -> ArgMax() ->
target[t] target[t]
}; };
......
...@@ -54,7 +54,7 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss): ...@@ -54,7 +54,7 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
loss = -(pred * label).sum(axis=self._axis, keepdims=True) loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications # ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices: for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i)) loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True) return loss.mean(axis=self._batch_axis, exclude=True)
@mx.metric.register @mx.metric.register
...@@ -278,6 +278,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: ...@@ -278,6 +278,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
tic = None tic = None
for epoch in range(begin_epoch, begin_epoch + num_epoch): for epoch in range(begin_epoch, begin_epoch + num_epoch):
train_iter.reset() train_iter.reset()
for batch_i, batch in enumerate(train_iter): for batch_i, batch in enumerate(train_iter):
with autograd.record(): with autograd.record():
...@@ -320,6 +321,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net: ...@@ -320,6 +321,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
tic = None tic = None
train_test_iter.reset() train_test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params) metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_test_iter): for batch_i, batch in enumerate(train_test_iter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment