Commit b2dfd565 authored by Christoph Richter's avatar Christoph Richter
Browse files

Fixed wrong 1-based matrix element access for armadillo (fixes #53)

parent ce277f76
...@@ -2,6 +2,9 @@ package de.monticore.lang.monticar.generator; ...@@ -2,6 +2,9 @@ package de.monticore.lang.monticar.generator;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol; import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import java.util.HashSet;
/** /**
* @author Sascha Schneiders. * @author Sascha Schneiders.
...@@ -9,6 +12,8 @@ import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol; ...@@ -9,6 +12,8 @@ import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
public abstract class MathCommand { public abstract class MathCommand {
protected String mathCommandName; protected String mathCommandName;
private HashSet<String> targetLanguageCommandNames = new HashSet<>();
public MathCommand() { public MathCommand() {
} }
...@@ -25,5 +30,23 @@ public abstract class MathCommand { ...@@ -25,5 +30,23 @@ public abstract class MathCommand {
this.mathCommandName = mathCommandName; this.mathCommandName = mathCommandName;
} }
public abstract void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint); protected abstract void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint);
public void convertAndSetTargetLanguageName(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
convert(mathExpressionSymbol, bluePrint);
if (mathExpressionSymbol instanceof MathMatrixNameExpressionSymbol) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
targetLanguageCommandNames.add(mathMatrixNameExpressionSymbol.getTextualRepresentation());
}
}
/**
* Gets the mathCommandName converted to the target language possibly contains multiple
* commands
*
* @return targetLanguageCommandName
*/
public HashSet<String> getTargetLanguageCommandNames() {
return targetLanguageCommandNames;
}
} }
...@@ -9,7 +9,7 @@ import java.util.List; ...@@ -9,7 +9,7 @@ import java.util.List;
public abstract class MathCommandRegister { public abstract class MathCommandRegister {
public List<MathCommand> mathCommands = new ArrayList<>(); public List<MathCommand> mathCommands = new ArrayList<>();
public MathCommandRegister(){ public MathCommandRegister() {
init(); init();
} }
...@@ -24,6 +24,26 @@ public abstract class MathCommandRegister { ...@@ -24,6 +24,26 @@ public abstract class MathCommandRegister {
} }
return null; return null;
} }
public boolean isMathCommand(String functionName) {
boolean isMathCommand = false;
if (!functionName.isEmpty()) {
if (getMathCommand(functionName) != null) {
isMathCommand = true;
} else {
for (MathCommand mathCommand : mathCommands) {
for (String s : mathCommand.getTargetLanguageCommandNames()) {
if (s.contains(functionName)) {
isMathCommand = true;
break;
}
}
}
}
}
return isMathCommand;
}
protected abstract void init(); protected abstract void init();
} }
package de.monticore.lang.monticar.generator.cpp; package de.monticore.lang.monticar.generator.cpp;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol; import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.Generator;
import de.monticore.lang.monticar.generator.MathCommandRegister; import de.monticore.lang.monticar.generator.MathCommandRegister;
import de.monticore.lang.monticar.generator.cpp.commands.*; import de.monticore.lang.monticar.generator.cpp.commands.*;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
/** /**
...@@ -82,7 +78,7 @@ public class MathCommandRegisterCPP extends MathCommandRegister { ...@@ -82,7 +78,7 @@ public class MathCommandRegisterCPP extends MathCommandRegister {
fullName = removeTrailingStrings(fullName, "("); fullName = removeTrailingStrings(fullName, "(");
String name = calculateName(fullName); String name = calculateName(fullName);
Log.info("" + input + " name: " + name, "containsCommandExpression"); Log.info("" + input + " name: " + name, "containsCommandExpression");
if (GeneratorCPP.currentInstance.getMathCommandRegister().getMathCommand(name) != null) { if (GeneratorCPP.currentInstance.getMathCommandRegister().isMathCommand(name)) {
return true; return true;
} }
fullName = fullName.substring(name.length() + 1); fullName = fullName.substring(name.length() + 1);
......
...@@ -160,7 +160,7 @@ public class MathFunctionFixer { ...@@ -160,7 +160,7 @@ public class MathFunctionFixer {
if (mathCommand != null) { if (mathCommand != null) {
if (MathConverter.curBackend.getBackendName().equals("OctaveBackend")) if (MathConverter.curBackend.getBackendName().equals("OctaveBackend"))
bluePrintCPP.addAdditionalIncludeString("Helper"); bluePrintCPP.addAdditionalIncludeString("Helper");
mathCommand.convert(mathExpressionSymbol, bluePrintCPP); mathCommand.convertAndSetTargetLanguageName(mathExpressionSymbol, bluePrintCPP);
} }
if (fixForLoopAccess(mathExpressionSymbol, variable, bluePrintCPP)) { if (fixForLoopAccess(mathExpressionSymbol, variable, bluePrintCPP)) {
for (MathMatrixAccessSymbol mathMatrixAccessSymbol : mathExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols()) { for (MathMatrixAccessSymbol mathMatrixAccessSymbol : mathExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols()) {
......
...@@ -2,6 +2,7 @@ package de.monticore.lang.monticar.generator.cpp.converter; ...@@ -2,6 +2,7 @@ package de.monticore.lang.monticar.generator.cpp.converter;
import de.monticore.lang.math.math._symboltable.MathForLoopHeadSymbol; import de.monticore.lang.math.math._symboltable.MathForLoopHeadSymbol;
import de.monticore.lang.math.math._symboltable.expression.*; import de.monticore.lang.math.math._symboltable.expression.*;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.Variable; import de.monticore.lang.monticar.generator.Variable;
import de.monticore.lang.monticar.generator.cpp.MathCommandRegisterCPP; import de.monticore.lang.monticar.generator.cpp.MathCommandRegisterCPP;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer; import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
...@@ -241,19 +242,57 @@ public class ExecuteMethodGeneratorHandler { ...@@ -241,19 +242,57 @@ public class ExecuteMethodGeneratorHandler {
} }
}*/ }*/
result = mathAssignmentExpressionSymbol.getNameOfMathValue(); result = mathAssignmentExpressionSymbol.getNameOfMathValue();
result += ExecuteMethodGenerator.getCorrectAccessString(mathAssignmentExpressionSymbol.getNameOfMathValue(), mathAssignmentExpressionSymbol.getMathMatrixAccessOperatorSymbol(), includeStrings); result += ExecuteMethodGenerator.getCorrectAccessString(mathAssignmentExpressionSymbol.getNameOfMathValue(), mathAssignmentExpressionSymbol.getMathMatrixAccessOperatorSymbol(), includeStrings);
result += mathAssignmentExpressionSymbol.getAssignmentOperator().getOperator() + " "; result += mathAssignmentExpressionSymbol.getAssignmentOperator().getOperator() + " ";
result += StringIndexHelper.modifyContentBetweenBracketsByAdding(ExecuteMethodGenerator.generateExecuteCode(mathAssignmentExpressionSymbol.getExpressionSymbol(), includeStrings) + ";\n", "-1"); result += StringIndexHelper.modifyContentBetweenBracketsByAdding(ExecuteMethodGenerator.generateExecuteCode(mathAssignmentExpressionSymbol.getExpressionSymbol(), includeStrings) + ";\n", "-1");
Log.info("result2: " + result, "MathAssignmentExpressionSymbol"); Log.info("result2: " + result, "MathAssignmentExpressionSymbol");
}
} else {
result = generateExecuteCodeForNonMatrixElementAssignments(mathAssignmentExpressionSymbol, includeStrings);
}
return result;
}
private static String generateExecuteCodeForNonMatrixElementAssignments(MathAssignmentExpressionSymbol mathAssignmentExpressionSymbol, List<String> includeStrings) {
String name = mathAssignmentExpressionSymbol.getNameOfMathValue();
String op = mathAssignmentExpressionSymbol.getAssignmentOperator().getOperator();
String assignment;
MathExpressionSymbol assignmentSymbol = mathAssignmentExpressionSymbol.getExpressionSymbol();
if (assignmentSymbol instanceof MathMatrixNameExpressionSymbol) {
MathMatrixNameExpressionSymbol matrixAssignmentSymbol = (MathMatrixNameExpressionSymbol) assignmentSymbol;
if (useZeroBasedIndexing(matrixAssignmentSymbol)) {
String matrixName = matrixAssignmentSymbol.getNameToAccess();
String matrixAccess = ExecuteMethodGenerator.getCorrectAccessString(matrixAssignmentSymbol.getNameToAccess(), matrixAssignmentSymbol.getMathMatrixAccessOperatorSymbol(), includeStrings);
assignment = String.format("%s%s", matrixName, matrixAccess);
} else {
assignment = ExecuteMethodGenerator.generateExecuteCode(assignmentSymbol, includeStrings);
} }
} else { } else {
result = mathAssignmentExpressionSymbol.getNameOfMathValue() + " " + mathAssignmentExpressionSymbol.getAssignmentOperator().getOperator() + " " + ExecuteMethodGenerator.generateExecuteCode(mathAssignmentExpressionSymbol.getExpressionSymbol(), includeStrings) + ";\n"; assignment = ExecuteMethodGenerator.generateExecuteCode(assignmentSymbol, includeStrings);
Log.info("result3: " + result, "MathAssignmentExpressionSymbol");
} }
String result = String.format("%s %s %s;\n", name, op, assignment);
Log.info("result3: " + result, "MathAssignmentExpressionSymbol");
return result; return result;
} }
private static boolean useZeroBasedIndexing(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol) {
boolean isZeroBased = false;
if (MathConverter.curBackend.usesZeroBasedIndexing()) {
if (!isFunctionCall(mathMatrixNameExpressionSymbol)) {
isZeroBased = true;
}
}
return isZeroBased;
}
private static boolean isFunctionCall(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol) {
boolean isFunctionCall = false;
if (MathCommandRegisterCPP.containsCommandExpression(mathMatrixNameExpressionSymbol, mathMatrixNameExpressionSymbol.getTextualRepresentation())) {
isFunctionCall = true;
}
return isFunctionCall;
}
public static String generateExecuteCode(MathForLoopExpressionSymbol mathForLoopExpressionSymbol, List<String> includeStrings) { public static String generateExecuteCode(MathForLoopExpressionSymbol mathForLoopExpressionSymbol, List<String> includeStrings) {
String result = ""; String result = "";
......
...@@ -88,4 +88,18 @@ public class BasicMathGenerationArmadilloTest extends AbstractSymtabTest { ...@@ -88,4 +88,18 @@ public class BasicMathGenerationArmadilloTest extends AbstractSymtabTest {
String restPath = "armadillo/testMath/l0/"; String restPath = "armadillo/testMath/l0/";
testFilesAreEqual(files, restPath); testFilesAreEqual(files, restPath);
} }
@Test
public void armadilloIndexTest() throws IOException {
TaggingResolver symtab = createSymTabAndTaggingResolver("src/test/resources");
ExpandedComponentInstanceSymbol componentSymbol = symtab.<ExpandedComponentInstanceSymbol>resolve("test.math.armadilloIndexTest", ExpandedComponentInstanceSymbol.KIND).orElse(null);
assertNotNull(componentSymbol);
GeneratorCPP generatorCPP = new GeneratorCPP();
generatorCPP.useArmadilloBackend();
generatorCPP.setGenerationTargetPath("./target/generated-sources-cpp/armadillo/testMath/l0");
List<File> files = generatorCPP.generateFiles(symtab, componentSymbol, symtab);
String restPath = "armadillo/testMath/l0/";
testFilesAreEqual(files, restPath);
}
} }
#ifndef TEST_MATH_ARMADILLOINDEXTEST
#define TEST_MATH_ARMADILLOINDEXTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
using namespace arma;
class test_math_armadilloIndexTest{
public:
mat in1;
mat out1;
mat CONSTANTCONSTANTVECTOR0;
rowvec CONSTANTCONSTANTVECTOR1;
void init()
{
in1=mat(2,2);
out1=mat(4,4);
CONSTANTCONSTANTVECTOR0 = mat(2,2);
CONSTANTCONSTANTVECTOR0(0,0) = 1;
CONSTANTCONSTANTVECTOR0(0,1) = 2;
CONSTANTCONSTANTVECTOR0(1,0) = 3;
CONSTANTCONSTANTVECTOR0(1,1) = 4;
CONSTANTCONSTANTVECTOR1 = rowvec(2);
CONSTANTCONSTANTVECTOR1(0,0) = 11;
CONSTANTCONSTANTVECTOR1(0,1) = 12;
}
void execute()
{
mat A=mat(2,2);
A = (ones<mat>(2, 2));
A = CONSTANTCONSTANTVECTOR0;
colvec b=colvec(2);
b = (zeros<colvec>(2));
b = CONSTANTCONSTANTVECTOR1;
out1 = (zeros<mat>(4, 4));
A(2-1, 2-1) = 5;
A(2-1) = b(2-1);
A(2-1, 1-1) = out1(1-1, 1-1);
b(2-1) = 13;
b(1-1) = A(1-1);
b(1-1) = A(2-1, 2-1);
b(2-1) = out1(4-1, 4-1);
double x = 0;
x = A(1-1, 2-1) ;
x = b(2-1) ;
x = in1(1-1, 1-1) ;
out1(1-1, 1-1) = in1(1-1, 1-1);
mat test = (zeros<mat>(4, 4));
out1 = test;
out1(1-1) = test(1-1);
out1(2-1, 2-1) = A(2-1, 2-1);
}
};
#endif
package test.math;
component ArmadilloIndexTest{
ports in Q^{2, 2} in1,
out Q^{4, 4} out1;
implementation Math{
// test matrix initialization
Q^{2, 2} A;
A = ones(2,2);
A = [1 2; 3 4];
// test vector initialization
Q^{2} b;
b = zeros(2);
b = [11 12];
// test port initialization
out1 = zeros(4, 4);
// test matrix element assignment
A(2, 2) = 5;
A(2) = b(2);
A(1, 1) = b(2);
A(2, 1) = out1(1, 1);
// test vector element assignment
b(2) = 13;
b(1) = A(1);
b(1) = A(2, 2);
b(2) = out1(4, 4);
// test scalar element assignment
Q x = 0;
x = A(1, 2);
x = b(2);
x = in1(1, 1);
// test port element assignment
out1(1, 1) = in1(1, 1);
Q^{4, 4} test = zeros(4,4);
out1 = test;
out1(1) = test(1);
out1(2, 2) = A(2, 2);
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment