Commit c5486ebe authored by Christoph Richter's avatar Christoph Richter
Browse files

Fixed wrong constant declaration for variable matrices (fixes #18)

parent 641211ac
......@@ -2,11 +2,9 @@ package de.monticore.lang.monticar.generator.cpp.converter;
import de.monticore.lang.math._symboltable.expression.IArithmeticExpression;
import de.monticore.lang.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math._symboltable.expression.MathNameExpressionSymbol;
import de.monticore.lang.math._symboltable.matrix.*;
import de.monticore.lang.monticar.generator.cpp.MathCommandRegisterCPP;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.OctaveHelper;
import de.monticore.lang.monticar.generator.cpp.StringValueListExtractorUtil;
import de.monticore.lang.monticar.generator.cpp.*;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
......@@ -244,7 +242,68 @@ public class ExecuteMethodGeneratorMatrixExpressionHandler {
}
public static String generateExecuteCode(MathMatrixArithmeticValueSymbol mathMatrixArithmeticValueSymbol, List<String> includeStrings) {
return MathConverter.getConstantConversion(mathMatrixArithmeticValueSymbol);
String result;
if (matrixValueContainsVariables(mathMatrixArithmeticValueSymbol)) {
result = generateVariableMatrixCode(mathMatrixArithmeticValueSymbol, includeStrings);
} else {
// if it does not contain any variables
result = MathConverter.getConstantConversion(mathMatrixArithmeticValueSymbol);
}
return result;
}
private static String generateVariableMatrixCode(MathMatrixArithmeticValueSymbol mathMatrixArithmeticValueSymbol, List<String> includeStrings) {
String result = "";
if (MathConverter.curBackend instanceof ArmadilloBackend) {
if ((mathMatrixArithmeticValueSymbol.getVectors().size() == 1) || (mathMatrixArithmeticValueSymbol.getVectors().get(0).getMathMatrixAccessSymbols().size() == 1))
result = generateCodeForVecDependentOnVar(mathMatrixArithmeticValueSymbol, includeStrings);
else {
StringBuilder sb = new StringBuilder();
for (MathMatrixAccessOperatorSymbol vec : mathMatrixArithmeticValueSymbol.getVectors()) {
sb.append("{");
for (MathMatrixAccessSymbol elem : vec.getMathMatrixAccessSymbols()) {
sb.append(ExecuteMethodGenerator.generateExecuteCode(elem, includeStrings));
sb.append(",");
}
sb.deleteCharAt(sb.length() - 1);
sb.append("},");
}
sb.deleteCharAt(sb.length() - 1);
result = String.format("%s({%s})", MathConverter.curBackend.getMatrixTypeName(), sb.toString());
}
} else {
Log.error("Not supported backend!");
}
return result;
}
private static String generateCodeForVecDependentOnVar(MathMatrixArithmeticValueSymbol matVal, List<String> includeStrings) {
StringBuilder sb = new StringBuilder();
for (MathMatrixAccessOperatorSymbol vec : matVal.getVectors()) {
for (MathMatrixAccessSymbol elem : vec.getMathMatrixAccessSymbols()) {
sb.append(ExecuteMethodGenerator.generateExecuteCode(elem, includeStrings));
sb.append(",");
}
}
sb.deleteCharAt(sb.length() - 1);
String result = String.format("%s({%s})", MathConverter.curBackend.getRowVectorTypeName(), sb.toString());
if ((matVal.getVectors().size() > 1) && (matVal.getVectors().get(0).getMathMatrixAccessSymbols().size() == 1)) {
result = String.format("(%s.t())", result); // transpose to colvec
}
return result;
}
private static boolean matrixValueContainsVariables(MathMatrixArithmeticValueSymbol mathMatrixArithmeticValueSymbol) {
for (MathMatrixAccessOperatorSymbol vec : mathMatrixArithmeticValueSymbol.getVectors()) {
for (MathMatrixAccessSymbol elem : vec.getMathMatrixAccessSymbols()) {
if (elem.getMathExpressionSymbol().isPresent()) {
MathExpressionSymbol elemSymbol = elem.getMathExpressionSymbol().get();
if (elemSymbol instanceof MathNameExpressionSymbol || elemSymbol instanceof MathMatrixNameExpressionSymbol)
return true;
}
}
}
return false;
}
public static String generateExecuteCode(MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol, List<String> includeStrings) {
......
......@@ -251,4 +251,17 @@ public class BasicGenerationArmadilloTest extends AbstractSymtabTest {
String restPath = "armadillo/detectionObjectDetector" + number + "/l3/";
testFilesAreEqual(files, restPath);
}
@Test
public void testPortInMatrixDefinition() throws IOException {
TaggingResolver symtab = createSymTabAndTaggingResolver("src/test/resources");
ExpandedComponentInstanceSymbol componentSymbol = symtab.<ExpandedComponentInstanceSymbol>resolve("test.portInMatrixDefinition", ExpandedComponentInstanceSymbol.KIND).orElse(null);
assertNotNull(componentSymbol);
GeneratorCPP generatorCPP = new GeneratorCPP();
generatorCPP.useArmadilloBackend();
generatorCPP.setGenerationTargetPath("./target/generated-sources-cpp/armadillo/portInMatrixDefinition");
List<File> files = generatorCPP.generateFiles(symtab, componentSymbol, symtab);
String restPath = "armadillo/portInMatrixDefinition/";
testFilesAreEqual(files, restPath);
}
}
#ifndef HELPERA_H
#define HELPERA_H
#include <iostream>
#include "armadillo.h"
#include <stdarg.h>
#include <initializer_list>
using namespace arma;
class HelperA{
public:
static mat getEigenVectors(mat A){
vec eigenValues;
mat eigenVectors;
eig_sym(eigenValues,eigenVectors,A);
return eigenVectors;
}
static vec getEigenValues(mat A){
vec eigenValues;
mat eigenVectors;
eig_sym(eigenValues,eigenVectors,A);
return eigenValues;
}
static mat getKMeansClusters(mat A, int k){
mat clusters;
kmeans(clusters,A.t(),k,random_subset,20,true);
/*printf("cluster centroid calculation done\n");
std::ofstream myfile;
myfile.open("data after cluster.txt");
myfile << A;
myfile.close();
std::ofstream myfile2;
myfile2.open("cluster centroids.txt");
myfile2 << clusters;
myfile2.close();*/
mat indexedData=getKMeansClustersIndexData(A.t(), clusters);
/*std::ofstream myfile3;
myfile3.open("data after index.txt");
myfile3 << indexedData;
myfile3.close();
*/
return indexedData;
}
static mat getKMeansClustersIndexData(mat A, mat centroids){
mat result=mat(A.n_cols, 1);
for(int i=0;i<A.n_cols;++i){
result(i, 0) = getIndexForClusterCentroids(A, i, centroids);
}
return result;
}
static int getIndexForClusterCentroids(mat A, int colIndex, mat centroids){
int index=1;
double lowestDistance=getEuclideanDistance(A, colIndex, centroids, 0);
for(int i=1;i<centroids.n_cols;++i){
double curDistance=getEuclideanDistance(A, colIndex, centroids, i);
if(curDistance<lowestDistance){
lowestDistance=curDistance;
index=i+1;
}
}
return index;
}
static double getEuclideanDistance(mat A, int colIndexA, mat B, int colIndexB){
double distance=0;
for(int i=0;i<A.n_rows;++i){
double elementA=A(i,colIndexA);
double elementB=B(i,colIndexB);
double diff=elementA-elementB;
distance+=diff*diff;
}
return sqrt(distance);
}
static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat getSqrtMatDiag(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = sqrt(curVal);
}
return A;
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
#ifndef TEST_PORTINMATRIXDEFINITION
#define TEST_PORTINMATRIXDEFINITION
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
#include "HelperA.h"
using namespace arma;
class test_portInMatrixDefinition{
public:
double in1;
double out1;
void init()
{
}
void execute()
{
colvec a = (rowvec({0,in1,1}).t());
a = a+1;
rowvec b = rowvec({a(1-1),a(2-1),a(3-1)});
mat c = mat({{a(1-1),a(2-1)},{b(1-1),b(2-1)}});
out1 = (accu(c));
}
};
#endif
package test;
component PortInMatrixDefinition{
ports in Q in1,
out Q out1;
implementation Math{
Q^{3} a = [0; in1; 1];
a = a + 1;
Q^{1, 3} b = [a(1), a(2), a(3)];
Q^{2, 2} c = [a(1), a(2); b(1), b(2)];
out1 = sum(c);
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment