Commit aeb44c0d authored by Julian Dierkes's avatar Julian Dierkes
Browse files

finished adding scaleCube Command

parent c1a68ce4
Pipeline #221224 failed with stages
in 2 minutes and 44 seconds
...@@ -41,25 +41,7 @@ public class MathScaleCubeCommand extends MathCommand { ...@@ -41,25 +41,7 @@ public class MathScaleCubeCommand extends MathCommand {
} }
public void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) { public void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol; Log.error("Not implemented for Octave Backend");
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<>());
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, "Fsum", "Double", valueListString, "FirstResult", false, 1), mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
// error if using extended syntax here
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
//todo
Log.error(String.format("Syntax: \"%s\" is not supported when using deprecated backend Octave", SUM_SYNTAX_EXTENDED));
}
} }
public void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) { public void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
...@@ -79,68 +61,56 @@ public class MathScaleCubeCommand extends MathCommand { ...@@ -79,68 +61,56 @@ public class MathScaleCubeCommand extends MathCommand {
MathMatrixAccessSymbol axis = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1); MathMatrixAccessSymbol axis = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1);
MathMatrixAccessSymbol new_x = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2); MathMatrixAccessSymbol new_x = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2);
MathMatrixAccessSymbol new_y = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3); MathMatrixAccessSymbol new_y = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3);
convertExtendedScalerImplementationArmadillo(valueListString, mathMatrixNameExpressionSymbol, cube, axis, new_x, new_y, bluePrintCPP); convertExtendedScalerImplementationArmadillo(valueListString, mathMatrixNameExpressionSymbol, bluePrintCPP);
} else { } else {
//todo //todo
Log.error(String.format("No implementation found for sum operation: \"sum(%s)\". Possible syntax is \"sum( X )\", \"sum(X,dim)\" or \"%s\"", mathExpressionSymbol.getTextualRepresentation(), SUM_SYNTAX_EXTENDED)); Log.error(String.format("No implementation found for scaleCube operation: \"scaleCube(%s)\".", mathExpressionSymbol.getTextualRepresentation(), SCALER_SYNTAX_EXTENDED));
} }
} }
/** /**
* Implements a sum function with syntax "sum( EXPRESSION , SUM_VARIABLE , START_VALUE , END_VALUE )" * Implements a scaleCube function with syntax "scaleCube( CUBE , AXIS , NEW_X , NEW_Y )"
* This syntax makes sum expressions easier to model.
* *
* @param mathMatrixNameExpressionSymbol symbol to convert
* @param cube expression from which the sum is calculates
* @param axis name of the sum variable
* @param new_x start value of the sum variable
* @param new_y end value of the sum variable
*/ */
private void convertExtendedScalerImplementationArmadillo(String valueString, MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, MathMatrixAccessSymbol cube, MathMatrixAccessSymbol axis, MathMatrixAccessSymbol new_x, MathMatrixAccessSymbol new_y, BluePrintCPP bluePrint) { private void convertExtendedScalerImplementationArmadillo(String valueString, MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, BluePrint bluePrint) {
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol(""); mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol(""); mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear(); mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear();
// create method // create method
Method calcSumMethod = getScalerCalculationMethod(cube, axis, new_x, new_y, bluePrint); Method calcScalerMethod = getScalerCalculationMethod(bluePrint);
// create code string // create code string
String code = calcSumMethod.getName() + valueString; String code = calcScalerMethod.getName() + valueString;
MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols()); MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().add(new MathMatrixAccessSymbol(codeExpr)); mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().add(new MathMatrixAccessSymbol(codeExpr));
// add method to bluePrint // add method to bluePrint
bluePrint.addMethod(calcSumMethod); bluePrint.addMethod(calcScalerMethod);
} }
private Method getScalerCalculationMethod(MathMatrixAccessSymbol cube, MathMatrixAccessSymbol axis, MathMatrixAccessSymbol new_x2, MathMatrixAccessSymbol new_y2, BluePrint bluePrint) { private Method getScalerCalculationMethod(BluePrint bluePrint) {
// create new method // create new method
Method method = getNewEmptyScalerCalculationMethod(); Method method = getNewEmptyScalerCalculationMethod();
// generate function code
String c = ExecuteMethodGenerator.generateExecuteCode(cube, new ArrayList<>());
String a = ExecuteMethodGenerator.generateExecuteCode(axis, new ArrayList<>());
String n_x = ExecuteMethodGenerator.generateExecuteCode(new_x2, new ArrayList<>());
String n_y = ExecuteMethodGenerator.generateExecuteCode(new_y2, new ArrayList<>());
// add loop var
// parameters // parameters
Variable img = new Variable(); Variable img = new Variable();
img.setName("img"); img.setName("img");
img.setVariableType(new VariableType("Cube", "cube", "")); img.setVariableType(new VariableType("Cube", "cube", ""));
Variable depth_axis = new Variable();
depth_axis.setName("depth_axis");
depth_axis.setVariableType(new VariableType("Integer", "int", ""));
Variable new_x = new Variable(); Variable new_x = new Variable();
new_x.setName("new_x"); new_x.setName("new_x");
new_x.setVariableType(new VariableType("Integer", "int", "")); new_x.setVariableType(new VariableType("Integer", "int", ""));
Variable new_y = new Variable(); Variable new_y = new Variable();
new_y.setName("new_y"); new_y.setName("new_y");
new_y.setVariableType(new VariableType("Integer", "int", "")); new_y.setVariableType(new VariableType("Integer", "int", ""));
Variable depth_axis = new Variable();
depth_axis.setName("depth_axis");
depth_axis.setVariableType(new VariableType("Integer", "int", ""));
method.addParameter(img); method.addParameter(img);
method.addParameter(depth_axis);
method.addParameter(new_x); method.addParameter(new_x);
method.addParameter(new_y); method.addParameter(new_y);
method.addParameter(depth_axis);
// add instructions
// add instructions
method.addInstruction(methodBody());
method.addInstruction(ifClauses());
return method; return method;
} }
...@@ -152,7 +122,7 @@ public class MathScaleCubeCommand extends MathCommand { ...@@ -152,7 +122,7 @@ public class MathScaleCubeCommand extends MathCommand {
return method; return method;
} }
private Instruction ifClauses() { private Instruction methodBody() {
return new Instruction() { return new Instruction() {
@Override @Override
public String getTargetLanguageInstruction() { public String getTargetLanguageInstruction() {
...@@ -169,10 +139,10 @@ public class MathScaleCubeCommand extends MathCommand { ...@@ -169,10 +139,10 @@ public class MathScaleCubeCommand extends MathCommand {
" arma::vec X = arma::regspace(1, cur_slice.n_cols);\n" + " arma::vec X = arma::regspace(1, cur_slice.n_cols);\n" +
" arma::vec Y = arma::regspace(1, cur_slice.n_rows);\n" + " arma::vec Y = arma::regspace(1, cur_slice.n_rows);\n" +
"\n" + "\n" +
" float scale_x = cur_slice.n_cols/new_x;\n" + " float scale_x = cur_slice.n_cols/float((new_x));\n" +
" float scale_y = cur_slice.n_rows/new_y;\n" + " float scale_y = cur_slice.n_rows/float((new_y));\n" +
" arma::vec XI = arma::regspace(1, new_x);\n" + " arma::vec XI = arma::regspace(1, new_x) * scale_x;\n" +
" arma::vec YI = arma::regspace(1, new_y);\n" + " arma::vec YI = arma::regspace(1, new_y) * scale_y;\n" +
"\n" + "\n" +
" arma::mat mat_out;\n" + " arma::mat mat_out;\n" +
"\n" + "\n" +
...@@ -186,7 +156,7 @@ public class MathScaleCubeCommand extends MathCommand { ...@@ -186,7 +156,7 @@ public class MathScaleCubeCommand extends MathCommand {
" r_img = arma::reshape(r_img, r_img.n_rows, r_img.n_slices, r_img.n_cols);\n" + " r_img = arma::reshape(r_img, r_img.n_rows, r_img.n_slices, r_img.n_cols);\n" +
" }\n" + " }\n" +
" \n" + " \n" +
" return r_img;"; " return r_img;\n";
} }
@Override @Override
......
...@@ -26,15 +26,15 @@ import java.util.Optional; ...@@ -26,15 +26,15 @@ import java.util.Optional;
*/ */
public class MathSumCommand extends MathCommand { public class MathSumCommand extends MathCommand {
//todo
private static final String SUM_SYNTAX_EXTENDED = "sum( EXPRESSION , SUM_VARIABLE , START_VALUE , END_VALUE )"; private static final String SUM_SYNTAX_EXTENDED = "sum( EXPRESSION , SUM_VARIABLE , START_VALUE , END_VALUE )";
private static final String CALC_SUM_METHOD_NAME = "scaleCube"; private static final String CALC_SUM_METHOD_NAME = "calcSum";
private static int sumCommandCounter = 0; private static int sumCommandCounter = 0;
public MathSumCommand() { public MathSumCommand() {
setMathCommandName("scaleCube"); setMathCommandName("sum");
//setTargetCommand("LALALA");
} }
@Override @Override
...@@ -66,7 +66,6 @@ public class MathSumCommand extends MathCommand { ...@@ -66,7 +66,6 @@ public class MathSumCommand extends MathCommand {
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls"); ((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
// error if using extended syntax here // error if using extended syntax here
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) { if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
//todo
Log.error(String.format("Syntax: \"%s\" is not supported when using deprecated backend Octave", SUM_SYNTAX_EXTENDED)); Log.error(String.format("Syntax: \"%s\" is not supported when using deprecated backend Octave", SUM_SYNTAX_EXTENDED));
} }
} }
...@@ -77,14 +76,17 @@ public class MathSumCommand extends MathCommand { ...@@ -77,14 +76,17 @@ public class MathSumCommand extends MathCommand {
BluePrintCPP bluePrintCPP = (BluePrintCPP) bluePrint; BluePrintCPP bluePrintCPP = (BluePrintCPP) bluePrint;
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols()) for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, bluePrintCPP); MathFunctionFixer.fixMathFunctions(accessSymbol, bluePrintCPP);
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) { if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 1) {
MathMatrixAccessSymbol cube = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(0); convertAccuSumImplementationArmadillo(mathMatrixNameExpressionSymbol, bluePrintCPP);
MathMatrixAccessSymbol axis = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1); } else if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 2) {
MathMatrixAccessSymbol new_x = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2); convertSumImplementationArmadillo(mathMatrixNameExpressionSymbol, bluePrintCPP);
MathMatrixAccessSymbol new_y = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3); } else if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
convertExtendedScalerImplementationArmadillo(mathMatrixNameExpressionSymbol, cube, axis, new_x, new_y, bluePrintCPP); MathMatrixAccessSymbol func = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(0);
MathMatrixAccessSymbol sumVar = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1);
MathMatrixAccessSymbol sumStart = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2);
MathMatrixAccessSymbol sumEnd = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3);
convertExtendedSumImplementationArmadillo(mathMatrixNameExpressionSymbol, func, sumVar, sumStart, sumEnd, bluePrintCPP);
} else { } else {
//todo
Log.error(String.format("No implementation found for sum operation: \"sum(%s)\". Possible syntax is \"sum( X )\", \"sum(X,dim)\" or \"%s\"", mathExpressionSymbol.getTextualRepresentation(), SUM_SYNTAX_EXTENDED)); Log.error(String.format("No implementation found for sum operation: \"sum(%s)\". Possible syntax is \"sum( X )\", \"sum(X,dim)\" or \"%s\"", mathExpressionSymbol.getTextualRepresentation(), SUM_SYNTAX_EXTENDED));
} }
} }
...@@ -121,17 +123,17 @@ public class MathSumCommand extends MathCommand { ...@@ -121,17 +123,17 @@ public class MathSumCommand extends MathCommand {
* This syntax makes sum expressions easier to model. * This syntax makes sum expressions easier to model.
* *
* @param mathMatrixNameExpressionSymbol symbol to convert * @param mathMatrixNameExpressionSymbol symbol to convert
* @param cube expression from which the sum is calculates * @param func expression from which the sum is calculates
* @param axis name of the sum variable * @param sumVar name of the sum variable
* @param new_x start value of the sum variable * @param sumStart start value of the sum variable
* @param sumEnd end value of the sum variable * @param sumEnd end value of the sum variable
*/ */
private void convertExtendedScalerImplementationArmadillo(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, MathMatrixAccessSymbol cube, MathMatrixAccessSymbol axis, MathMatrixAccessSymbol new_x, MathMatrixAccessSymbol new_y, BluePrintCPP bluePrint) { private void convertExtendedSumImplementationArmadillo(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, MathMatrixAccessSymbol func, MathMatrixAccessSymbol sumVar, MathMatrixAccessSymbol sumStart, MathMatrixAccessSymbol sumEnd, BluePrintCPP bluePrint) {
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol(""); mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol(""); mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear(); mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear();
// create method // create method
Method calcSumMethod = getSumCalculationMethod(cube, axis, new_x, new_y, bluePrint); Method calcSumMethod = getSumCalculationMethod(func, sumVar, sumStart, sumEnd, bluePrint);
// create code string // create code string
String code = calcSumMethod.getTargetLanguageMethodCall(); String code = calcSumMethod.getTargetLanguageMethodCall();
MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols()); MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
......
...@@ -138,6 +138,11 @@ public class ArmadilloFunctionTest extends AbstractSymtabTest { ...@@ -138,6 +138,11 @@ public class ArmadilloFunctionTest extends AbstractSymtabTest {
testMathCommand("det"); testMathCommand("det");
} }
@Test
public void scaleCubeCommand() throws IOException {
testMathCommand("scaleCube");
}
@Test @Test
public void testDiagCommand() throws IOException { public void testDiagCommand() throws IOException {
testMathCommand("diag"); testMathCommand("diag");
......
#ifndef TEST_MATH_SCALECUBECOMMANDTEST
#define TEST_MATH_SCALECUBECOMMANDTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo"
using namespace arma;
class test_math_scaleCubeCommandTest{
public:
cube img_in;
cube img_out;
void init()
{
img_in = cube(1, 28, 28);
img_out = cube(1, 64, 64);
}
cube scaleCube1(cube img, int depth_axis, int new_x, int new_y)
{
if (depth_axis == 0) {
img = arma::reshape(img, img.n_cols, img.n_slices, img.n_rows);
} else if(depth_axis == 1) {
img = arma::reshape(img, img.n_rows, img.n_slices, img.n_cols);
}
arma::cube r_img = arma::cube(64,64, img.n_slices);
for (int i = 0; i < img.n_slices; i++)
{
arma::mat cur_slice = img.slice(i);
arma::vec X = arma::regspace(1, cur_slice.n_cols);
arma::vec Y = arma::regspace(1, cur_slice.n_rows);
float scale_x = cur_slice.n_cols/float((new_x));
float scale_y = cur_slice.n_rows/float((new_y));
arma::vec XI = arma::regspace(1, new_x) * scale_x;
arma::vec YI = arma::regspace(1, new_y) * scale_y;
arma::mat mat_out;
arma::interp2(X, Y, cur_slice, XI, YI, mat_out);
r_img.slice(i) = mat_out;
}
if (depth_axis == 0) {
r_img = arma::reshape(r_img, r_img.n_slices, r_img.n_rows, r_img.n_cols);
} else if (depth_axis == 1) {
r_img = arma::reshape(r_img, r_img.n_rows, r_img.n_slices, r_img.n_cols);
}
return r_img;
}
void execute()
{
img_out = scaleCube1(img_in, 0, 64, 64);
}
};
#endif
/* (c) https://github.com/MontiCore/monticore */
package test.math;
component ScaleCubeCommandTest{
ports in Q(-1:1)^{1,28,28} img_in,
out Q(-1:1)^{1,64,64} img_out;
implementation Math{
img_out = scaleCube(img_in, 0, 64, 64);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment