CNNTrainer_fullConfig.py 1.88 KB
Newer Older
1
2
3
4
5
6
from caffe2.python import workspace, core, model_helper, brew, optimizer
from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2

import numpy as np
import cv2
7
8
9
10
11
12
13
14
15
16
17
18
import logging
import CNNCreator_fullConfig

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
    logger.addHandler(handler)

    fullConfig = CNNCreator_fullConfig.CNNCreator_fullConfig()
    fullConfig.train(
        num_epoch=5,
19
        batch_size=100,
20
21
        context='gpu',
        eval_metric='mse',
22
23
        opt_type='rmsprop',
        epsilon=1.0E-6,
24
        weight_decay=0.01,
25
        gamma=0.9,
26
27
28
        policy='step',
        base_learning_rate=0.001,
        stepsize=1000
29
    )
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

    print '\n********************************************'
    print("Loading Deploy model")

    context='gpu'
    if context == 'cpu':
        device_opts = core.DeviceOption(caffe2_pb2.CPU, 0)
        print("CPU mode selected")
    elif context == 'gpu':
        device_opts = core.DeviceOption(caffe2_pb2.CUDA, 0)
        print("GPU mode selected")

    LeNet.load_net(LeNet.INIT_NET, LeNet.PREDICT_NET, device_opts=device_opts)

    img = cv2.imread("3.jpg")                                   # Load test image
    img = cv2.resize(img, (28,28))                              # Resize to 28x28
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY )                # Covert to grayscale
    img = img.reshape((1,1,28,28)).astype('float32')            # Reshape to (1,1,28,28)
    workspace.FeedBlob("data", img, device_option=device_opts)  # FeedBlob
    workspace.RunNet('deploy_net', num_iter=1)                  # Forward

    print("\nInput: {}".format(img.shape))
    pred = workspace.FetchBlob("predictions")
    print("Output: {}".format(pred))
    print("Output class: {}".format(np.argmax(pred)))