Aufgrund einer Wartung wird GitLab am 28.09. zwischen 10:00 und 11:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 28.09. between 10:00 and 11:00 am.

CNNTranslator.h 4.15 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
1
/* (c) https://github.com/MontiCore/monticore */
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
#ifndef CNNTRANSLATOR_H
#define CNNTRANSLATOR_H
#include <armadillo>
#include <cassert>

using namespace std;
using namespace arma;

class CNNTranslator{
public:
    template<typename T> static void addColToSTDVector(const Col<T> &source, vector<float> &data){
        for(size_t i = 0; i < source.n_elem; i++){
            data.push_back((float) source(i));
        }
    }

    template<typename T> static void addRowToSTDVector(const subview_row<T> &source, vector<float> &data){
        for(size_t i = 0; i < source.n_elem; i++){
            data.push_back((float) source(i));
        }
    }

    template<typename T> static void addRowToSTDVector(const Row<T> &source, vector<float> &data){
        for(size_t i = 0; i < source.n_elem; i++){
            data.push_back((float) source(i));
        }
    }

    template<typename T> static void addMatToSTDVector(const Mat<T> &source, vector<float> &data){
        for(size_t i = 0; i < source.n_rows; i++){
            addRowToSTDVector(source.row(i), data);
        }
    }


    template<typename T> static vector<float> translate(const Col<T> &source){
        size_t size = source.n_elem;
        vector<float> data;
        data.reserve(size);
        addColToSTDVector(source, data);
        return data;
    }

    template<typename T> static vector<float> translate(const Row<T> &source){
        size_t size = source.n_elem;
        vector<float> data;
        data.reserve(size);
        addRowToSTDVector(source, data);
        return data;
    }

    template<typename T> static vector<float> translate(const Mat<T> &source){
        size_t size = source.n_elem;
        vector<float> data;
        data.reserve(size);
        addMatToSTDVector(source, data);
        return data;
    }

    template<typename T> static vector<float> translate(const Cube<T> &source){
        size_t size = source.n_elem;
        vector<float> data;
        data.reserve(size);
        for(size_t i = 0; i < source.n_slices; i++){
            addMatToSTDVector(source.slice(i), data);
        }
        return data;
    }

    static vec translateToCol(const vector<float> &source, const vector<size_t> &shape){
        assert(shape.size() == 1);
        vec column(shape[0]);
        for(size_t i = 0; i < source.size(); i++){
            column(i) = (double) source[i];
        }
        return column;
    }

    static mat translateToMat(const vector<float> &source, const vector<size_t> &shape){
        assert(shape.size() == 2);
        mat matrix(shape[1], shape[0]); //create transposed version of the matrix
        int startPos = 0;
        int endPos = matrix.n_rows;
        const vector<size_t> columnShape = {matrix.n_rows};
        for(size_t i = 0; i < matrix.n_cols; i++){
            vector<float> colSource(&source[startPos], &source[endPos]);
            matrix.col(i) = translateToCol(colSource, columnShape);
            startPos = endPos;
            endPos += matrix.n_rows;
        }
        return matrix.t();
    }

    static cube translateToCube(const vector<float> &source, const vector<size_t> &shape){
        assert(shape.size() == 3);
        cube cubeMatrix(shape[1], shape[2], shape[0]);
        const int matrixSize = shape[1] * shape[2];
        const vector<size_t> matrixShape = {shape[1], shape[2]};
        int startPos = 0;
        int endPos = matrixSize;
        for(size_t i = 0; i < cubeMatrix.n_slices; i++){
            vector<float> matrixSource(&source[startPos], &source[endPos]);
            cubeMatrix.slice(i) = translateToMat(matrixSource, matrixShape);
            startPos = endPos;
            endPos += matrixSize;
        }
        return cubeMatrix;
    }

    template<typename T> static vector<size_t> getShape(const Col<T> &source){
        return {source.n_elem};
    }

    template<typename T> static vector<size_t> getShape(const Row<T> &source){
        return {source.n_elem};
    }

    template<typename T> static vector<size_t> getShape(const Mat<T> &source){
        return {source.n_rows, source.n_cols};
    }

    template<typename T> static vector<size_t> getShape(const Cube<T> &source){
        return {source.n_slices, source.n_rows, source.n_cols};
    }
};

#endif