Commit aeb44c0d authored by Julian Dierkes's avatar Julian Dierkes
Browse files

finished adding scaleCube Command

parent c1a68ce4
Pipeline #221224 failed with stages
in 2 minutes and 44 seconds
......@@ -41,25 +41,7 @@ public class MathScaleCubeCommand extends MathCommand {
}
public void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<>());
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, "Fsum", "Double", valueListString, "FirstResult", false, 1), mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
// error if using extended syntax here
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
//todo
Log.error(String.format("Syntax: \"%s\" is not supported when using deprecated backend Octave", SUM_SYNTAX_EXTENDED));
}
Log.error("Not implemented for Octave Backend");
}
public void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
......@@ -79,68 +61,56 @@ public class MathScaleCubeCommand extends MathCommand {
MathMatrixAccessSymbol axis = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1);
MathMatrixAccessSymbol new_x = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2);
MathMatrixAccessSymbol new_y = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3);
convertExtendedScalerImplementationArmadillo(valueListString, mathMatrixNameExpressionSymbol, cube, axis, new_x, new_y, bluePrintCPP);
convertExtendedScalerImplementationArmadillo(valueListString, mathMatrixNameExpressionSymbol, bluePrintCPP);
} else {
//todo
Log.error(String.format("No implementation found for sum operation: \"sum(%s)\". Possible syntax is \"sum( X )\", \"sum(X,dim)\" or \"%s\"", mathExpressionSymbol.getTextualRepresentation(), SUM_SYNTAX_EXTENDED));
Log.error(String.format("No implementation found for scaleCube operation: \"scaleCube(%s)\".", mathExpressionSymbol.getTextualRepresentation(), SCALER_SYNTAX_EXTENDED));
}
}
/**
* Implements a sum function with syntax "sum( EXPRESSION , SUM_VARIABLE , START_VALUE , END_VALUE )"
* This syntax makes sum expressions easier to model.
* Implements a scaleCube function with syntax "scaleCube( CUBE , AXIS , NEW_X , NEW_Y )"
*
* @param mathMatrixNameExpressionSymbol symbol to convert
* @param cube expression from which the sum is calculates
* @param axis name of the sum variable
* @param new_x start value of the sum variable
* @param new_y end value of the sum variable
*/
private void convertExtendedScalerImplementationArmadillo(String valueString, MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, MathMatrixAccessSymbol cube, MathMatrixAccessSymbol axis, MathMatrixAccessSymbol new_x, MathMatrixAccessSymbol new_y, BluePrintCPP bluePrint) {
private void convertExtendedScalerImplementationArmadillo(String valueString, MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, BluePrint bluePrint) {
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear();
// create method
Method calcSumMethod = getScalerCalculationMethod(cube, axis, new_x, new_y, bluePrint);
Method calcScalerMethod = getScalerCalculationMethod(bluePrint);
// create code string
String code = calcSumMethod.getName() + valueString;
String code = calcScalerMethod.getName() + valueString;
MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().add(new MathMatrixAccessSymbol(codeExpr));
// add method to bluePrint
bluePrint.addMethod(calcSumMethod);
bluePrint.addMethod(calcScalerMethod);
}
private Method getScalerCalculationMethod(MathMatrixAccessSymbol cube, MathMatrixAccessSymbol axis, MathMatrixAccessSymbol new_x2, MathMatrixAccessSymbol new_y2, BluePrint bluePrint) {
private Method getScalerCalculationMethod(BluePrint bluePrint) {
// create new method
Method method = getNewEmptyScalerCalculationMethod();
// generate function code
String c = ExecuteMethodGenerator.generateExecuteCode(cube, new ArrayList<>());
String a = ExecuteMethodGenerator.generateExecuteCode(axis, new ArrayList<>());
String n_x = ExecuteMethodGenerator.generateExecuteCode(new_x2, new ArrayList<>());
String n_y = ExecuteMethodGenerator.generateExecuteCode(new_y2, new ArrayList<>());
// add loop var
// parameters
Variable img = new Variable();
img.setName("img");
img.setVariableType(new VariableType("Cube", "cube", ""));
Variable depth_axis = new Variable();
depth_axis.setName("depth_axis");
depth_axis.setVariableType(new VariableType("Integer", "int", ""));
Variable new_x = new Variable();
new_x.setName("new_x");
new_x.setVariableType(new VariableType("Integer", "int", ""));
Variable new_y = new Variable();
new_y.setName("new_y");
new_y.setVariableType(new VariableType("Integer", "int", ""));
Variable depth_axis = new Variable();
depth_axis.setName("depth_axis");
depth_axis.setVariableType(new VariableType("Integer", "int", ""));
method.addParameter(img);
method.addParameter(depth_axis);
method.addParameter(new_x);
method.addParameter(new_y);
method.addParameter(depth_axis);
// add instructions
// add instructions
method.addInstruction(methodBody());
method.addInstruction(ifClauses());
return method;
}
......@@ -152,7 +122,7 @@ public class MathScaleCubeCommand extends MathCommand {
return method;
}
private Instruction ifClauses() {
private Instruction methodBody() {
return new Instruction() {
@Override
public String getTargetLanguageInstruction() {
......@@ -169,10 +139,10 @@ public class MathScaleCubeCommand extends MathCommand {
" arma::vec X = arma::regspace(1, cur_slice.n_cols);\n" +
" arma::vec Y = arma::regspace(1, cur_slice.n_rows);\n" +
"\n" +
" float scale_x = cur_slice.n_cols/new_x;\n" +
" float scale_y = cur_slice.n_rows/new_y;\n" +
" arma::vec XI = arma::regspace(1, new_x);\n" +
" arma::vec YI = arma::regspace(1, new_y);\n" +
" float scale_x = cur_slice.n_cols/float((new_x));\n" +
" float scale_y = cur_slice.n_rows/float((new_y));\n" +
" arma::vec XI = arma::regspace(1, new_x) * scale_x;\n" +
" arma::vec YI = arma::regspace(1, new_y) * scale_y;\n" +
"\n" +
" arma::mat mat_out;\n" +
"\n" +
......@@ -186,7 +156,7 @@ public class MathScaleCubeCommand extends MathCommand {
" r_img = arma::reshape(r_img, r_img.n_rows, r_img.n_slices, r_img.n_cols);\n" +
" }\n" +
" \n" +
" return r_img;";
" return r_img;\n";
}
@Override
......
......@@ -26,15 +26,15 @@ import java.util.Optional;
*/
public class MathSumCommand extends MathCommand {
//todo
private static final String SUM_SYNTAX_EXTENDED = "sum( EXPRESSION , SUM_VARIABLE , START_VALUE , END_VALUE )";
private static final String CALC_SUM_METHOD_NAME = "scaleCube";
private static final String CALC_SUM_METHOD_NAME = "calcSum";
private static int sumCommandCounter = 0;
public MathSumCommand() {
setMathCommandName("scaleCube");
setMathCommandName("sum");
//setTargetCommand("LALALA");
}
@Override
......@@ -66,7 +66,6 @@ public class MathSumCommand extends MathCommand {
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
// error if using extended syntax here
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
//todo
Log.error(String.format("Syntax: \"%s\" is not supported when using deprecated backend Octave", SUM_SYNTAX_EXTENDED));
}
}
......@@ -77,14 +76,17 @@ public class MathSumCommand extends MathCommand {
BluePrintCPP bluePrintCPP = (BluePrintCPP) bluePrint;
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, bluePrintCPP);
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
MathMatrixAccessSymbol cube = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(0);
MathMatrixAccessSymbol axis = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1);
MathMatrixAccessSymbol new_x = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2);
MathMatrixAccessSymbol new_y = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3);
convertExtendedScalerImplementationArmadillo(mathMatrixNameExpressionSymbol, cube, axis, new_x, new_y, bluePrintCPP);
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 1) {
convertAccuSumImplementationArmadillo(mathMatrixNameExpressionSymbol, bluePrintCPP);
} else if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 2) {
convertSumImplementationArmadillo(mathMatrixNameExpressionSymbol, bluePrintCPP);
} else if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
MathMatrixAccessSymbol func = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(0);
MathMatrixAccessSymbol sumVar = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1);
MathMatrixAccessSymbol sumStart = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2);
MathMatrixAccessSymbol sumEnd = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3);
convertExtendedSumImplementationArmadillo(mathMatrixNameExpressionSymbol, func, sumVar, sumStart, sumEnd, bluePrintCPP);
} else {
//todo
Log.error(String.format("No implementation found for sum operation: \"sum(%s)\". Possible syntax is \"sum( X )\", \"sum(X,dim)\" or \"%s\"", mathExpressionSymbol.getTextualRepresentation(), SUM_SYNTAX_EXTENDED));
}
}
......@@ -121,17 +123,17 @@ public class MathSumCommand extends MathCommand {
* This syntax makes sum expressions easier to model.
*
* @param mathMatrixNameExpressionSymbol symbol to convert
* @param cube expression from which the sum is calculates
* @param axis name of the sum variable
* @param new_x start value of the sum variable
* @param func expression from which the sum is calculates
* @param sumVar name of the sum variable
* @param sumStart start value of the sum variable
* @param sumEnd end value of the sum variable
*/
private void convertExtendedScalerImplementationArmadillo(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, MathMatrixAccessSymbol cube, MathMatrixAccessSymbol axis, MathMatrixAccessSymbol new_x, MathMatrixAccessSymbol new_y, BluePrintCPP bluePrint) {
private void convertExtendedSumImplementationArmadillo(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, MathMatrixAccessSymbol func, MathMatrixAccessSymbol sumVar, MathMatrixAccessSymbol sumStart, MathMatrixAccessSymbol sumEnd, BluePrintCPP bluePrint) {
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear();
// create method
Method calcSumMethod = getSumCalculationMethod(cube, axis, new_x, new_y, bluePrint);
Method calcSumMethod = getSumCalculationMethod(func, sumVar, sumStart, sumEnd, bluePrint);
// create code string
String code = calcSumMethod.getTargetLanguageMethodCall();
MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
......
......@@ -138,6 +138,11 @@ public class ArmadilloFunctionTest extends AbstractSymtabTest {
testMathCommand("det");
}
@Test
public void scaleCubeCommand() throws IOException {
testMathCommand("scaleCube");
}
@Test
public void testDiagCommand() throws IOException {
testMathCommand("diag");
......
#ifndef TEST_MATH_SCALECUBECOMMANDTEST
#define TEST_MATH_SCALECUBECOMMANDTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo"
using namespace arma;
class test_math_scaleCubeCommandTest{
public:
cube img_in;
cube img_out;
void init()
{
img_in = cube(1, 28, 28);
img_out = cube(1, 64, 64);
}
cube scaleCube1(cube img, int depth_axis, int new_x, int new_y)
{
if (depth_axis == 0) {
img = arma::reshape(img, img.n_cols, img.n_slices, img.n_rows);
} else if(depth_axis == 1) {
img = arma::reshape(img, img.n_rows, img.n_slices, img.n_cols);
}
arma::cube r_img = arma::cube(64,64, img.n_slices);
for (int i = 0; i < img.n_slices; i++)
{
arma::mat cur_slice = img.slice(i);
arma::vec X = arma::regspace(1, cur_slice.n_cols);
arma::vec Y = arma::regspace(1, cur_slice.n_rows);
float scale_x = cur_slice.n_cols/float((new_x));
float scale_y = cur_slice.n_rows/float((new_y));
arma::vec XI = arma::regspace(1, new_x) * scale_x;
arma::vec YI = arma::regspace(1, new_y) * scale_y;
arma::mat mat_out;
arma::interp2(X, Y, cur_slice, XI, YI, mat_out);
r_img.slice(i) = mat_out;
}
if (depth_axis == 0) {
r_img = arma::reshape(r_img, r_img.n_slices, r_img.n_rows, r_img.n_cols);
} else if (depth_axis == 1) {
r_img = arma::reshape(r_img, r_img.n_rows, r_img.n_slices, r_img.n_cols);
}
return r_img;
}
void execute()
{
img_out = scaleCube1(img_in, 0, 64, 64);
}
};
#endif
/* (c) https://github.com/MontiCore/monticore */
package test.math;
component ScaleCubeCommandTest{
ports in Q(-1:1)^{1,28,28} img_in,
out Q(-1:1)^{1,64,64} img_out;
implementation Math{
img_out = scaleCube(img_in, 0, 64, 64);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment