CNNTrainer_simpleConfig.py 882 Bytes
Newer Older
1 2
import logging
import mxnet as mx
3
import CNNCreator_simpleConfig
Nicola Gatto's avatar
Nicola Gatto committed
4
import CNNDataLoader_simpleConfig
5
import CNNSupervisedTrainer_simpleConfig
6 7 8 9

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
10
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
11 12
    logger.addHandler(handler)

Nicola Gatto's avatar
Nicola Gatto committed
13
    simpleConfig_creator = CNNCreator_simpleConfig.CNNCreator_simpleConfig()
14 15 16 17 18
    simpleConfig_loader = CNNDataLoader_simpleConfig.CNNDataLoader_simpleConfig()
    simpleConfig_trainer = CNNSupervisedTrainer_simpleConfig.CNNSupervisedTrainer_simpleConfig(
        simpleConfig_loader,
        simpleConfig_creator
    )
Nicola Gatto's avatar
Nicola Gatto committed
19 20

    simpleConfig_trainer.train(
21 22
        batch_size=100,
        num_epoch=50,
Eyüp Harputlu's avatar
Eyüp Harputlu committed
23
        loss='cross_entropy',
24 25 26
        optimizer='adam',
        optimizer_params={
            'learning_rate': 0.001}
27
    )