CifarNetwork.emadl 1.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
package simpleCifar10;

component CifarNetwork<Z(2:oo) classes = 10>{
    ports in Z(0:255)^{3, 32, 32} data,
         out Q(0:1)^{classes} softmax;

    implementation CNN {
        def conv(kernel, channels, stride=1, act=true){
            Convolution(kernel=(kernel,kernel), channels=channels, stride=(stride,stride)) ->
        }
        def resLayer(channels, stride=1, addSkipConv=false){
            (
                conv(kernel=3, channels=channels, stride=stride) ->
                conv(kernel=3, channels=channels, act=false)
            |
                conv(kernel=1, channels=channels, stride=stride, act=false, ? = addSkipConv)
            ) ->
            Add() ->
        }

        data ->
        resLayer(channels=8, addSkipConv=true) ->
        resLayer(channels=16, stride=2, addSkipConv=true) ->
        resLayer(channels=16, ->=2) ->
        resLayer(channels=32, stride=2, addSkipConv=true) ->
        resLayer(channels=32, ->=2) ->
        resLayer(channels=64, stride=2, addSkipConv=true) ->
        resLayer(channels=64, ->=2) ->
        GlobalPooling(pool_type="avg") ->
        FullyConnected(units=128) ->
        Dropout()->
        FullyConnected(units=classes) ->
        Softmax() ->
        softmax

    }
}