CNNTrainer_cifar10_cifar10Classifier_net.py 785 Bytes
Newer Older
1 2 3 4 5 6 7
import logging
import mxnet as mx
import CNNCreator_cifar10_cifar10Classifier_net

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
8
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
9 10 11 12
    logger.addHandler(handler)

    cifar10_cifar10Classifier_net = CNNCreator_cifar10_cifar10Classifier_net.CNNCreator_cifar10_cifar10Classifier_net()
    cifar10_cifar10Classifier_net.train(
13
        batch_size=5,
14 15
        num_epoch=10,
        load_checkpoint=False,
16
        context='cpu',
17 18 19 20 21 22 23
        normalize=True,
        optimizer='adam',
        optimizer_params={
            'weight_decay': 1.0E-4,
            'learning_rate': 0.01,
            'learning_rate_decay': 0.8,
            'step_size': 1000}
24
    )