Commit f8d4da88 authored by Sascha Niklas Schneiders's avatar Sascha Niklas Schneiders
Browse files

added diagonal square matrix optimization

parent eaa3e350
......@@ -83,9 +83,22 @@ public class ArmadilloHelperSource {
"}\n" +
"\n" +
"static mat getSqrtMat(mat A){\n" +
"cx_mat result=sqrtmat(A);\n" +
"return real(result);\n" +
"\n" +
"for(int i=0;i<A.n_rows;++i){\n" +
" double curVal = A(i,i);\n" +
" A(i,i) = sqrt(curVal);\n" +
"}\n" +
"return A;\n" +
"}\n" +
"\n" +
"static mat getSqrtMatDiag(mat A){\n" +
"for(int i=0;i<A.n_rows;++i){\n" +
" double curVal = A(i,i);\n" +
" A(i,i) = sqrt(curVal);\n" +
"}\n" +
"return A;\n" +
"}\n" +
"\n" +
"static mat invertDiagMatrix(mat A){\n" +
"for(int i=0;i<A.n_rows;++i){\n" +
" double curVal = A(i,i);\n" +
......
......@@ -48,5 +48,6 @@ public class MathCommandRegisterCPP extends MathCommandRegister {
registerMathCommand(new MathDetCommand());
registerMathCommand(new MathKMeansCommand());
registerMathCommand(new MathSqrtmCommand());
registerMathCommand(new MathSqrtmDiagCommand());
}
}
package de.monticore.lang.monticar.generator.cpp.commands;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.BluePrint;
import de.monticore.lang.monticar.generator.MathCommand;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.OctaveHelper;
import de.monticore.lang.monticar.generator.cpp.converter.ComponentConverter;
import de.monticore.lang.monticar.generator.cpp.converter.ExecuteMethodGenerator;
import de.monticore.lang.monticar.generator.cpp.MathFunctionFixer;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
import de.monticore.lang.monticar.generator.cpp.symbols.MathStringExpression;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
/**
* @author Sascha Schneiders
*/
public class MathSqrtmDiagCommand extends MathCommand {
public MathSqrtmDiagCommand() {
setMathCommandName("sqrtmdiag");
}
@Override
public void convert(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
String backendName = MathConverter.curBackend.getBackendName();
if (backendName.equals("OctaveBackend")) {
//convertUsingOctaveBackend(mathExpressionSymbol, bluePrint);
Log.error("OctaveBackend does not support command sqrtdiag yet");
} else if (backendName.equals("ArmadilloBackend")) {
convertUsingArmadilloBackend(mathExpressionSymbol, bluePrint);
}
}
public void convertUsingOctaveBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression(OctaveHelper.getCallBuiltInFunction(mathExpressionSymbol, "Fsqrt", "Double", valueListString, "FirstResult", false, 1));
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("octave/builtin-defun-decls");
}
public void convertUsingArmadilloBackend(MathExpressionSymbol mathExpressionSymbol, BluePrint bluePrint) {
MathMatrixNameExpressionSymbol mathMatrixNameExpressionSymbol = (MathMatrixNameExpressionSymbol) mathExpressionSymbol;
mathMatrixNameExpressionSymbol.setNameToAccess("");
String valueListString = "";
for (MathMatrixAccessSymbol accessSymbol : mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols())
MathFunctionFixer.fixMathFunctions(accessSymbol, (BluePrintCPP) bluePrint);
valueListString += ExecuteMethodGenerator.generateExecuteCode(mathExpressionSymbol, new ArrayList<String>());
//OctaveHelper.getCallOctaveFunction(mathExpressionSymbol, "sum","Double", valueListString));
List<MathMatrixAccessSymbol> newMatrixAccessSymbols = new ArrayList<>();
MathStringExpression stringExpression = new MathStringExpression("HelperA::getSqrtMatDiag" + valueListString);
newMatrixAccessSymbols.add(new MathMatrixAccessSymbol(stringExpression));
mathMatrixNameExpressionSymbol.getMathMatrixAccessOperatorSymbol().setMathMatrixAccessSymbols(newMatrixAccessSymbols);
((BluePrintCPP) bluePrint).addAdditionalIncludeString("HelperA");
}
}
\ No newline at end of file
......@@ -6,7 +6,6 @@ import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessOperatorS
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixAccessSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixExpressionSymbol;
import de.monticore.lang.math.math._symboltable.matrix.MathMatrixNameExpressionSymbol;
import de.monticore.lang.monticar.generator.cpp.BluePrintCPP;
import de.monticore.lang.monticar.generator.cpp.converter.ComponentConverter;
import de.se_rwth.commons.logging.Log;
......@@ -15,7 +14,7 @@ import java.util.List;
/**
* @author Sascha Schneiders
*/
public class MathDiagonalMatrixInversionOptimization implements MathOptimizationRule {
public class MathDiagonalMatrixOptimizations implements MathOptimizationRule {
MathStatementsSymbol currentMathStatementsSymbol = null;
@Override
......@@ -72,11 +71,18 @@ public class MathDiagonalMatrixInversionOptimization implements MathOptimization
public void optimize(MathMatrixNameExpressionSymbol mathExpressionSymbol, List<MathExpressionSymbol> precedingExpressions) {
if (mathExpressionSymbol.getNameToAccess().equals("inv")) {
//ComponentConverter.currentBluePrint.getMathInformationRegister().isDiagonalMatrix()
boolean invertsDiagonalMatrix = invertsDiagonalMatrix(mathExpressionSymbol);
boolean invertsDiagonalMatrix = isDiagonalMatrix(mathExpressionSymbol);
if (invertsDiagonalMatrix) {
mathExpressionSymbol.setNameToAccess("invdiag");
}
Log.debug("Found inv and replaced with invdiag", "optimizeMathMatrixNameExp");
} else if (mathExpressionSymbol.getNameToAccess().equals("sqrtm")) {
boolean isDiagMatrix = isDiagonalMatrix(mathExpressionSymbol);
if (isDiagMatrix) {
mathExpressionSymbol.setNameToAccess("sqrtmdiag");
}
Log.debug("Found sqrtm and replaced with sqrtdiag", "optimizeMathMatrixNameExp");
}
if (mathExpressionSymbol.getAstMathMatrixNameExpression().getMathMatrixAccessExpression().isPresent()) {
optimize(mathExpressionSymbol.getMathMatrixAccessOperatorSymbol(), precedingExpressions);
......@@ -84,12 +90,14 @@ public class MathDiagonalMatrixInversionOptimization implements MathOptimization
Log.debug("Not handled: EndOperator", "optimizeMathMatrixNameExpr");
}
private boolean invertsDiagonalMatrix(MathMatrixNameExpressionSymbol mathExpressionSymbol) {
private boolean isDiagonalMatrix(MathMatrixNameExpressionSymbol mathExpressionSymbol) {
boolean invertsDiagonalMatrix = false;
if (mathExpressionSymbol.getAstMathMatrixNameExpression().getMathMatrixAccessExpression().isPresent()) {
//optimize(mathExpressionSymbol.getMathMatrixAccessOperatorSymbol(), precedingExpressions);
//Log.error(ComponentConverter.currentBluePrint.getMathInformationRegister().getVariable("degree").getProperties().toString());
String name = getMatrixName((MathMatrixAccessSymbol) mathExpressionSymbol.getAstMathMatrixNameExpression().getMathMatrixAccessExpression().get().getMathMatrixAccesss().get(0).getSymbol().get());//TODO handle all possible cases
//System.out.println("isDiagonalMatrix: " + name);
//System.out.println(mathExpressionSymbol.getTextualRepresentation());
invertsDiagonalMatrix = ComponentConverter.currentBluePrint.getMathInformationRegister().getVariable(name).getProperties().contains("diag");
}
return invertsDiagonalMatrix;
......@@ -110,6 +118,12 @@ public class MathDiagonalMatrixInversionOptimization implements MathOptimization
} else {
Log.debug("Not handled getMatrixName", "MissingImplementation");
}
} else if (curMathExp.isMatrixExpression()) {
if (((MathMatrixExpressionSymbol) curMathExp).isMatrixNameExpression()) {
return getMatrixName(((MathMatrixNameExpressionSymbol) curMathExp).getMathMatrixAccessOperatorSymbol().getMathMatrixAccessSymbols().get(0));
} else {
Log.debug("Not handled getMatrixName", "MissingImplementation");
}
} else {
Log.debug("Not handled getMatrixName", "MissingImplementation");
}
......
......@@ -234,7 +234,7 @@ public class MathOptimizer {
static {
addOptimizationRule(new MathMultiplicationAddition());
addOptimizationRule(new MathMatrixMultiplicationOrder());
addOptimizationRule(new MathDiagonalMatrixInversionOptimization());
addOptimizationRule(new MathDiagonalMatrixOptimizations());
addOptimizationRule(new MathAssignmentPartResultReuse());
}
......
......@@ -187,8 +187,8 @@ public class BasicGenerationArmadilloTest extends AbstractSymtabTest {
ExpandedComponentInstanceSymbol componentSymbol = symtab.<ExpandedComponentInstanceSymbol>resolve("detection.objectDetector" + number, ExpandedComponentInstanceSymbol.KIND).orElse(null);
assertNotNull(componentSymbol);
GeneratorCPP generatorCPP = new GeneratorCPP();
generatorCPP.setUseThreadingOptimization(true);
generatorCPP.setUseAlgebraicOptimizations(true);
generatorCPP.setUseThreadingOptimization(true);
generatorCPP.useArmadilloBackend();
generatorCPP.setGenerationTargetPath("./target/generated-sources-cpp/armadillo/detectionObjectDetector" + number + "/l3");
List<File> files = generatorCPP.generateFiles(symtab, componentSymbol, symtab);
......
#ifndef HELPER_H
#define HELPER_H
#define _GLIBCXX_USE_CXX11_ABI 0
#include <iostream>
#include <octave/oct.h>
#include <octave/octave.h>
#include <octave/parse.h>
#include <octave/interpreter.h>
#include <stdarg.h>
#include <initializer_list>
class Helper
{
public:
static void init()
{
string_vector argv(2);
argv(0) = "embedded";
argv(1) = "-q";
octave_main(2, argv.c_str_vec(), 1);
//octave_debug=1;
//feval ("pkg", ovl ("load", "all"), 0);
}
static octave_value_list convertToOctaveValueList(double a)
{
octave_value_list in;
in(0) = a;
return in;
}
static octave_value_list convertToOctaveValueList(Matrix a)
{
octave_value_list in;
in(0) = a;
return in;
}
static octave_value_list convertToOctaveValueList(RowVector a)
{
octave_value_list in;
in(0) = a;
return in;
}
static octave_value_list convertToOctaveValueList(ColumnVector a)
{
octave_value_list in;
in(0) = a;
return in;
}
static octave_value_list convertToOctaveValueList(double a, double b)
{
octave_value_list in;
in(0) = a;
in(1) = b;
return in;
}
static octave_value_list convertToOctaveValueList(std::initializer_list<double> args)
{
octave_value_list in;
int counter = 0;
for(double element : args) {
in(counter) = octave_value(element);
++counter;
}
return in;
}
static octave_value_list convertToOctaveValueList(Matrix a, double b)
{
octave_value_list in;
in(0) = a;
in(1) = b;
return in;
}
static octave_value_list convertToOctaveValueList(RowVector a, double b)
{
octave_value_list in;
in(0) = a;
in(1) = b;
return in;
}
static octave_value_list convertToOctaveValueList(ColumnVector a, double b)
{
octave_value_list in;
in(0) = a;
in(1) = b;
return in;
}
static octave_value_list callOctaveFunction(octave_value_list in, std::string functionName,int argsOut)
{
/*octave_idx_type n = 2;
octave_value_list in;
for(octave_idx_type i = 0; i < n; i++)
in(i) = octave_value(5 * (i + 2));
octave_value_list out = feval("gcd", in, 1);
if(!error_state && out.length() > 0)
std::cout << "GCD of [" << in(0).int_value() << ", " << in(1).int_value() << "] is " << out(0).int_value()
<< std::endl;
else
std::cout << "invalid\n";
clean_up_and_exit(0);*/
/* if(functionName=="eigs")
return feval(functionName, in, 2);
else if(functionName=="kmeans")
return feval(functionName, in, 2);
*/
return feval(functionName, in, argsOut);
}
static int callOctaveFunctionIntFirstResult(octave_value_list in, std::string functionName, int argsOut)
{
// printf("callOctaveFunctionIntFirstResult pre return functionName: %s\n",functionName.c_str());
return callOctaveFunction(in, functionName,argsOut)(0).int_value();
}
static double callOctaveFunctionDoubleFirstResult(octave_value_list in, std::string functionName, int argsOut)
{
// printf("callOctaveFunctionDoubleFirstResult pre return functionName: %s\n",functionName.c_str());
return callOctaveFunction(in, functionName,argsOut)(0).double_value();
}
static Matrix callOctaveFunctionMatrixFirstResult(octave_value_list in, std::string functionName, int argsOut)
{
return callOctaveFunction(in, functionName,argsOut)(0).matrix_value();
}
static ColumnVector callOctaveFunctionColumnVectorFirstResult(octave_value_list in, std::string functionName, int argsOut)
{
printf("pre Call %s\n", functionName.c_str());
try {
in=octave_value_list();
octave_value_list list = callOctaveFunction(in, functionName,argsOut);
printf("post Call %s\n", functionName.c_str());
return list(0).array_value().as_column();
} catch(const std::exception& e) {
printf("%s\n", e.what());
}
return ColumnVector();
}
static RowVector callOctaveFunctionRowVectorFirstResult(octave_value_list in, std::string functionName, int argsOut)
{
return callOctaveFunction(in, functionName,argsOut)(0).array_value().as_row();
}
static int callOctaveFunctionIntSecondResult(octave_value_list in, std::string functionName, int argsOut)
{
return callOctaveFunction(in, functionName,argsOut)(1).int_value();
}
static double callOctaveFunctionDoubleSecondResult(octave_value_list in, std::string functionName, int argsOut)
{
return callOctaveFunction(in, functionName,argsOut)(1).double_value();
}
static Matrix callOctaveFunctionMatrixSecondResult(octave_value_list in, std::string functionName, int argsOut)
{
return callOctaveFunction(in, functionName,argsOut)(1).matrix_value();
}
static ColumnVector callOctaveFunctionColumnVectorSecondResult(octave_value_list in, std::string functionName, int argsOut)
{
return callOctaveFunction(in, functionName,argsOut)(1).array_value().as_column();
}
static RowVector callOctaveFunctionRowVectorSecondResult(octave_value_list in, std::string functionName, int argsOut)
{
return callOctaveFunction(in, functionName,argsOut)(1).array_value().as_row();
}
static Matrix getMatrixFromOctaveListFirstResult(octave_value_list list){
return list(0).matrix_value();
}
static RowVector getRowVectorFromOctaveListFirstResult(octave_value_list list){
return list(0).array_value().as_row();
}
static ColumnVector getColumnVectorFromOctaveListFirstResult(octave_value_list list){
return list(0).array_value().as_column();
}
static double getDoubleFromOctaveListFirstResult(octave_value_list list){
return list(0).double_value();
}
static int getIntFromOctaveListFirstResult(octave_value_list list){
return list(0).int_value();
}
};
#endif // HELPER_H
\ No newline at end of file
#ifndef DETECTION_OBJECTDETECTOR_SPECTRALCLUSTERER_1__SIMILARITY
#define DETECTION_OBJECTDETECTOR_SPECTRALCLUSTERER_1__SIMILARITY
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
#include "Helper.h"
#include "octave/builtin-defun-decls.h"
using namespace arma;
class detection_objectDetector_spectralClusterer_1__similarity{
const int n = 2500;
public:
mat data;
mat similarity;
mat degree;
void init()
{
data=mat(n,3);
similarity=mat(n,n);
degree=mat(n,n);
}
void execute()
{
for( auto i=1;i<=(Helper::getDoubleFromOctaveListFirstResult(Fsize(Helper::convertToOctaveValueList(data, 1),1)));++i){
for( auto j=1;j<=(Helper::getDoubleFromOctaveListFirstResult(Fsize(Helper::convertToOctaveValueList(data, 1),1)));++j){
double dist = (Helper::getDoubleFromOctaveListFirstResult(Fsqrt(Helper::convertToOctaveValueList(Helper::getDoubleFromOctaveListFirstResult(Fmpower(Helper::convertToOctaveValueList(Helper::getDoubleFromOctaveListFirstResult(Fmpower(Helper::convertToOctaveValueList(Helper::getDoubleFromOctaveListFirstResult(Fmpower(Helper::convertToOctaveValueList((data(i-1, 1-1)-data(j-1, 1-1)),2+(data(i-1, 2-1)-data(j-1, 2-1))),1)),2+(data(i-1, 3-1)-data(j-1, 3-1))),1)),2),1))),1)));
similarity(i-1, j-1) = (Helper::getDoubleFromOctaveListFirstResult(Fexp(Helper::convertToOctaveValueList((0-dist)/(2)),1)));
}
}
for( auto k=1;k<=(Helper::getDoubleFromOctaveListFirstResult(Fsize(Helper::convertToOctaveValueList(similarity, 1),1)));++k){
degree(k-1, k-1) = (Helper::getDoubleFromOctaveListFirstResult(Fsum(Helper::convertToOctaveValueList(similarity.row(k-1)),1)));
}
}
};
#endif
......@@ -77,9 +77,22 @@ static double getEuclideanDistance(mat A, int colIndexA, mat B, int colIndexB){
}
static mat getSqrtMat(mat A){
cx_mat result=sqrtmat(A);
return real(result);
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = sqrt(curVal);
}
return A;
}
static mat getSqrtMatDiag(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
A(i,i) = sqrt(curVal);
}
return A;
}
static mat invertDiagMatrix(mat A){
for(int i=0;i<A.n_rows;++i){
double curVal = A(i,i);
......
......@@ -32,12 +32,12 @@ spectralClusterer[3].init();
void execute()
{
spectralClusterer[0].imgMatrix = imgFront;
spectralClusterer[1].imgMatrix = imgRight;
spectralClusterer[2].imgMatrix = imgLeft;
spectralClusterer[3].imgMatrix = imgBack;
spectralClusterer[0].execute();
spectralClusterer[1].imgMatrix = imgRight;
spectralClusterer[1].execute();
spectralClusterer[2].imgMatrix = imgLeft;
spectralClusterer[2].execute();
spectralClusterer[3].imgMatrix = imgBack;
spectralClusterer[3].execute();
clusters[0] = spectralClusterer[0].clusters;
clusters[1] = spectralClusterer[1].clusters;
......
......@@ -11,6 +11,8 @@
using namespace arma;
class detection_objectDetector_spectralClusterer_1_{
const int n = 2500;
const int k = 4;
const int maximumClusters = 4;
public:
mat imgMatrix;
mat clusters;
......
......@@ -4,11 +4,11 @@
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
#include "Helper.h"
#include "octave/builtin-defun-decls.h"
#include "HelperA.h"
using namespace arma;
class detection_objectDetector_spectralClusterer_1__eigenSolver{
const int n = 2500;
const int targetEigenvectors = 4;
public:
mat matrix;
mat eigenvectors;
......@@ -19,10 +19,10 @@ eigenvectors=mat(n,targetEigenvectors);
}
void execute()
{
mat eigenVectors = (Helper::getColumnVectorFromOctaveListFirstResult(Feig(Helper::convertToOctaveValueList(matrix),2)));
mat eigenVectors = (HelperA::getEigenVectors((matrix)));
double counter = 1;
double start = (Helper::getDoubleFromOctaveListFirstResult(Fsize(Helper::convertToOctaveValueList(eigenVectors, 2),1)))-(targetEigenvectors-1);
for( auto i=start;i<=(Helper::getDoubleFromOctaveListFirstResult(Fsize(Helper::convertToOctaveValueList(eigenVectors, 1),1)));++i){
double start = (eigenVectors.n_cols)-(targetEigenvectors-1);
for( auto i=start;i<=(eigenVectors.n_rows);++i){
eigenvectors.col(counter-1) = eigenVectors.col(i-1);
counter = counter+1;
}
......
......@@ -4,11 +4,12 @@
#define M_PI 3.14159265358979323846
#endif
#include "armadillo.h"
#include "Helper.h"
#include "octave/builtin-defun-decls.h"
#include "HelperA.h"
using namespace arma;
class detection_objectDetector_spectralClusterer_1__kMeansClustering{
const int n = 2500;
const int amountVectors = 4;
const int maximumClusters = 4;
public:
mat vectors;
mat clusters;
......@@ -19,13 +20,13 @@ clusters=mat(n,1);
}
void execute()
{
mat UMatrix;
for( auto i=1;i<=(Helper::getDoubleFromOctaveListFirstResult(Fsize(Helper::convertToOctaveValueList(vectors, 1),1)));++i){