Commit 31f126ba authored by Christoph Richter's avatar Christoph Richter
Browse files

MathCommands: Fixed ones and zeros command for cubes and vectors (fixes #58)

parent e6cd1f21
package de.monticore.lang.monticar.generator.cpp.commands;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessOperatorSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.BluePrint;
import de.monticore.lang.monticar.generator.MathCommand;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.OctaveHelper;
import de.monticore.lang.monticar.generator.cpp.converter.ExecuteMethodGenerator;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
/**
* Abstract base class for ones and zeros initialization function
*/
public abstract class MathInitCommand extends MathCommand {
protected abstract String getArmadilloInitCommandName();
protected abstract String getOctaveInitCommandName();
@Override
public void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
String backendName = MathConverter.curBackend.getBackendName();
if (backendName.equals("OctaveBackend")) {
convertUsingOctaveBackend(mathExpressionSymbol, bluePrint);
} else if (backendName.equals("ArmadilloBackend")) {
convertUsingArmadilloBackend(mathExpressionSymbol, bluePrint);
}
}
protected static String getMatrixTypeFromAccessOperatorSize(MathMatrixAccessOperatorSymbol mathExpressionSymbol) {
String result = "";
int size = mathExpressionSymbol.getMathMatrixAccessSymbols().size();
if (size == 1) {
result = MathConverter.curBackend.getColumnVectorTypeName();
} else if (size == 2) {
result = MathConverter.curBackend.getMatrixTypeName();
} else if (size == 3) {
result = MathConverter.curBackend.getCubeTypeName();
} else {
Log.error(String.format("Unknown matrix type for matrix access: %s", mathExpressionSymbol.getTextualRepresentation()));
}
return result;
}
private void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, getOctaveInitCommandName(), "Matrix", valueListString, "FirstResult", false, 1));
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
}
protected void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
String type = getMatrixTypeFromAccessOperatorSize(mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol());
MathStringExpression stringExpression = new MathStringExpression(String.format("%s<%s>%s", getArmadilloInitCommandName(), type, valueListString));
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
}
}
package de.monticore.lang.monticar.generator.cpp.commands;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.BluePrint;
import de.monticore.lang.monticar.generator.MathCommand;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.OctaveHelper;
import de.monticore.lang.monticar.generator.cpp.converter.ComponentConverter;
import de.monticore.lang.monticar.generator.cpp.converter.ExecuteMethodGenerator;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import java.util.ArrayList;
import java.util.List;
/**
* @author Sascha Schneiders
*/
public class MathOnesCommand extends MathCommand {
public class MathOnesCommand extends MathInitCommand {
public MathOnesCommand() {
setMathCommandName("ones");
}
@Override
public void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
String backendName = MathConverter.curBackend.getBackendName();
if (backendName.equals("OctaveBackend")) {
convertUsingOctaveBackend(mathExpressionSymbol, bluePrint);
} else if (backendName.equals("ArmadilloBackend")) {
convertUsingArmadilloBackend(mathExpressionSymbol, bluePrint);
}
protected String getArmadilloInitCommandName() {
return "ones";
}
public void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, "Fones", "Matrix", valueListString, "FirstResult", false, 1));
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
}
private void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression("ones<mat>" + valueListString);
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
//((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
@Override
protected String getOctaveInitCommandName() {
return "Fones";
}
}
package de.monticore.lang.monticar.generator.cpp.commands;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.BluePrint;
import de.monticore.lang.monticar.generator.MathCommand;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.OctaveHelper;
import de.monticore.lang.monticar.generator.cpp.converter.ComponentConverter;
import de.monticore.lang.monticar.generator.cpp.converter.ExecuteMethodGenerator;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import java.util.ArrayList;
import java.util.List;
/**
* @author Sascha Schneiders
*/
public class MathZerosCommand extends MathCommand {
public class MathZerosCommand extends MathInitCommand {
public MathZerosCommand() {
setMathCommandName("zeros");
}
@Override
public void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
String backendName = MathConverter.curBackend.getBackendName();
if (backendName.equals("OctaveBackend")) {
convertUsingOctaveBackend(mathExpressionSymbol, bluePrint);
} else if (backendName.equals("ArmadilloBackend")) {
convertUsingArmadilloBackend(mathExpressionSymbol, bluePrint);
}
}
private void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, "Fzeros", "Matrix", valueListString, "FirstResult", false, 1));
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
protected String getArmadilloInitCommandName() {
return "zeros";
}
private void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression("zeros<mat>" + valueListString);
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
//((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
@Override
protected String getOctaveInitCommandName() {
return "Fzeros";
}
}
......@@ -68,6 +68,16 @@ public class ArmadilloFunctionTest extends AbstractSymtabTest {
testMathCommand("zeros");
}
@Test
public void testOnesArmadilloOnlyCommand() throws IOException {
testMathCommand("onesArmadilloOnly");
}
@Test
public void testZerosArmadilloOnlyCommand() throws IOException {
testMathCommand("zerosArmadilloOnly");
}
@Test
public void testAbsCommand() throws IOException {
testMathCommand("abs");
......
#ifndef TEST_MATH_ONESARMADILLOONLYCOMMANDTEST
#define TEST_MATH_ONESARMADILLOONLYCOMMANDTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
using namespace arma;
class test_math_onesArmadilloOnlyCommandTest{
public:
void init()
{
}
void execute()
{
colvec a = (ones<colvec>(3));
mat b = (ones<mat>(3, 3));
cube c = (ones<cube>(3, 3, 3));
}
};
#endif
#ifndef TEST_MATH_ZEROSARMADILLOONLYCOMMANDTEST
#define TEST_MATH_ZEROSARMADILLOONLYCOMMANDTEST
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
using namespace arma;
class test_math_zerosArmadilloOnlyCommandTest{
public:
void init()
{
}
void execute()
{
colvec a = (zeros<colvec>(3));
mat b = (zeros<mat>(3, 3));
cube c = (zeros<cube>(3, 3, 3));
}
};
#endif
package test.math;
component OnesArmadilloOnlyCommandTest{
implementation Math{
Q^{3} a = ones(3);
Q^{3,3} b = ones(3,3);
Q^{3,3, 3} c = ones(3,3,3);
}
}
\ No newline at end of file
package test.math;
component ZerosArmadilloOnlyCommandTest{
implementation Math{
Q^{3} a = zeros(3);
Q^{3,3} b = zeros(3,3);
Q^{3,3, 3} c = zeros(3,3,3);
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment