Commit bfdd83c3 authored by Christoph Richter's avatar Christoph Richter
Browse files

MathSumCommand: Overloaded sum function

parent 2b4a42a7
......@@ -39,6 +39,22 @@ public class Method {
parameters.add(v);
}
public boolean addParameterUnique(Variable v) {
boolean added = !containsParameter(v);
if (added) {
addParameter(v);
}
return added;
}
private boolean containsParameter(Variable v) {
boolean found = false;
for (Variable param : getParameters()) {
found |= param.getName().contentEquals(v.getName());
}
return found;
}
public List<Variable> getParameters() {
return parameters;
}
......@@ -62,4 +78,14 @@ public class Method {
public void setInstructions(List<Instruction> instructions) {
this.instructions = instructions;
}
public String getTargetLanguageMethodCall() {
String args = "";
int size = getParameters().size();
for (int i = 0; i < size - 1; i++) {
args += String.format("%s, ", parameters.get(i).getNameTargetLanguageFormat());
}
args += parameters.get(size - 1).getNameTargetLanguageFormat();
return String.format("%s(%s)", name, args);
}
}
......@@ -3,23 +3,33 @@ package de.monticore.lang.monticar.generator.cpp.commands;
import de.monticore.lang.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.BluePrint;
import de.monticore.lang.monticar.generator.MathCommand;
import de.monticore.lang.monticar.generator.*;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.OctaveHelper;
import de.monticore.lang.monticar.generator.cpp.converter.ComponentConverter;
import de.monticore.lang.monticar.generator.cpp.converter.ExecuteMethodGenerator;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
* @author Sascha Schneiders
* @author Christoph Richter
* Overloaded syntax to more math convinient way:
* sum(function, sum_variable, start_value, end_value)
*/
public class MathSumCommand extends MathCommand {
private static final String SUM_SYNTAX_EXTENDED = "sum( EXPRESSION , SUM_VARIABLE , START_VALUE , END_VALUE )";
private static final String CALC_SUM_METHOD_NAME = "calcSum";
private static int sumCommandCounter = 0;
public MathSumCommand() {
setMathCommandName("sum");
//setTargetCommand("LALALA");
......@@ -44,33 +54,196 @@ public class MathSumCommand extends MathCommand {
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, "Fsum", "Double", valueListString, "FirstResult", false, 1),mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, "Fsum", "Double", valueListString, "FirstResult", false, 1), mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
// error if using extended syntax here
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
Log.error(String.format("Syntax: \"%s\" is not supported when using deprecated backend Octave", SUM_SYNTAX_EXTENDED));
}
}
public void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
BluePrintCPP bluePrintCPP = (BluePrintCPP) bluePrint;
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
MathFunctionFixer.fixMathFunctions(accessSymbol, bluePrintCPP);
if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 1) {
convertAccuSumImplementationArmadillo(mathMatrixNameExpressionSymbol, bluePrintCPP);
} else if (mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().size() == 4) {
MathMatrixAccessSymbol func = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(0);
MathMatrixAccessSymbol sumVar = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(1);
MathMatrixAccessSymbol sumStart = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(2);
MathMatrixAccessSymbol sumEnd = mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(3);
convertExtendedSumImplementationArmadillo(mathMatrixNameExpressionSymbol, func, sumVar, sumStart, sumEnd, bluePrintCPP);
} else {
Log.error(String.format("No implementation found for sum operation: \"sum(%s)\". Possible syntax is \"sum( X )\" or \"%s\"", mathExpressionSymbol.getTextualRepresentation(), SUM_SYNTAX_EXTENDED));
}
}
/**
* Implements the sum command using Armadillos accu command
*
* @param mathMatrixNameExpressionSymbol MathMatrixNameExpressionSymbol passed to convert
* @param bluePrint BluePrint of current code generation
* @see <a href="http://arma.sourceforge.net/docs.html#accu">Armadillo Documentation</a>
*/
private void convertAccuSumImplementationArmadillo(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, BluePrintCPP bluePrint) {
String valueListString = ExecuteMethodGenerator.generateExecuteCode(mathMatrixNameExpressionSymbol, new ArrayList<>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
MathStringExpression stringExpression = new MathStringExpression("accu" + valueListString, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression("accu"+valueListString,mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("HelperA");
bluePrint.addAdditionalIncludeString("HelperA"); // question: why? (CR)
}
/**
* Implements a sum function with syntax "sum( EXPRESSION , SUM_VARIABLE , START_VALUE , END_VALUE )"
* This syntax makes sum expressions easier to model.
*
* @param mathMatrixNameExpressionSymbol symbol to convert
* @param func expression from which the sum is calculates
* @param sumVar name of the sum variable
* @param sumStart start value of the sum variable
* @param sumEnd end value of the sum variable
*/
private void convertExtendedSumImplementationArmadillo(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, MathMatrixAccessSymbol func, MathMatrixAccessSymbol sumVar, MathMatrixAccessSymbol sumStart, MathMatrixAccessSymbol sumEnd, BluePrintCPP bluePrint) {
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessStartSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setAccessEndSymbol("");
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().clear();
// create method
Method calcSumMethod = getSumCalculationMethod(func, sumVar, sumStart, sumEnd, bluePrint);
// create code string
String code = calcSumMethod.getTargetLanguageMethodCall();
MathStringExpression codeExpr = new MathStringExpression(code, mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols());
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().add(new MathMatrixAccessSymbol(codeExpr));
// add method to bluePrint
bluePrint.addMethod(calcSumMethod);
}
private Method getSumCalculationMethod(MathMatrixAccessSymbol func, MathMatrixAccessSymbol sumVar, MathMatrixAccessSymbol sumStart, MathMatrixAccessSymbol sumEnd, BluePrintCPP bluePrint) {
// create new method
Method method = getNewEmptySumCalculationMethod();
// generate function code
String f = ExecuteMethodGenerator.generateExecuteCode(func, new ArrayList<>());
String varString = ExecuteMethodGenerator.generateExecuteCode(sumVar, new ArrayList<>());
String start = ExecuteMethodGenerator.generateExecuteCode(sumStart, new ArrayList<>());
String end = ExecuteMethodGenerator.generateExecuteCode(sumEnd, new ArrayList<>());
// add loop var
Variable loopVar = generateLoopVariable(varString, bluePrint);
// parameters
setParameters(method, bluePrint);
// add instructions
method.addInstruction(accumulatorInitialization());
method.addInstruction(forLoopHeader(varString, start, end));
method.addInstruction(forLoopBody(f));
method.addInstruction(returnAccumulator());
// add loopvar to children
addLoopVarParamToMethod(method, loopVar, bluePrint);
// delete loop var
bluePrint.getMathInformationRegister().getVariables().remove(loopVar);
return method;
}
private Method getNewEmptySumCalculationMethod() {
sumCommandCounter++;
Method method = new Method();
method.setName(CALC_SUM_METHOD_NAME + sumCommandCounter);
method.setReturnTypeName("double");
return method;
}
private void setParameters(Method method, BluePrint bluePrint) {
List<Variable> vars = bluePrint.getMathInformationRegister().getVariables();
for (int i = 0; i < vars.size() - 2; i++) { // the last variable is the one we are assigning now
method.addParameterUnique(vars.get(i));
}
}
private Variable generateLoopVariable(String name, BluePrint bluePrint) {
Variable loopVar = new Variable(name, Variable.FORLOOPINFO);
loopVar.setVariableType(new VariableType("Integer", "int", ""));
bluePrint.getMathInformationRegister().addVariable(loopVar);
return loopVar;
}
private Instruction accumulatorInitialization() {
return new Instruction() {
@Override
public String getTargetLanguageInstruction() {
return " double res = 0; \n";
}
@Override
public boolean isConnectInstruction() {
return false;
}
};
}
private Instruction forLoopHeader(String sumVar, String sumStart, String sumEnd) {
return new Instruction() {
@Override
public String getTargetLanguageInstruction() {
return String.format(" for (int %s = %s - 1; %s <= %s - 1; %s++)\n", sumVar, sumStart, sumVar, sumEnd, sumVar);
}
@Override
public boolean isConnectInstruction() {
return false;
}
};
}
private Instruction forLoopBody(String func) {
return new Instruction() {
@Override
public String getTargetLanguageInstruction() {
return String.format(" res += %s;\n", func);
}
@Override
public boolean isConnectInstruction() {
return false;
}
};
}
private Instruction returnAccumulator() {
return new Instruction() {
@Override
public String getTargetLanguageInstruction() {
return " return res;\n";
}
@Override
public boolean isConnectInstruction() {
return false;
}
};
}
private void addLoopVarParamToMethod(Method method, Variable loopVar, BluePrintCPP bluePrint) {
String func = method.getInstructions().get(2).getTargetLanguageInstruction();
if (func.contains(CALC_SUM_METHOD_NAME)) {
String[] split1 = func.split(CALC_SUM_METHOD_NAME);
String[] split2 = split1[1].split("[)]");
func = CALC_SUM_METHOD_NAME + split2[0] + ", " + loopVar.getNameTargetLanguageFormat() + ")";
// and change the method signiture of the calc sum function
String mName = CALC_SUM_METHOD_NAME + split1[1].substring(0, split1[1].indexOf("("));
Optional<Method> affectedMethod = bluePrint.getMethod(mName);
if (affectedMethod.isPresent()) {
affectedMethod.get().addParameterUnique(loopVar);
addLoopVarParamToMethod(affectedMethod.get(), loopVar, bluePrint);
}
method.getInstructions().set(2, forLoopBody(func));
}
}
}
......@@ -183,4 +183,9 @@ public class ArmadilloFunctionTest extends AbstractSymtabTest {
public void testTanhCommand() throws IOException {
testMathCommand("tanh");
}
@Test
public void testSumExtendedCommand() throws IOException {
testMathCommand("sumExtended");
}
}
#ifndef TEST_MATH_SUMEXTENDEDCOMMANDTEST
#define TEST_MATH_SUMEXTENDEDCOMMANDTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
using namespace arma;
class test_math_sumExtendedCommandTest{
public:
colvec CONSTANTCONSTANTVECTOR0;
void init()
{
CONSTANTCONSTANTVECTOR0 = colvec(3);
CONSTANTCONSTANTVECTOR0(0,0) = 1;
CONSTANTCONSTANTVECTOR0(1,0) = 2;
CONSTANTCONSTANTVECTOR0(2,0) = 3;
}
double calcSum1(colvec A)
{
double res = 0;
for (int i = 1 - 1; i <= 3 - 1; i++)
res += A(i);
return res;
}
double calcSum2(colvec A, double x, int j)
{
double res = 0;
for (int i = 1 - 1; i <= 3 - 1; i++)
res += A(i)*A(j);
return res;
}
double calcSum3(colvec A, double x)
{
double res = 0;
for (int j = 1 - 1; j <= 2 - 1; j++)
res += calcSum2(A, x, j);
return res;
}
double calcSum4(colvec A, double x, double y, int j, int k)
{
double res = 0;
for (int i = 1 - 1; i <= 3 - 1; i++)
res += A(i)*A(j)*A(k);
return res;
}
double calcSum5(colvec A, double x, double y, int k)
{
double res = 0;
for (int j = 1 - 1; j <= 2 - 1; j++)
res += calcSum4(A, x, y, j, k);
return res;
}
double calcSum6(colvec A, double x, double y)
{
double res = 0;
for (int k = 1 - 1; k <= 1 - 1; k++)
res += calcSum5(A, x, y, k);
return res;
}
void execute()
{
colvec A = CONSTANTCONSTANTVECTOR0;
double x = calcSum1(A);
double y = calcSum3(A, x);
double z = calcSum6(A, x, y);
}
};
#endif
package test.math;
component SumExtendedCommandTest{
implementation Math{
Q^{3} A = [1; 2; 3];
Q x = sum(A(i), i, 1, 3);
Q y = sum(sum(A(i) * A(j), i, 1, 3), j, 1, 2);
Q z = sum(sum(sum(A(i) * A(j) * A(k), i, 1, 3), j, 1, 2) k, 1, 1);
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment