Commit f8903b14 authored by Christoph Richter's avatar Christoph Richter
Browse files

fixup! Fixed element-wise multiplication generation armadillo (gitlab #13)

parent b83d0854
......@@ -37,6 +37,7 @@ public interface MathBackend {
String getDivisionEEString(MathMatrixArithmeticExpressionSymbol mathExpressionSymbol, String valueListString);
String getMultiplicationEEString(MathMatrixArithmeticExpressionSymbol mathExpressionSymbol, String valueListString);
/**
* Does the backend use 0-based or 1-based indexing for matrix element access?
*
......
......@@ -104,6 +104,12 @@ public class ArmadilloBackend implements MathBackend {
ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol.getRightExpression(), new ArrayList<>());
}
@Override
public String getMultiplicationEEString(MathMatrixArithmeticExpressionSymbol mathExpressionSymbol, String valueListString) {
return ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol.getLeftExpression(), new ArrayList<>()) + " % " +
ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol.getRightExpression(), new ArrayList<>());
}
@Override
public boolean usesZeroBasedIndexing() {
return true;
......
......@@ -89,6 +89,12 @@ public class LinalgBackend implements MathBackend {
return null;
}
@Override
public String getMultiplicationEEString(MathMatrixArithmeticExpressionSymbol mathExpressionSymbol, String valueListString) {
Log.debug("Not supported yet","Not Implemented");
return null;
}
@Override
public boolean usesZeroBasedIndexing() {
// TODO: check this! Do not know this backend...
......
......@@ -185,13 +185,8 @@ public class MathFunctionFixer extends BaseMathFunctionFixerHandler {
public static void fixMathFunctions(MathMatrixArithmeticExpressionSymbol mathExpressionSymbol, BluePrintCPP bluePrintCPP) {
fixMathFunctions(mathExpressionSymbol.getLeftExpression(), bluePrintCPP);
if (mathExpressionSymbol.getRightExpression() != null) {
if (mathExpressionSymbol.getRightExpression() != null)
fixMathFunctions(mathExpressionSymbol.getRightExpression(), bluePrintCPP);
// fix element wise multiplication
if (mathExpressionSymbol.getOperator().contentEquals(".*")) {
mathExpressionSymbol.setOperator("%");
}
}
}
public static void fixMathFunctions(MathMatrixNameExpressionSymbol mathExpressionSymbol, BluePrintCPP bluePrintCPP) {
......
......@@ -92,6 +92,12 @@ public class OctaveBackend implements MathBackend {
return OctaveHelper.getCallOctaveFunctionFirstResult(mathExpressionSymbol.getLeftExpression(), "ldivide", valueListString, false);
}
@Override
public String getMultiplicationEEString(MathMatrixArithmeticExpressionSymbol mathExpressionSymbol, String valueListString) {
Log.warn("Backend deprecated");
return OctaveHelper.getCallOctaveFunctionFirstResult(mathExpressionSymbol.getLeftExpression(), ".*", valueListString, false);
}
@Override
public boolean usesZeroBasedIndexing() {
return false;
......
......@@ -24,6 +24,8 @@ public class ExecuteMethodGeneratorMatrixExpressionHandler {
result = generateExecuteCodeMatrixEEPowerOf(mathMatrixArithmeticExpressionSymbol, includeStrings);
} else if (mathMatrixArithmeticExpressionSymbol.getMathOperator().equals("./")) {
result = generateExecuteCodeMatrixEEDivide(mathMatrixArithmeticExpressionSymbol, includeStrings);
} else if (mathMatrixArithmeticExpressionSymbol.getMathOperator().equals(".*")) {
result = generateExecuteCodeMatrixEEMult(mathMatrixArithmeticExpressionSymbol, includeStrings);
/*} else if (mathArithmeticExpressionSymbol.getMathOperator().equals("./")) {
Log.error("reace");
result += "\"ldivide\"";
......@@ -44,6 +46,11 @@ public class ExecuteMethodGeneratorMatrixExpressionHandler {
return result;
}
private static String generateExecuteCodeMatrixEEMult(MathMatrixArithmeticExpressionSymbol mathMatrixArithmeticExpressionSymbol, List<String> includeStrings) {
String valueListString = calculateValueListString(mathMatrixArithmeticExpressionSymbol);
return MathConverter.curBackend.getMultiplicationEEString(mathMatrixArithmeticExpressionSymbol, valueListString);
}
public static String calculateValueListString(IArithmeticExpression mathExpressionSymbol) {
List<MathExpressionSymbol> list = new ArrayList<MathExpressionSymbol>();
list.add(mathExpressionSymbol.getLeftExpression());
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment