mnist_mnistClassifier_net.h 730 Bytes
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#ifndef MNIST_MNISTCLASSIFIER_NET
#define MNIST_MNISTCLASSIFIER_NET
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo"
#include "CNNPredictor_mnist_mnistClassifier_net.h"
#include "CNNTranslator.h"
using namespace arma;
class mnist_mnistClassifier_net{
const int classes = 10;
public:
Sebastian N.'s avatar
Sebastian N. committed
13
CNNPredictor_mnist_mnistClassifier_net_0 _predictor_0_;
14 15 16 17 18 19 20 21
icube image;
colvec predictions;
void init()
{
image = icube(1, 28, 28);
predictions=colvec(classes);
}
void execute(){
22
    vector<float> CNN_predictions_(10);
23

Sebastian N.'s avatar
Sebastian N. committed
24
    _predictor_0_.predict(CNNTranslator::translate(image),
25
                CNN_predictions_);
26

27
    predictions = CNNTranslator::translateToCol(CNN_predictions_, std::vector<size_t> {10});
28 29 30 31 32

}

};
#endif