CNNTrainer.ftl 2.63 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
2 3 4 5
import logging
import mxnet as mx
<#list configurations as config>
import CNNCreator_${config.instanceName}
Nicola Gatto's avatar
Nicola Gatto committed
6
import CNNDataLoader_${config.instanceName}
7
import CNNSupervisedTrainer_${config.instanceName}
8 9 10 11 12
</#list>

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
13
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
14 15 16
    logger.addHandler(handler)

<#list configurations as config>
Nicola Gatto's avatar
Nicola Gatto committed
17
    ${config.instanceName}_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
18 19 20 21 22
    ${config.instanceName}_loader = CNNDataLoader_${config.instanceName}.CNNDataLoader_${config.instanceName}()
    ${config.instanceName}_trainer = CNNSupervisedTrainer_${config.instanceName}.CNNSupervisedTrainer_${config.instanceName}(
        ${config.instanceName}_loader,
        ${config.instanceName}_creator
    )
Nicola Gatto's avatar
Nicola Gatto committed
23 24

    ${config.instanceName}_trainer.train(
25
<#if (config.batchSize)??>
26
        batch_size=${config.batchSize},
27
</#if>
28
<#if (config.numEpoch)??>
29
        num_epoch=${config.numEpoch},
30
</#if>
31
<#if (config.loadCheckpoint)??>
32
        load_checkpoint=${config.loadCheckpoint?string("True","False")},
33
</#if>
34 35 36 37 38 39
<#if (config.checkpointPeriod)??>
        checkpoint_period=${config.checkpointPeriod},
</#if>
<#if (config.logPeriod)??>
        log_period=${config.logPeriod},
</#if>
40
<#if (config.context)??>
41
        context='${config.context}',
42 43
</#if>
<#if (config.normalize)??>
44
        normalize=${config.normalize?string("True","False")},
45
</#if>
46 47 48
<#if (config.useTeacherForcing)??>
        use_teacher_forcing='${config.useTeacherForcing?string("True","False")}',
</#if>
49 50 51
<#if (config.saveAttentionImage)??>
        save_attention_image='${config.saveAttentionImage?string("True","False")}',
</#if>
52
<#if (config.evalMetric)??>
53
        eval_metric='${config.evalMetric.name}',
Sebastian N.'s avatar
Sebastian N. committed
54 55 56 57 58
        eval_metric_params={
<#if (config.evalMetric.exclude)??>
            'exclude': [<#list config.evalMetric.exclude as value>${value}<#sep>, </#list>],
</#if>
        },
59
</#if>
60 61 62
<#if (config.evalTrain)??>
        eval_train=${config.evalTrain?string("True","False")},
</#if>
Eyüp Harputlu's avatar
Eyüp Harputlu committed
63 64 65 66 67 68 69 70 71 72
<#if (config.configuration.loss)??>
        loss='${config.lossName}',
<#if (config.lossParams)??>
        loss_params={
<#list config.lossParams?keys as param>
            '${param}': ${config.lossParams[param]}<#sep>,
</#list>
},
</#if>
</#if>
73
<#if (config.configuration.optimizer)??>
74 75
        optimizer='${config.optimizerName}',
        optimizer_params={
76 77 78
<#list config.optimizerParams?keys as param>
            '${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
79
}
80 81
</#if>
    )
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
82
</#list>