CNNTrainer.ftl 1.56 KB
Newer Older
1 2
import logging
import mxnet as mx
Nicola Gatto's avatar
Nicola Gatto committed
3
import supervised_trainer
4 5
<#list configurations as config>
import CNNCreator_${config.instanceName}
Nicola Gatto's avatar
Nicola Gatto committed
6
import CNNDataLoader_${config.instanceName}
7 8 9 10 11
</#list>

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
12
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
13 14 15
    logger.addHandler(handler)

<#list configurations as config>
Nicola Gatto's avatar
Nicola Gatto committed
16 17 18 19 20 21
    ${config.instanceName}_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
    ${config.instanceName}_loader = CNNDataLoader_${config.instanceName}.${config.instanceName}DataLoader()
    ${config.instanceName}_trainer = supervised_trainer.CNNSupervisedTrainer(${config.instanceName}_loader,
        ${config.instanceName}_creator)

    ${config.instanceName}_trainer.train(
22
<#if (config.batchSize)??>
23
        batch_size=${config.batchSize},
24
</#if>
25
<#if (config.numEpoch)??>
26
        num_epoch=${config.numEpoch},
27
</#if>
28
<#if (config.loadCheckpoint)??>
29
        load_checkpoint=${config.loadCheckpoint?string("True","False")},
30 31
</#if>
<#if (config.context)??>
32
        context='${config.context}',
33 34
</#if>
<#if (config.normalize)??>
35
        normalize=${config.normalize?string("True","False")},
36
</#if>
37
<#if (config.evalMetric)??>
38
        eval_metric='${config.evalMetric}',
39
</#if>
40
<#if (config.configuration.optimizer)??>
41 42
        optimizer='${config.optimizerName}',
        optimizer_params={
43 44 45
<#list config.optimizerParams?keys as param>
            '${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
46
}
47 48 49
</#if>
    )
</#list>