Commit e688f60b authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'dierkes' into 'master'

Dierkes

See merge request !39
parents 3aa54423 cab09138
Pipeline #224707 passed with stage
in 17 minutes and 19 seconds
# (c) https://github.com/MontiCore/monticore
stages:
- linux
#- linux
- windows
masterJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- apt-get update -q && apt-get install -y -q g++ libhdf5-serial-dev libhdf5-dev libopenblas-dev
- gcc --help
- mvn -Dmaven.wagon.http.ssl.insecure=true -Dmaven.wagon.http.ssl.allowall=true -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only:
- master
#masterJobLinux:
# stage: linux
# image: maven:3-jdk-8
# script:
# - apt-get update -q && apt-get install -y -q g++ libhdf5-serial-dev libhdf5-dev libopenblas-dev
# - gcc --help
# - mvn -Dmaven.wagon.http.ssl.insecure=true -Dmaven.wagon.http.ssl.allowall=true -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
# - cat target/site/jacoco/index.html
# - mvn package sonar:sonar -s settings.xml
# only:
# - master
masterJobWindows:
stage: windows
script:
- mvn -Dmaven.wagon.http.ssl.insecure=true -Dmaven.wagon.http.ssl.allowall=true -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -U
- mvn -Dmaven.wagon.http.ssl.insecure=true -Dmaven.wagon.http.ssl.allowall=true -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml -U
only:
- master
tags:
- Windows10
BranchJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- apt-get update -q && apt-get install -y -q g++ libhdf5-serial-dev libhdf5-dev libopenblas-dev
- mvn -Dmaven.wagon.http.ssl.insecure=true -Dmaven.wagon.http.ssl.allowall=true -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- cat target/site/jacoco/index.html
except:
- master
#BranchJobLinux:
# stage: linux
# image: maven:3-jdk-8
# script:
# - apt-get update -q && apt-get install -y -q g++ libhdf5-serial-dev libhdf5-dev libopenblas-dev
# - mvn -Dmaven.wagon.http.ssl.insecure=true -Dmaven.wagon.http.ssl.allowall=true -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
# - cat target/site/jacoco/index.html
# except:
# - master
BranchJobWindows:
stage: windows
......
......@@ -15,8 +15,21 @@
* Example: `getCMakeConfig().addModuleDependency(new CMakeFindModule("LibName", "LibHeader.hpp", "libname", headerSearchPaths, bibrarySearchPaths, findHeaderEnabled, findLibEnabled, isRequiered));`
* Additionally any CMake command can be inserted via `getCMakeConfig().addCMakeCommand("CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-deprecated\"")`or at the end via `addCMakeCommandEnd("#some command at the end")`
### How to add a new Command to EMAM
* add your command java file in src /src/main/java/de/monticore/lang/generator/cpp/commands. You can
look at CeilMathCommand.java as an example for an command that already exists in Armadillo. For a
more complex command, you can look at scaleCube. CeilMathCommand.java will be find in adi-dev branch, after merging this branch into master the
ceil command will be available also on the master.
* register your command in MathCommandRegisterCPP.java
* add yourCommandTest.emam in EMAM2Cpp/src/test/resources/test/math
* add a cpp file for your command in EMAM2Cpp/src/test/resources/results/armadillo/testMath/l0. This
file will be compared to the generated one from yourCommandTest.emam.
* add your test in
EMAM2Cpp/src/test/java/de/monticore/lang/monticar/generator/cpp/armadillo/ArmadilloFunctionTest.java
### Note on find_package
If no search directory is specified CMake will search on default locations. For linux this is _/usr/lib_ , _usr/local/lib_ , _usr/include_ etc. Windows systems does not have a default library path. The generated CMake files also are using environment variables as hint. If a package could not be found but it is installed somewhere on the system please create an environment variable **PackageName_HOME**.
Here an example for Armadillo:
Create a environment variable called _Armadillo_Home_ with the path to the base directory of your Armadillo installation.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -15,6 +15,8 @@ public class MathCommandRegisterCPP extends MathCommandRegister {
protected void init() {
//registerMathCommand("size", "testo");
registerMathCommand(new MathScaleCubeCommand());
registerMathCommand(new MathJoinCubeDimCommand());
registerMathCommand(new MathAtan2Command());
registerMathCommand(new MathLog2Command());
registerMathCommand(new MathSizeCommand());
......
package de.monticore.lang.monticar.generator.cpp.commands;
import de.monticore.lang.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math._symboltable.expression.MathValueType;
import de.monticore.lang.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.*;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.converter.ExecuteMethodGenerator;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
public class MathJoinCubeDimCommand extends MathCommand {
//todo
private static final String JOINER_SYNTAX_EXTENDED = "joinCubeDim( CUBE , CUBE , DIM )";
private static final String JOINER_METHOD_NAME = "joinCubeDim";
private static int scalerCommandCounter = 0;
public MathJoinCubeDimCommand() {
setMathCommandName("joinCubeDim");
}
@Override
public void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
String backendName = MathConverter.curBackend.getBackendName();
if (backendName.equals("OctaveBackend")) {
convertUsingOctaveBackend(mathExpressionSymbol, bluePrint);
} else if (backendName.equals("ArmadilloBackend")) {
convertUsingArmadilloBackend(mathExpressionSymbol, bluePrint);
}
}
public void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
Log.error("Not implemented for Octave Backend");
}
public void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
BluePrintCPP bluePrintCPP = (BluePrintCPP) bluePrint;
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, bluePrintCPP);
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 3) {
MathMatrixAccessSymbol cube1 = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(0);
MathMatrixAccessSymbol cube2 = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1);
MathMatrixAccessSymbol axis = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2);
convertExtendedJoinerImplementationArmadillo(valueListString, mathMatrixNameExpressionSymbol, bluePrintCPP);
} else {
//todo
Log.error(String.format("No implementation found for joinCubeDim operation: \"joinCubeDim(%s)\".", mathExpressionSymbol.getTextualRepresentation(), JOINER_SYNTAX_EXTENDED));
}
}
/**
* Implements a scaleCube function with syntax "scaleCube( CUBE , AXIS , NEW_X , NEW_Y )"
*
*/
private void convertExtendedJoinerImplementationArmadillo(String valueString, MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, BluePrint bluePrint) {
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear();
// create method
Method calcJoinerMethod = getJoinerCalculationMethod(bluePrint);
// create code string
String code = calcJoinerMethod.getName() + valueString;
MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().add(new MathMatrixAccessSymbol(codeExpr));
// add method to bluePrint
bluePrint.addMethod(calcJoinerMethod);
}
private Method getJoinerCalculationMethod(BluePrint bluePrint) {
// create new method
Method method = getNewEmptyScalerCalculationMethod();
// parameters
Variable c1 = new Variable();
c1.setName("c1");
c1.setVariableType(new VariableType("Cube", "cube", ""));
Variable c2 = new Variable();
c2.setName("c2");
c2.setVariableType(new VariableType("Cube", "cube", ""));
Variable dim = new Variable();
dim.setName("dim");
dim.setVariableType(new VariableType("Integer", "int", ""));
method.addParameter(c1);
method.addParameter(c2);
method.addParameter(dim);
// add instructions
method.addInstruction(methodBody());
return method;
}
private Method getNewEmptyScalerCalculationMethod() {
scalerCommandCounter++;
Method method = new Method();
method.setName(JOINER_METHOD_NAME + scalerCommandCounter);
method.setReturnTypeName("cube");
return method;
}
private Instruction methodBody() {
return new Instruction() {
@Override
public String getTargetLanguageInstruction() {
return " if (dim == 0) {\n" +
" c1.insert_rows(c1.n_rows, c2);\n" +
" return c1;\n" +
" } else if(dim == 1) {\n" +
" c1.insert_cols(c1.n_cols, c2);\n" +
" return c1;\n" +
" }\n" +
"\n" +
" c1 = arma::join_slices(c1, c2);\n" +
" return c1;";
}
@Override
public boolean isConnectInstruction() {
return false;
}
};
}
}
package de.monticore.lang.monticar.generator.cpp.commands;
import de.monticore.lang.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math._symboltable.expression.MathValueType;
import de.monticore.lang.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.*;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.OctaveHelper;
import de.monticore.lang.monticar.generator.cpp.converter.ExecuteMethodGenerator;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
public class MathScaleCubeCommand extends MathCommand {
private static final String SCALER_SYNTAX_EXTENDED = "scaleCube( EXPRESSION , AXIS , NEW_X , NEW_Y )";
private static final String SCALER_METHOD_NAME = "scaleCube";
private static int scalerCommandCounter = 0;
public MathScaleCubeCommand() {
setMathCommandName("scaleCube");
}
@Override
public void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
String backendName = MathConverter.curBackend.getBackendName();
if (backendName.equals("OctaveBackend")) {
convertUsingOctaveBackend(mathExpressionSymbol, bluePrint);
} else if (backendName.equals("ArmadilloBackend")) {
convertUsingArmadilloBackend(mathExpressionSymbol, bluePrint);
}
}
public void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
Log.error("Not implemented for Octave Backend");
}
public void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
BluePrintCPP bluePrintCPP = (BluePrintCPP) bluePrint;
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, bluePrintCPP);
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
MathMatrixAccessSymbol cube = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(0);
MathMatrixAccessSymbol axis = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1);
MathMatrixAccessSymbol new_x = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2);
MathMatrixAccessSymbol new_y = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3);
convertExtendedScalerImplementationArmadillo(valueListString, mathMatrixNameExpressionSymbol, bluePrintCPP);
} else {
Log.error(String.format("No implementation found for scaleCube operation: \"scaleCube(%s)\".", mathExpressionSymbol.getTextualRepresentation(), SCALER_SYNTAX_EXTENDED));
}
}
/**
* Implements a scaleCube function with syntax "scaleCube( CUBE , AXIS , NEW_X , NEW_Y )"
*
*/
private void convertExtendedScalerImplementationArmadillo(String valueString, MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, BluePrint bluePrint) {
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear();
// create method
Method calcScalerMethod = getScalerCalculationMethod(bluePrint);
// create code string
String code = calcScalerMethod.getName() + valueString;
MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().add(new MathMatrixAccessSymbol(codeExpr));
// add method to bluePrint
bluePrint.addMethod(calcScalerMethod);
}
private Method getScalerCalculationMethod(BluePrint bluePrint) {
// create new method
Method method = getNewEmptyScalerCalculationMethod();
// parameters
Variable img = new Variable();
img.setName("img");
img.setVariableType(new VariableType("Cube", "cube", ""));
Variable depth_axis = new Variable();
depth_axis.setName("depth_axis");
depth_axis.setVariableType(new VariableType("Integer", "int", ""));
Variable new_x = new Variable();
new_x.setName("new_x");
new_x.setVariableType(new VariableType("Integer", "int", ""));
Variable new_y = new Variable();
new_y.setName("new_y");
new_y.setVariableType(new VariableType("Integer", "int", ""));
method.addParameter(img);
method.addParameter(depth_axis);
method.addParameter(new_x);
method.addParameter(new_y);
// add instructions
method.addInstruction(methodBody());
return method;
}
private Method getNewEmptyScalerCalculationMethod() {
scalerCommandCounter++;
Method method = new Method();
method.setName(SCALER_METHOD_NAME + scalerCommandCounter);
method.setReturnTypeName("cube");
return method;
}
private Instruction methodBody() {
return new Instruction() {
@Override
public String getTargetLanguageInstruction() {
return " if (depth_axis == 0) { \n "+
" img = arma::reshape(img, img.n_cols, img.n_slices, img.n_rows);\n" +
" } else if(depth_axis == 1) {\n" +
" img = arma::reshape(img, img.n_rows, img.n_slices, img.n_cols);\n" +
" }\n" +
" \n" +
" arma::cube r_img = arma::cube(64,64, img.n_slices);\n" +
" for (int i = 0; i < img.n_slices; i++) \n" +
" {\n" +
" arma::mat cur_slice = img.slice(i);\n" +
" arma::vec X = arma::regspace(0, cur_slice.n_cols-1);\n" +
" arma::vec Y = arma::regspace(0, cur_slice.n_rows-1);\n" +
"\n" +
" float scale_x = (cur_slice.n_cols-1)/float((new_x));\n" +
" float scale_y = (cur_slice.n_rows-1)/float((new_y));\n" +
" arma::vec XI = arma::regspace(0, new_x-1) * scale_x;\n" +
" arma::vec YI = arma::regspace(0, new_y-1) * scale_y;\n" +
"\n" +
" arma::mat mat_out;\n" +
"\n" +
" arma::interp2(X, Y, cur_slice, XI, YI, mat_out);\n" +
" r_img.slice(i) = mat_out;\n" +
" }\n" +
"\n" +
" if (depth_axis == 0) {\n" +
" r_img = arma::reshape(r_img, r_img.n_slices, r_img.n_rows, r_img.n_cols);\n" +
" } else if (depth_axis == 1) {\n" +
" r_img = arma::reshape(r_img, r_img.n_rows, r_img.n_slices, r_img.n_cols);\n" +
" }\n" +
" \n" +
" return r_img;\n";
}
@Override
public boolean isConnectInstruction() {
return false;
}
};
}
}
......@@ -138,6 +138,16 @@ public class ArmadilloFunctionTest extends AbstractSymtabTest {
testMathCommand("det");
}
@Test
public void scaleCubeCommand() throws IOException {
testMathCommand("scaleCube");
}
@Test
public void joinCubeDimCommand() throws IOException {
testMathCommand("joinCubeDim");
}
@Test
public void testDiagCommand() throws IOException {
testMathCommand("diag");
......
#ifndef TEST_MATH_JOINCUBEDIMCOMMANDTEST
#define TEST_MATH_JOINCUBEDIMCOMMANDTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo"
using namespace arma;
class test_math_joinCubeDimCommandTest{
public:
cube img_in;
cube img_out;
void init()
{
img_in = cube(1, 28, 28);
img_out = cube(2, 28, 28);
}
cube joinCubeDim1(cube c1, cube c2, int dim)
{
if (dim == 0) {
c1.insert_rows(c1.n_rows, c2);
return c1;
} else if(dim == 1) {
c1.insert_cols(c1.n_cols, c2);
return c1;
}
c1 = arma::join_slices(c1, c2);
return c1;}
void execute()
{
img_out = joinCubeDim1(img_in, img_in, 0);
}
};
#endif
#ifndef TEST_MATH_SCALECUBECOMMANDTEST
#define TEST_MATH_SCALECUBECOMMANDTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo"
using namespace arma;
class test_math_scaleCubeCommandTest{
public:
cube img_in;
cube img_out;
void init()
{
img_in = cube(1, 28, 28);
img_out = cube(1, 64, 64);
}
cube scaleCube1(cube img, int depth_axis, int new_x, int new_y)
{
if (depth_axis == 0) {
img = arma::reshape(img, img.n_cols, img.n_slices, img.n_rows);
} else if(depth_axis == 1) {
img = arma::reshape(img, img.n_rows, img.n_slices, img.n_cols);
}
arma::cube r_img = arma::cube(64,64, img.n_slices);
for (int i = 0; i < img.n_slices; i++)
{
arma::mat cur_slice = img.slice(i);
arma::vec X = arma::regspace(0, cur_slice.n_cols-1);
arma::vec Y = arma::regspace(0, cur_slice.n_rows-1);
float scale_x = (cur_slice.n_cols-1)/float((new_x));
float scale_y = (cur_slice.n_rows-1)/float((new_y));
arma::vec XI = arma::regspace(0, new_x-1) * scale_x;
arma::vec YI = arma::regspace(0, new_y-1) * scale_y;
arma::mat mat_out;
arma::interp2(X, Y, cur_slice, XI, YI, mat_out);
r_img.slice(i) = mat_out;
}
if (depth_axis == 0) {
r_img = arma::reshape(r_img, r_img.n_slices, r_img.n_rows, r_img.n_cols);
} else if (depth_axis == 1) {
r_img = arma::reshape(r_img, r_img.n_rows, r_img.n_slices, r_img.n_cols);
}
return r_img;
}
void execute()
{
img_out = scaleCube1(img_in, 0, 64, 64);
}
};
#endif
/* (c) https://github.com/MontiCore/monticore */
package test.math;
component JoinCubeDimCommandTest{
ports in Q(-1:1)^{1,28,28} img_in,
out Q(-1:1)^{2,28,28} img_out;
implementation Math{
img_out = joinCubeDim(img_in, img_in, 0);
}
}
/* (c) https://github.com/MontiCore/monticore */
package test.math;
component ScaleCubeCommandTest{
ports in Q(-1:1)^{1,28,28} img_in,
out Q(-1:1)^{1,64,64} img_out;
implementation Math{
img_out = scaleCube(img_in, 0, 64, 64);
}
}
#ifndef TEST_MATH_SCALECUBECOMMANDTEST
#define TEST_MATH_SCALECUBECOMMANDTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo"
using namespace arma;
class test_math_scaleCubeCommandTest{
public:
cube img_in;
cube img_out;
void init()
{
img_in = cube(1, 28, 28);
img_out = cube(1, 64, 64);
}
cube scaleCube1(cube img, int depth_axis, int new_x, int new_y)
{
if (depth_axis == 0) {
img = arma::reshape(img, img.n_cols, img.n_slices, img.n_rows);
} else if(depth_axis == 1) {
img = arma::reshape(img, img.n_rows, img.n_slices, img.n_cols);
}
arma::cube r_img = arma::cube(64,64, img.n_slices);
for (int i = 0; i < img.n_slices; i++)
{
arma::mat cur_slice = img.slice(i);
arma::vec X = arma::regspace(1, cur_slice.n_cols);
arma::vec Y = arma::regspace(1, cur_slice.n_rows);
float scale_x = cur_slice.n_cols/float((new_x));
float scale_y = cur_slice.n_rows/float((new_y));
arma::vec XI = arma::regspace(1, new_x) * scale_x;
arma::vec YI = arma::regspace(1, new_y) * scale_y;
arma::mat mat_out;
arma::interp2(X, Y, cur_slice, XI, YI, mat_out);
r_img.slice(i) = mat_out;
}
if (depth_axis == 0) {
r_img = arma::reshape(r_img, r_img.n_slices, r_img.n_rows, r_img.n_cols);
} else if (depth_axis == 1) {
r_img = arma::reshape(r_img, r_img.n_rows, r_img.n_slices, r_img.n_cols);