CNNTrainer.ftl 1.47 KB
Newer Older
1 2 3 4 5 6 7 8 9
import logging
import mxnet as mx
<#list configurations as config>
import CNNCreator_${config.instanceName}
</#list>

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
10
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
11 12 13 14 15 16
    logger.addHandler(handler)

<#list configurations as config>
    ${config.instanceName} = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
    ${config.instanceName}.train(
<#if (config.batchSize)??>
17
        batch_size=${config.batchSize},
18
</#if>
19
<#if (config.numEpoch)??>
20
        num_epoch=${config.numEpoch},
21
</#if>
22
<#if (config.loadCheckpoint)??>
23
        load_checkpoint=${config.loadCheckpoint?string("True","False")},
24 25
</#if>
<#if (config.context)??>
26
        context='${config.context}',
27 28
</#if>
<#if (config.normalize)??>
29
        normalize=${config.normalize?string("True","False")},
30
</#if>
31
<#if (config.evalMetric)??>
32
        eval_metric='${config.evalMetric}',
33
</#if>
eyuhar's avatar
eyuhar committed
34 35 36 37 38 39 40 41 42 43
<#if (config.configuration.loss)??>
        loss='${config.lossName}',
<#if (config.lossParams)??>
        loss_params={
<#list config.lossParams?keys as param>
            '${param}': ${config.lossParams[param]}<#sep>,
</#list>
},
</#if>
</#if>
44
<#if (config.configuration.optimizer)??>
45 46
        optimizer='${config.optimizerName}',
        optimizer_params={
47 48 49
<#list config.optimizerParams?keys as param>
            '${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
50
}
51 52 53
</#if>
    )
</#list>