Commit c481f475 authored by Sascha Niklas Schneiders's avatar Sascha Niklas Schneiders
Browse files

worked on optimizations

parent 1f4b3dc9
......@@ -29,6 +29,7 @@ public class Variable {
Optional<String> constantValue = Optional.empty();
List<String> additionalInformation = new ArrayList<>();
List<String> dimensionalInformation = new ArrayList<>();
List<String> properties = new ArrayList<>();
Optional<String> customTypeName = Optional.empty();
public Variable() {
......@@ -206,6 +207,18 @@ public class Variable {
return parameterVariable;
}
public List<String> getProperties() {
return properties;
}
public void addProperties(List<String> properties) {
this.properties.addAll(properties);
}
public void addProperty(String property) {
this.properties.add(property);
}
public void setIsParameterVariable(boolean parameterVariable) {
this.parameterVariable = parameterVariable;
}
......
......@@ -86,6 +86,13 @@ public class ArmadilloHelperSource {
"cx_mat result=sqrtmat(A);\n" +
"return real(result);\n" +
"}\n" +
"static mat invertDiagMatrix(mat A){\n" +
"for(int i=0;i<A.n_rows;++i){\n" +
" double curVal = A(i,i);\n" +
" A(i,i) = 1/curVal;\n" +
"}\n" +
"return A;\n" +
"}\n" +
"};\n" +
"#endif\n";
}
package de.monticore.lang.monticar.generator.cpp.commands;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.BluePrint;
import de.monticore.lang.monticar.generator.MathCommand;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.OctaveHelper;
import de.monticore.lang.monticar.generator.cpp.converter.ExecuteMethodGenerator;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
/**
* @author Sascha Schneiders
*/
public class MathInvDiagCommand extends MathCommand {
public MathInvDiagCommand() {
setMathCommandName("invdiag");
}
@Override
public void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
String backendName = MathConverter.curBackend.getBackendName();
if (backendName.equals("OctaveBackend")) {
//convertUsingOctaveBackend(mathExpressionSymbol, bluePrint);
Log.error("OctaveBackend does not support command invdiag yet");
} else if (backendName.equals("ArmadilloBackend")) {
convertUsingArmadilloBackend(mathExpressionSymbol, bluePrint);
}
}
public void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, "Finv", "Matrix", valueListString, "FirstResult", false, 1));
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
}
public void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression("HelperA::invertDiagMatrix"+valueListString);
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("HelperA");
}
}
......@@ -75,7 +75,9 @@ public class ComponentConverterMethodGeneration {
List<MathExpressionSymbol> visitedMathExpressionSymbol = new ArrayList<>();
int lastIndex = 0;
boolean swapNextInstructions = false;
for (currentGenerationIndex = 0; currentGenerationIndex < mathStatementsSymbol.getMathExpressionSymbols().size(); ++currentGenerationIndex) {
int beginIndex = currentGenerationIndex;
MathExpressionSymbol mathExpressionSymbol = mathStatementsSymbol.getMathExpressionSymbols().get(currentGenerationIndex);
if (!visitedMathExpressionSymbol.contains(mathExpressionSymbol)) {
if (generatorCPP.useAlgebraicOptimizations()) {
......@@ -90,16 +92,18 @@ public class ComponentConverterMethodGeneration {
String result = ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, includeStrings);
TargetCodeMathInstruction instruction = new TargetCodeMathInstruction(result, mathExpressionSymbol);
method.addInstruction(instruction);
if (lastIndex == currentGenerationIndex - 1) {
//Log.error("ad");
Instruction lastInstruction = method.getInstructions().get(currentGenerationIndex);
method.getInstructions().remove(currentGenerationIndex);
method.addInstruction(lastInstruction);
}
visitedMathExpressionSymbol.add(mathExpressionSymbol);
System.out.println("lastIndex: "+lastIndex+" current: "+currentGenerationIndex);
System.out.println("lastIndex: " + lastIndex + " current: " + currentGenerationIndex);
lastIndex = currentGenerationIndex;
}
if (swapNextInstructions) {
swapNextInstructions = false;
//Log.error("ad");
Instruction lastInstruction = method.getInstructions().get(currentGenerationIndex);
method.getInstructions().remove(currentGenerationIndex);
method.addInstruction(lastInstruction);
}
if (beginIndex != currentGenerationIndex) swapNextInstructions = true;
}
}
......
......@@ -3,8 +3,11 @@ package de.monticore.lang.monticar.generator.cpp.converter;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ConnectorSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ConstantPortSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.PortSymbol;
import de.monticore.lang.math.math._ast.ASTAssignmentType;
import de.monticore.lang.monticar.generator.Variable;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.ts.references.MCASTTypeSymbolReference;
import de.monticore.lang.monticar.types2._ast.ASTType;
/**
* @author Sascha Schneiders
......@@ -48,7 +51,13 @@ public class PortConverter {
String typeNameMontiCar = portSymbol.getTypeReference().getName();
if (portSymbol.getTypeReference().getReferencedSymbol() instanceof MCASTTypeSymbolReference) {
MCASTTypeSymbolReference typeSymbolReference = (MCASTTypeSymbolReference) portSymbol.getTypeReference().getReferencedSymbol();
ASTType astType = typeSymbolReference.getAstType();
ASTAssignmentType astAssignmentType = (ASTAssignmentType) astType;
//if (astAssignmentType.getMatrixProperty().size() > 0) Log.error(astType.toString());
variable.addProperties(astAssignmentType.getMatrixProperty());
}
if (portSymbol.isIncoming())
{
......@@ -84,7 +93,6 @@ public class PortConverter {
}
public static String getPortNameWithoutArrayBracketPart(String name) {
String nameWithOutArrayBracketPart = name;
if (nameWithOutArrayBracketPart.endsWith("]")) {
......
......@@ -22,7 +22,7 @@ public class MathAssignmentPartResultReuse implements MathOptimizationRule {
MathStatementsSymbol currentMathStatementsSymbol = null;
List<MathExpressionSymbol> encounteredSymbolInstances = new ArrayList<>();
Map<MathExpressionSymbol, String> symbolMap = new HashMap();
int currentId;
int currentId = 0;
MathExpressionSymbol startMathExpressionSymbol = null;
@Override
......@@ -45,7 +45,6 @@ public class MathAssignmentPartResultReuse implements MathOptimizationRule {
currentMathStatementsSymbol = mathStatementsSymbol;
encounteredSymbolInstances.clear();
symbolMap.clear();
currentId = 0;
startMathExpressionSymbol = mathExpressionSymbol;
optimize(mathExpressionSymbol, precedingExpressions);
}
......@@ -72,7 +71,7 @@ public class MathAssignmentPartResultReuse implements MathOptimizationRule {
System.out.println("Found Same Symbol");
String name = "";
if (!symbolMap.containsKey(mathExpressionSymbol)) {
symbolMap.put(mathExpressionSymbol, name = getReplacementName(currentId));
symbolMap.put(mathExpressionSymbol, name = getReplacementName(currentId++));
} else {
name = symbolMap.get(mathExpressionSymbol);
}
......
package de.monticore.lang.monticar.generator.optimization;
import de.monticore.lang.math.math._symboltable.MathStatementsSymbol;
import de.monticore.lang.math.math._symboltable.expression.*;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessOperatorSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.converter.ComponentConverter;
import de.se_rwth.commons.logging.Log;
import java.util.List;
/**
* @author Sascha Schneiders
*/
public class MathDiagonalMatrixInversionOptimization implements MathOptimizationRule {
MathStatementsSymbol currentMathStatementsSymbol = null;
@Override
public void optimize(MathExpressionSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions) {
if (mathExpressionSymbol == null) {
} else if (mathExpressionSymbol.isAssignmentExpression()) {
optimize((MathAssignmentExpressionSymbol) mathExpressionSymbol, precedingExpressions);
} else if (mathExpressionSymbol.isMatrixExpression()) {
optimize((MathMatrixExpressionSymbol) mathExpressionSymbol, precedingExpressions);
} else if (mathExpressionSymbol.isArithmeticExpression()) {
optimize((MathArithmeticExpressionSymbol) mathExpressionSymbol, precedingExpressions);
} else {
Log.debug("Not handled: " + mathExpressionSymbol.getClass().getName() + " " + mathExpressionSymbol.getTextualRepresentation(),
"optimizeMathExpressionSymbol");
}
}
@Override
public void optimize(MathExpressionSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions, MathStatementsSymbol mathStatementsSymbol) {
currentMathStatementsSymbol = mathStatementsSymbol;
optimize(mathExpressionSymbol, precedingExpressions);
}
public void optimize(MathAssignmentExpressionSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions) {
optimize(mathExpressionSymbol.getExpressionSymbol(), precedingExpressions);
}
public void optimize(MathArithmeticExpressionSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions) {
optimize(mathExpressionSymbol.getLeftExpression(), precedingExpressions);
optimize(mathExpressionSymbol.getRightExpression(), precedingExpressions);
}
public void optimize(MathMatrixExpressionSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions) {
if (mathExpressionSymbol.isMatrixNameExpression()) {
optimize((MathMatrixNameExpressionSymbol) mathExpressionSymbol, precedingExpressions);
} else if (mathExpressionSymbol.isMatrixAccessExpression()) {
optimize((MathMatrixAccessSymbol) mathExpressionSymbol, precedingExpressions);
} else {
Log.debug("Not handled: " + mathExpressionSymbol.getClass().getName() + " " + mathExpressionSymbol.
getTextualRepresentation(), "optimizeMathMatrixExpr");
}
}
public void optimize(MathMatrixAccessSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions) {
if (mathExpressionSymbol.getMathExpressionSymbol().isPresent()) {
optimize(mathExpressionSymbol.getMathExpressionSymbol().get(), precedingExpressions);
} else {
Log.debug("Not handled further", "optimizeMathMatrixAccess");
}
}
public void optimize(MathMatrixNameExpressionSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions) {
if (mathExpressionSymbol.getNameToAccess().equals("inv")) {
//ComponentConverter.currentBluePrint.getMathInformationRegister().isDiagonalMatrix()
boolean invertsDiagonalMatrix = false;
if (mathExpressionSymbol.getAstMathMatrixNameExpression().getMathMatrixAccessExpression().isPresent()) {
//optimize(mathExpressionSymbol.getMathMatrixAccessOperatorSymbol(), precedingExpressions);
//Log.error(ComponentConverter.currentBluePrint.getMathInformationRegister().getVariable("degree").getProperties().toString());
String name = getMatrixName((MathMatrixAccessSymbol) mathExpressionSymbol.getAstMathMatrixNameExpression().getMathMatrixAccessExpression().get().getMathMatrixAccesss().get(0).getSymbol().get());//TODO handle all possible cases
invertsDiagonalMatrix = ComponentConverter.currentBluePrint.getMathInformationRegister().getVariable(name).getProperties().contains("diag");
}
if (invertsDiagonalMatrix)
mathExpressionSymbol.setNameToAccess("invdiag");
Log.debug("Found inv and replaced with invdiag", "optimizeMathMatrixNameExp");
}
if (mathExpressionSymbol.getAstMathMatrixNameExpression().getMathMatrixAccessExpression().isPresent()) {
optimize(mathExpressionSymbol.getMathMatrixAccessOperatorSymbol(), precedingExpressions);
} else if (mathExpressionSymbol.getAstMathMatrixNameExpression().getEndOperator().isPresent())
Log.debug("Not handled: EndOperator", "optimizeMathMatrixNameExpr");
}
public void optimize(MathMatrixAccessOperatorSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions) {
for (MathExpressionSymbol subExpr : mathExpressionSymbol.getMathMatrixAccessSymbols()) {
optimize(subExpr, precedingExpressions);
}
}
public static String getMatrixName(MathMatrixAccessSymbol mathExpressionSymbol) {
assert mathExpressionSymbol.getMathExpressionSymbol().isPresent();
MathExpressionSymbol curMathExp = mathExpressionSymbol.getMathExpressionSymbol().get();
if (curMathExp.isValueExpression()) {
if (((MathValueExpressionSymbol) curMathExp).isNameExpression()) {
return ((MathNameExpressionSymbol) curMathExp).getNameToResolveValue();
} else {
Log.debug("Not handled getMatrixName", "MissingImplementation");
}
} else {
Log.debug("Not handled getMatrixName", "MissingImplementation");
}
return "";
}
}
......@@ -148,6 +148,7 @@ public class MathInformationRegister {
for (MathExpressionSymbol dimension : mathValueSymbol.getType().getDimensions())
var.addDimensionalInformation(dimension.getTextualRepresentation());
this.variables.add(var);
var.addProperties(mathValueSymbol.getType().getProperties());
}
public Variable getVariable(String name) {
......
......@@ -231,6 +231,7 @@ public class MathOptimizer {
static {
addOptimizationRule(new MathMultiplicationAddition());
addOptimizationRule(new MathMatrixMultiplicationOrder());
addOptimizationRule(new MathDiagonalMatrixInversionOptimization());
addOptimizationRule(new MathAssignmentPartResultReuse());
}
......
......@@ -309,7 +309,7 @@ public class MathOptimizerTest extends AbstractSymtabTest {
}
@Test
public void testMathAssignmentOptimization1() throws IOException{
public void testMathAssignmentOptimization1Octave() throws IOException{
TaggingResolver symtab = createSymTabAndTaggingResolver("src/test/resources");
ExpandedComponentInstanceSymbol componentSymbol = symtab.<ExpandedComponentInstanceSymbol>resolve("detection.normalizedLaplacianInstance", ExpandedComponentInstanceSymbol.KIND).orElse(null);
......@@ -319,4 +319,17 @@ public class MathOptimizerTest extends AbstractSymtabTest {
generatorCPP.setGenerationTargetPath("./target/generated-sources-cpp/optimizer/l1");
generatorCPP.generateFiles(componentSymbol, symtab);
}
@Test
public void testMathAssignmentOptimization1Armadillo() throws IOException{
TaggingResolver symtab = createSymTabAndTaggingResolver("src/test/resources");
ExpandedComponentInstanceSymbol componentSymbol = symtab.<ExpandedComponentInstanceSymbol>resolve("detection.normalizedLaplacianInstance", ExpandedComponentInstanceSymbol.KIND).orElse(null);
assertNotNull(componentSymbol);
GeneratorCPP generatorCPP = new GeneratorCPP();
generatorCPP.useArmadilloBackend();
generatorCPP.setUseAlgebraicOptimizations(true);
generatorCPP.setGenerationTargetPath("./target/generated-sources-cpp/armadillo/optimizer/l1");
generatorCPP.generateFiles(componentSymbol, symtab);
}
}
......@@ -2,7 +2,7 @@ package detection;
component NormalizedLaplacian<N1 n = 1>{
ports in Q(-oo:oo)^{n,n} degree,
ports in diag Q(-oo:oo)^{n,n} degree,
in Q(-oo:oo)^{n,n} W,
out Q(-oo:oo)^{n,n} nLaplacian;
......@@ -13,6 +13,7 @@ component NormalizedLaplacian<N1 n = 1>{
end
end*/
nLaplacian = sqrtm(inv(degree)) * W * sqrtm(inv(degree));
nLaplacian = sqrtm(inv(degree)*2) * W * sqrtm(inv(degree)*2 );
}
}
\ No newline at end of file
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
......@@ -80,5 +80,12 @@ static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = 1/curVal;
}
return A;
}
};
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment