Commit b2dfd565 authored by Christoph Richter's avatar Christoph Richter
Browse files

Fixed wrong 1-based matrix element access for armadillo (fixes #53)

parent ce277f76
......@@ -2,6 +2,9 @@ package de.monticore.lang.monticar.generator;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import java.util.HashSet;
/**
* @author Sascha Schneiders.
......@@ -9,6 +12,8 @@ import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
public abstract class MathCommand {
protected String mathCommandName;
private HashSet<String> targetLanguageCommandNames = new HashSet<>();
public MathCommand() {
}
......@@ -25,5 +30,23 @@ public abstract class MathCommand {
this.mathCommandName = mathCommandName;
}
public abstract void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint);
protected abstract void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint);
public void convertAndSetTargetLanguageName(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
convert(mathExpressionSymbol, bluePrint);
if (mathExpressionSymbol instanceof MathMatrixNameExpressionSymbol) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
targetLanguageCommandNames.add(mathMatrixNameExpressionSymbol.getTextualRepresentation());
}
}
/**
* Gets the mathCommandName converted to the target language possibly contains multiple
* commands
*
* @return targetLanguageCommandName
*/
public HashSet<String> getTargetLanguageCommandNames() {
return targetLanguageCommandNames;
}
}
......@@ -9,7 +9,7 @@ import java.util.List;
public abstract class MathCommandRegister {
public List<MathCommand> mathCommands = new ArrayList<>();
public MathCommandRegister(){
public MathCommandRegister() {
init();
}
......@@ -24,6 +24,26 @@ public abstract class MathCommandRegister {
}
return null;
}
public boolean isMathCommand(String functionName) {
boolean isMathCommand = false;
if (!functionName.isEmpty()) {
if (getMathCommand(functionName) != null) {
isMathCommand = true;
} else {
for (MathCommand mathCommand : mathCommands) {
for (String s : mathCommand.getTargetLanguageCommandNames()) {
if (s.contains(functionName)) {
isMathCommand = true;
break;
}
}
}
}
}
return isMathCommand;
}
protected abstract void init();
}
package de.monticore.lang.monticar.generator.cpp;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.Generator;
import de.monticore.lang.monticar.generator.MathCommandRegister;
import de.monticore.lang.monticar.generator.cpp.commands.*;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import de.se_rwth.commons.logging.Log;
/**
......@@ -82,7 +78,7 @@ public class MathCommandRegisterCPP extends MathCommandRegister {
fullName = removeTrailingStrings(fullName, "(");
String name = calculateName(fullName);
Log.info("" + input + " name: " + name, "containsCommandExpression");
if (GeneratorCPP.currentInstance.getMathCommandRegister().getMathCommand(name) != null) {
if (GeneratorCPP.currentInstance.getMathCommandRegister().isMathCommand(name)) {
return true;
}
fullName = fullName.substring(name.length() + 1);
......
......@@ -160,7 +160,7 @@ public class MathFunctionFixer {
if (mathCommand != null) {
if (MathConverter.curBackend.getBackendName().equals("OctaveBackend"))
bluePrintCPP.addAdditionalIncludeString("Helper");
mathCommand.convert(mathExpressionSymbol, bluePrintCPP);
mathCommand.convertAndSetTargetLanguageName(mathExpressionSymbol, bluePrintCPP);
}
if (fixForLoopAccess(mathExpressionSymbol, variable, bluePrintCPP)) {
for (MathMatrixAccessSymbol mathMatrixAccessSymbol : mathExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols()) {
......
......@@ -2,6 +2,7 @@ package de.monticore.lang.monticar.generator.cpp.converter;
import de.monticore.lang.math.math._symboltable.MathForLoopHeadSymbol;
import de.monticore.lang.math.math._symboltable.expression.*;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.Variable;
import de.monticore.lang.monticar.generator.cpp.MathCommandRegisterCPP;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
......@@ -248,12 +249,50 @@ public class ExecuteMethodGeneratorHandler {
Log.info("result2: " + result, "MathAssignmentExpressionSymbol");
}
} else {
result = mathAssignmentExpressionSymbol.getNameOfMathValue() + " " + mathAssignmentExpressionSymbol.getAssignmentOperator().getOperator() + " " + ExecuteMethodGenerator.generateExecuteCode(mathAssignmentExpressionSymbol.getExpressionSymbol(), includeStrings) + ";\n";
Log.info("result3: " + result, "MathAssignmentExpressionSymbol");
result = generateExecuteCodeForNonMatrixElementAssignments(mathAssignmentExpressionSymbol, includeStrings);
}
return result;
}
private static String generateExecuteCodeForNonMatrixElementAssignments(MathAssignmentExpressionSymbol mathAssignmentExpressionSymbol, List<String> includeStrings) {
String name = mathAssignmentExpressionSymbol.getNameOfMathValue();
String op = mathAssignmentExpressionSymbol.getAssignmentOperator().getOperator();
String assignment;
MathExpressionSymbol assignmentSymbol = mathAssignmentExpressionSymbol.getExpressionSymbol();
if (assignmentSymbol instanceof MathMatrixNameExpressionSymbol) {
MathMatrixNameExpressionSymbol matrixAssignmentSymbol = (MathMatrixNameExpressionSymbol) assignmentSymbol;
if (useZeroBasedIndexing(matrixAssignmentSymbol)) {
String matrixName = matrixAssignmentSymbol.getNameToAccess();
String matrixAccess = ExecuteMethodGenerator.getCorrectAccessString(matrixAssignmentSymbol.getNameToAccess(), matrixAssignmentSymbol.getMathMatrixAccessOperatorSymbol(), includeStrings);
assignment = String.format("%s%s", matrixName, matrixAccess);
} else {
assignment = ExecuteMethodGenerator.generateExecuteCode(assignmentSymbol, includeStrings);
}
} else {
assignment = ExecuteMethodGenerator.generateExecuteCode(assignmentSymbol, includeStrings);
}
String result = String.format("%s %s %s;\n", name, op, assignment);
Log.info("result3: " + result, "MathAssignmentExpressionSymbol");
return result;
}
private static boolean useZeroBasedIndexing(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol) {
boolean isZeroBased = false;
if (MathConverter.curBackend.usesZeroBasedIndexing()) {
if (!isFunctionCall(mathMatrixNameExpressionSymbol)) {
isZeroBased = true;
}
}
return isZeroBased;
}
private static boolean isFunctionCall(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol) {
boolean isFunctionCall = false;
if (MathCommandRegisterCPP.containsCommandExpression(mathMatrixNameExpressionSymbol, mathMatrixNameExpressionSymbol.getTextualRepresentation())) {
isFunctionCall = true;
}
return isFunctionCall;
}
public static String generateExecuteCode(MathForLoopExpressionSymbol mathForLoopExpressionSymbol, List<String> includeStrings) {
String result = "";
......
......@@ -88,4 +88,18 @@ public class BasicMathGenerationArmadilloTest extends AbstractSymtabTest {
String restPath = "armadillo/testMath/l0/";
testFilesAreEqual(files, restPath);
}
@Test
public void armadilloIndexTest() throws IOException {
TaggingResolver symtab = createSymTabAndTaggingResolver("src/test/resources");
ExpandedComponentInstanceSymbol componentSymbol = symtab.<ExpandedComponentInstanceSymbol>resolve("test.math.armadilloIndexTest", ExpandedComponentInstanceSymbol.KIND).orElse(null);
assertNotNull(componentSymbol);
GeneratorCPP generatorCPP = new GeneratorCPP();
generatorCPP.useArmadilloBackend();
generatorCPP.setGenerationTargetPath("./target/generated-sources-cpp/armadillo/testMath/l0");
List<File> files = generatorCPP.generateFiles(symtab, componentSymbol, symtab);
String restPath = "armadillo/testMath/l0/";
testFilesAreEqual(files, restPath);
}
}
#ifndef TEST_MATH_ARMADILLOINDEXTEST
#define TEST_MATH_ARMADILLOINDEXTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
using namespace arma;
class test_math_armadilloIndexTest{
public:
mat in1;
mat out1;
mat CONSTANTCONSTANTVECTOR0;
rowvec CONSTANTCONSTANTVECTOR1;
void init()
{
in1=mat(2,2);
out1=mat(4,4);
CONSTANTCONSTANTVECTOR0 = mat(2,2);
CONSTANTCONSTANTVECTOR0(0,0) = 1;
CONSTANTCONSTANTVECTOR0(0,1) = 2;
CONSTANTCONSTANTVECTOR0(1,0) = 3;
CONSTANTCONSTANTVECTOR0(1,1) = 4;
CONSTANTCONSTANTVECTOR1 = rowvec(2);
CONSTANTCONSTANTVECTOR1(0,0) = 11;
CONSTANTCONSTANTVECTOR1(0,1) = 12;
}
void execute()
{
mat A=mat(2,2);
A = (ones<mat>(2, 2));
A = CONSTANTCONSTANTVECTOR0;
colvec b=colvec(2);
b = (zeros<colvec>(2));
b = CONSTANTCONSTANTVECTOR1;
out1 = (zeros<mat>(4, 4));
A(2-1, 2-1) = 5;
A(2-1) = b(2-1);
A(2-1, 1-1) = out1(1-1, 1-1);
b(2-1) = 13;
b(1-1) = A(1-1);
b(1-1) = A(2-1, 2-1);
b(2-1) = out1(4-1, 4-1);
double x = 0;
x = A(1-1, 2-1) ;
x = b(2-1) ;
x = in1(1-1, 1-1) ;
out1(1-1, 1-1) = in1(1-1, 1-1);
mat test = (zeros<mat>(4, 4));
out1 = test;
out1(1-1) = test(1-1);
out1(2-1, 2-1) = A(2-1, 2-1);
}
};
#endif
package test.math;
component ArmadilloIndexTest{
ports in Q^{2, 2} in1,
out Q^{4, 4} out1;
implementation Math{
// test matrix initialization
Q^{2, 2} A;
A = ones(2,2);
A = [1 2; 3 4];
// test vector initialization
Q^{2} b;
b = zeros(2);
b = [11 12];
// test port initialization
out1 = zeros(4, 4);
// test matrix element assignment
A(2, 2) = 5;
A(2) = b(2);
A(1, 1) = b(2);
A(2, 1) = out1(1, 1);
// test vector element assignment
b(2) = 13;
b(1) = A(1);
b(1) = A(2, 2);
b(2) = out1(4, 4);
// test scalar element assignment
Q x = 0;
x = A(1, 2);
x = b(2);
x = in1(1, 1);
// test port element assignment
out1(1, 1) = in1(1, 1);
Q^{4, 4} test = zeros(4,4);
out1 = test;
out1(1) = test(1);
out1(2, 2) = A(2, 2);
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment