CNNTrainer.ftl 2.07 KB
Newer Older
1 2 3 4
import logging
import mxnet as mx
<#list configurations as config>
import CNNCreator_${config.instanceName}
Nicola Gatto's avatar
Nicola Gatto committed
5
import CNNDataLoader_${config.instanceName}
6
import CNNSupervisedTrainer_${config.instanceName}
7 8 9 10 11
</#list>

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
12
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
13 14 15
    logger.addHandler(handler)

<#list configurations as config>
Nicola Gatto's avatar
Nicola Gatto committed
16
    ${config.instanceName}_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
17 18 19 20 21
    ${config.instanceName}_loader = CNNDataLoader_${config.instanceName}.CNNDataLoader_${config.instanceName}()
    ${config.instanceName}_trainer = CNNSupervisedTrainer_${config.instanceName}.CNNSupervisedTrainer_${config.instanceName}(
        ${config.instanceName}_loader,
        ${config.instanceName}_creator
    )
Nicola Gatto's avatar
Nicola Gatto committed
22 23

    ${config.instanceName}_trainer.train(
24
<#if (config.batchSize)??>
25
        batch_size=${config.batchSize},
26
</#if>
27
<#if (config.numEpoch)??>
28
        num_epoch=${config.numEpoch},
29
</#if>
30
<#if (config.loadCheckpoint)??>
31
        load_checkpoint=${config.loadCheckpoint?string("True","False")},
32 33
</#if>
<#if (config.context)??>
34
        context='${config.context}',
35 36
</#if>
<#if (config.normalize)??>
37
        normalize=${config.normalize?string("True","False")},
38
</#if>
39
<#if (config.evalMetric)??>
Sebastian N.'s avatar
Sebastian N. committed
40 41 42 43 44 45
        eval_metric='${config.evalMetric.metric}',
        eval_metric_params={
<#if (config.evalMetric.exclude)??>
            'exclude': [<#list config.evalMetric.exclude as value>${value}<#sep>, </#list>],
</#if>
        },
46
</#if>
Eyüp Harputlu's avatar
Eyüp Harputlu committed
47 48 49 50 51 52 53 54 55 56
<#if (config.configuration.loss)??>
        loss='${config.lossName}',
<#if (config.lossParams)??>
        loss_params={
<#list config.lossParams?keys as param>
            '${param}': ${config.lossParams[param]}<#sep>,
</#list>
},
</#if>
</#if>
57
<#if (config.configuration.optimizer)??>
58 59
        optimizer='${config.optimizerName}',
        optimizer_params={
60 61 62
<#list config.optimizerParams?keys as param>
            '${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
63
}
64 65 66
</#if>
    )
</#list>