CifarNetwork.emadl 1.21 KB
Newer Older
1 2 3 4 5 6 7 8
package simpleCifar10;

component CifarNetwork<Z(2:oo) classes = 10>{
    ports in Z(0:255)^{3, 32, 32} data,
         out Q(0:1)^{classes} softmax;

    implementation CNN {
        def conv(kernel, channels, stride=1, act=true){
Christopher Jan-Steffen Brix's avatar
Christopher Jan-Steffen Brix committed
9
            Convolution(kernel=(kernel,kernel), channels=channels, stride=(stride,stride))
10 11 12 13 14 15 16 17
        }
        def resLayer(channels, stride=1, addSkipConv=false){
            (
                conv(kernel=3, channels=channels, stride=stride) ->
                conv(kernel=3, channels=channels, act=false)
            |
                conv(kernel=1, channels=channels, stride=stride, act=false, ? = addSkipConv)
            ) ->
Christopher Jan-Steffen Brix's avatar
Christopher Jan-Steffen Brix committed
18
            Add()
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
        }

        data ->
        resLayer(channels=8, addSkipConv=true) ->
        resLayer(channels=16, stride=2, addSkipConv=true) ->
        resLayer(channels=16, ->=2) ->
        resLayer(channels=32, stride=2, addSkipConv=true) ->
        resLayer(channels=32, ->=2) ->
        resLayer(channels=64, stride=2, addSkipConv=true) ->
        resLayer(channels=64, ->=2) ->
        GlobalPooling(pool_type="avg") ->
        FullyConnected(units=128) ->
        Dropout()->
        FullyConnected(units=classes) ->
        Softmax() ->
        softmax

    }
}