CNNTrainer.ftl 2.37 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
2 3 4 5
import logging
import mxnet as mx
<#list configurations as config>
import CNNCreator_${config.instanceName}
Nicola Gatto's avatar
Nicola Gatto committed
6
import CNNDataLoader_${config.instanceName}
7
import CNNSupervisedTrainer_${config.instanceName}
8 9 10 11 12
</#list>

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
13
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
14 15 16
    logger.addHandler(handler)

<#list configurations as config>
Nicola Gatto's avatar
Nicola Gatto committed
17
    ${config.instanceName}_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
18 19 20 21 22
    ${config.instanceName}_loader = CNNDataLoader_${config.instanceName}.CNNDataLoader_${config.instanceName}()
    ${config.instanceName}_trainer = CNNSupervisedTrainer_${config.instanceName}.CNNSupervisedTrainer_${config.instanceName}(
        ${config.instanceName}_loader,
        ${config.instanceName}_creator
    )
Nicola Gatto's avatar
Nicola Gatto committed
23 24

    ${config.instanceName}_trainer.train(
25
<#if (config.batchSize)??>
26
        batch_size=${config.batchSize},
27
</#if>
28
<#if (config.numEpoch)??>
29
        num_epoch=${config.numEpoch},
30
</#if>
31
<#if (config.loadCheckpoint)??>
32
        load_checkpoint=${config.loadCheckpoint?string("True","False")},
33 34
</#if>
<#if (config.context)??>
35
        context='${config.context}',
36 37
</#if>
<#if (config.normalize)??>
38
        normalize=${config.normalize?string("True","False")},
39
</#if>
40 41 42
<#if (config.useTeacherForcing)??>
        use_teacher_forcing='${config.useTeacherForcing?string("True","False")}',
</#if>
43 44 45
<#if (config.saveAttentionImage)??>
        save_attention_image='${config.saveAttentionImage?string("True","False")}',
</#if>
46
<#if (config.evalMetric)??>
47
        eval_metric='${config.evalMetric.name}',
Sebastian N.'s avatar
Sebastian N. committed
48 49 50 51 52
        eval_metric_params={
<#if (config.evalMetric.exclude)??>
            'exclude': [<#list config.evalMetric.exclude as value>${value}<#sep>, </#list>],
</#if>
        },
53
</#if>
Eyüp Harputlu's avatar
Eyüp Harputlu committed
54 55 56 57 58 59 60 61 62 63
<#if (config.configuration.loss)??>
        loss='${config.lossName}',
<#if (config.lossParams)??>
        loss_params={
<#list config.lossParams?keys as param>
            '${param}': ${config.lossParams[param]}<#sep>,
</#list>
},
</#if>
</#if>
64
<#if (config.configuration.optimizer)??>
65 66
        optimizer='${config.optimizerName}',
        optimizer_params={
67 68 69
<#list config.optimizerParams?keys as param>
            '${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
70
}
71 72
</#if>
    )
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
73
</#list>