Commit 97af19e9 authored by Christian Fuß's avatar Christian Fuß

added option for teacher forcing. Fixed SoftmaxCrossEntropyIgnoreIndices loss....

added option for teacher forcing. Fixed SoftmaxCrossEntropyIgnoreIndices loss. Added test for Show, attend and tell architecture
parent ae55fd73
Pipeline #207388 failed with stages
in 26 seconds
......@@ -190,6 +190,7 @@ class ${tc.fileNameWithoutEnding}:
context='gpu',
checkpoint_period=5,
save_attention_image=False,
use_teacher_forcing=False,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
......@@ -242,12 +243,12 @@ class ${tc.fileNameWithoutEnding}:
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
#if loss == 'softmax_cross_entropy':
# fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
# loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
ignore_indices = [2]
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
if loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
......
......@@ -37,6 +37,9 @@ if __name__ == "__main__":
<#if (config.normalize)??>
normalize=${config.normalize?string("True","False")},
</#if>
<#if (config.useTeacherForcing)??>
use_teacher_forcing='${config.useTeacherForcing?string("True","False")}',
</#if>
<#if (config.saveAttentionImage)??>
save_attention_image='${config.saveAttentionImage?string("True","False")}',
</#if>
......
......@@ -38,6 +38,8 @@
<#if tc.endsWithArgmax(networkInstruction.body)>
${outputName} = mx.nd.argmax(${outputName}, axis=1).expand_dims(1)
</#if>
if use_teacher_forcing == "True":
${outputName} = mx.nd.expand_dims(labels[${tc.getIndex(outputName, true)}], axis=1)
</#if>
</#list>
<#else>
......
......@@ -148,6 +148,14 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testShow_attend_tell() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/valid_tests", "-r", "Show_attend_tell", "-o", "./target/generated-sources-cnnarch/"};
CNNArch2GluonCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testFullCfgGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
......
......@@ -189,6 +189,7 @@ class CNNSupervisedTrainer_Alexnet:
context='gpu',
checkpoint_period=5,
save_attention_image=False,
use_teacher_forcing=False,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
......@@ -241,12 +242,12 @@ class CNNSupervisedTrainer_Alexnet:
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
#if loss == 'softmax_cross_entropy':
# fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
# loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
ignore_indices = [2]
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
if loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
......
......@@ -189,6 +189,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
context='gpu',
checkpoint_period=5,
save_attention_image=False,
use_teacher_forcing=False,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
......@@ -241,12 +242,12 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
#if loss == 'softmax_cross_entropy':
# fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
# loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
ignore_indices = [2]
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
if loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
......
......@@ -189,6 +189,7 @@ class CNNSupervisedTrainer_VGG16:
context='gpu',
checkpoint_period=5,
save_attention_image=False,
use_teacher_forcing=False,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
......@@ -241,12 +242,12 @@ class CNNSupervisedTrainer_VGG16:
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
#if loss == 'softmax_cross_entropy':
# fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
# loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
ignore_indices = [2]
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
if loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
......
architecture Show_attend_tell{
def input Z(0:255)^{3,224,224} images
def output Z(0:37758)^{1} target[25]
layer LSTM(units=512) decoder;
layer FullyConnected(units = 256) features;
layer FullyConnected(units = 1, flatten=false) attention;
0 -> target[0];
images ->
Convolution(kernel=(7,7), channels=64, stride=(7,7), padding="valid") ->
Convolution(kernel=(4,4), channels=64, stride=(4,4), padding="valid") ->
GlobalPooling(pool_type="max") ->
features;
timed <t> GreedySearch(max_length=25){
(
(
(
features.output ->
FullyConnected(units=512, flatten=false)
|
decoder.state[0] ->
FullyConnected(units=512, flatten=false)
) ->
BroadcastAdd() ->
Tanh() ->
FullyConnected(units=1, flatten=false) ->
Softmax(axis=0) ->
Dropout(p=0.25) ->
attention
|
features.output
)->
BroadcastMultiply() ->
ReduceSum(axis=0) ->
ExpandDims(axis=0)
|
target[t-1] ->
Embedding(output_dim=256)
) ->
Concatenate(axis=1) ->
decoder ->
FullyConnected(units=37758) ->
Tanh() ->
Dropout(p=0.25) ->
Softmax() ->
ArgMax() ->
target[t]
};
}
......@@ -8,3 +8,4 @@ Invariant data/Invariant
RNNencdec data/RNNencdec
RNNtest data/RNNtest
RNNsearch data/RNNsearch
Show_attend_tell data/Show_attend_tell
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment