Commit cf77f613 authored by Dinh-An Ho's avatar Dinh-An Ho

Merge branch 'ML_clustering' of...

Merge branch 'ML_clustering' of https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMAM2Middleware into ML_clustering
parents 459c0be5 8b30b74c
package de.monticore.lang.monticar.generator.middleware.clustering;
import de.monticore.expressionsbasis._ast.ASTExpression;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ConnectorSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.PortSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.*;
import de.monticore.lang.embeddedmontiarc.tagging.middleware.ros.RosConnectionSymbol;
import de.monticore.lang.math._ast.ASTNumberExpression;
import de.monticore.lang.monticar.common2._ast.ASTCommonMatrixType;
import de.monticore.lang.monticar.ts.MCTypeSymbol;
import de.monticore.lang.monticar.ts.references.MCASTTypeSymbolReference;
import de.monticore.lang.monticar.ts.references.MCTypeReference;
import de.monticore.symboltable.CommonScope;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Symbol;
import de.monticore.symboltable.resolving.ResolvingFilter;
import de.se_rwth.commons.logging.Log;
import java.util.*;
import java.util.stream.Collector;
import java.util.stream.Collectors;
public class AutomaticClusteringHelper {
public static double[][] createAdjacencyMatrix(List<ExpandedComponentInstanceSymbol> subcomps, Collection<ConnectorSymbol> connectors, Map<String, Integer> subcompLabels) {
......@@ -127,4 +131,159 @@ public class AutomaticClusteringHelper {
return 50;
}
public static ExpandedComponentInstanceSymbol flattenArchitecture(ExpandedComponentInstanceSymbol symbol){
if (symbol.getSubComponents().isEmpty()){
if (symbol.getEnclosingComponent().isPresent()){
return copySymbolWithSystemName(symbol);
}
return symbol;
}
for (ExpandedComponentInstanceSymbol sym : symbol.getSubComponents()){
symbol = flattenArchitecture(sym);
}
if (symbol.getEnclosingComponent().isPresent()){
ExpandedComponentInstanceSymbol enclosingComponent = copySymbolWithSystemName(symbol);
symbol = enclosingComponent.getSubComponent(symbol.getFullName().replace(".", "_")).get();
ExpandedComponentInstanceSymbol thisSymbol = symbol;
List<ExpandedComponentInstanceSymbol> newSubcomponents = enclosingComponent.getSubComponents().stream()
.filter(e -> !e.getFullName().equals(thisSymbol.getFullName())).collect(Collectors.toList());
newSubcomponents.addAll(newSubcomponents.size(), new ArrayList<>(symbol.getSubComponents()));
HashSet<String> incomingPorts = new HashSet<>(symbol.getIncomingPorts().stream().map(p ->{
return p.getFullName();
}).collect(Collectors.toList()));
HashSet<String> outgoingPorts = new HashSet<>(symbol.getOutgoingPorts().stream().map(p ->{
return p.getFullName();
}).collect(Collectors.toList()));
//only connectors from incoming ports
Set<ConnectorSymbol> incomingConnectors = symbol.getConnectors().stream()
.filter(c -> incomingPorts.contains(thisSymbol.getFullName() + "." + thisSymbol.getName() + "_"+ c.getSource()))
.collect(Collectors.toSet());
//only connectors going into symbol
Set<ConnectorSymbol> incomingParentConnectors = enclosingComponent.getConnectors().stream()
.filter(c -> c.getTargetPort().getComponentInstance().get().getFullName().equals(thisSymbol.getFullName()))
.collect(Collectors.toSet());
//only connectors from outgoing ports
Set<ConnectorSymbol> outgoingConnectors = symbol.getConnectors().stream()
.filter(c -> outgoingPorts.contains(thisSymbol.getFullName() + "." + thisSymbol.getName() + "_"+ c.getTarget()))
.collect(Collectors.toSet());
//only connectors going out of symbol
Set<ConnectorSymbol> outgoingParentConnectors = enclosingComponent.getConnectors().stream()
.filter(c -> c.getSourcePort().getComponentInstance().get().getFullName().equals(thisSymbol.getFullName()))
.collect(Collectors.toSet());
//untouched connectors of enclosing symbol
Set<ConnectorSymbol> newConnectors = enclosingComponent.getConnectors().stream()
.filter(c -> !(incomingParentConnectors.contains(c) || outgoingParentConnectors.contains(c)))
.collect(Collectors.toSet());
//untouched connectors of symbol with renamed ports
newConnectors.addAll(symbol.getConnectors().stream()
//.map(c -> {return mapToNewName(c);})
.filter(c -> !(incomingConnectors.contains(c) || outgoingConnectors.contains(c)))
.collect(Collectors.toSet()));
for (ConnectorSymbol con : incomingConnectors){
for (ConnectorSymbol connectorSymbol : incomingParentConnectors){
if (con.getSource().equals(connectorSymbol.getTarget().replaceFirst(".*_", ""))){
ConnectorSymbol tmpConnector = ConnectorSymbol.builder()
.setSource(connectorSymbol.getSource())
.setTarget(con.getTarget())
.build();
newConnectors.add(tmpConnector);
}
}
}
for (ConnectorSymbol con : outgoingConnectors){
for (ConnectorSymbol connectorSymbol : outgoingParentConnectors){
if (con.getTarget().equals(connectorSymbol.getSource().replaceFirst(".*_", ""))){
ConnectorSymbol tmpConnector = ConnectorSymbol.builder()
.setSource(con.getSource())
.setTarget(connectorSymbol.getTarget())
.build();
newConnectors.add(tmpConnector);
}
}
}
ExpandedComponentInstanceSymbol res = constructECIS(enclosingComponent, newSubcomponents, newConnectors,
enclosingComponent.getName(), new ArrayList<>(enclosingComponent.getPortsList()));
return res;
} else {
return symbol;
}
}
private static ExpandedComponentInstanceSymbol copySymbolWithSystemName(ExpandedComponentInstanceSymbol symbol) {
ExpandedComponentInstanceSymbol enclosingComponent = symbol.getEnclosingComponent().get();
ExpandedComponentInstanceSymbol thisSymbol = symbol;
List<ExpandedComponentInstanceSymbol> subcomponents = enclosingComponent.getSubComponents().stream()
.filter(e -> !e.getFullName().equals(thisSymbol.getFullName())).collect(Collectors.toList());
List<PortSymbol> ports = new ArrayList<>();
String newName = symbol.getFullName().replace(".", "_");
String newEnclosingName = enclosingComponent.getFullName().replace(".", "_");
for (PortSymbol port : symbol.getPortsList()) {
ports.add((PortSymbol)(port.isConstant() ?
(new EMAPortBuilder()).setName(newName + "_" + port.getName()).setDirection(port.isIncoming())
.setTypeReference(port.getTypeReference()).setConstantValue(((ConstantPortSymbol)port).getConstantValue())
.setASTNode(port.getAstNode()).buildConstantPort()
: (new EMAPortBuilder()).setName(newName + "_" + port.getName()).setDirection(port.isIncoming())
.setTypeReference(port.getTypeReference()).setASTNode(port.getAstNode()).setConfig(port.isConfig())
.setMiddlewareSymbol(port.getMiddlewareSymbol()).build()));
}
ExpandedComponentInstanceSymbol e = constructECIS(symbol, new ArrayList<>(symbol.getSubComponents()),
new HashSet<>(symbol.getConnectors()), newName, ports);
subcomponents.add(e);
HashSet<String> incomingPorts = new HashSet<>(symbol.getIncomingPorts().stream().map(p ->{
return p.getFullName();
}).collect(Collectors.toList()));
HashSet<String> outgoingPorts = new HashSet<>(symbol.getOutgoingPorts().stream().map(p ->{
return p.getFullName();
}).collect(Collectors.toList()));
Set<ConnectorSymbol> newConnectors = enclosingComponent.getConnectors().stream()
.map(c ->{
if (incomingPorts.contains(c.getComponentInstance().get().getFullName() + "." + c.getTarget().substring(0,1).toLowerCase()
+ c.getTarget().substring(1))){
c.setSource(c.getSource());
c.setTarget(c.getTarget().replaceFirst("[^.]*.",e.getName() + "." + newName + "_"));
} else if (outgoingPorts.contains(c.getComponentInstance().get().getFullName() + "." + c.getSource().substring(0,1).toLowerCase()
+ c.getSource().substring(1))){
c.setSource(c.getSource().replaceFirst("[^.]*.",e.getName() + "." + newName + "_"));
c.setTarget(c.getTarget());
}
return c;
})
.collect(Collectors.toSet());
return constructECIS(enclosingComponent, subcomponents, newConnectors, enclosingComponent.getName(),
new ArrayList<>(enclosingComponent.getPortsList()));
}
private static ExpandedComponentInstanceSymbol constructECIS(ExpandedComponentInstanceSymbol enclosingComponent,
List<ExpandedComponentInstanceSymbol> newSubcomponents,
Set<ConnectorSymbol> newConnectors, String name, List<PortSymbol> ports) {
Set<ResolvingFilter<? extends Symbol>> resolvingFilters = enclosingComponent.getSpannedScope().getResolvingFilters();
newSubcomponents.forEach(sc -> ((CommonScope) sc.getSpannedScope()).setResolvingFilters(resolvingFilters));
ExpandedComponentInstanceSymbol res = new ExpandedComponentInstanceBuilder()
.setName(name)
.setSymbolReference(enclosingComponent.getComponentType())
.addPorts(ports)
.addConnectors(newConnectors)
.addSubComponents(newSubcomponents)
.addResolutionDeclarationSymbols(enclosingComponent.getResolutionDeclarationSymbols())
.build();
((CommonScope) res.getSpannedScope()).setResolvingFilters(resolvingFilters);
res.setEnclosingScope((MutableScope) enclosingComponent.getEnclosingScope());
return res;
}
}
package de.monticore.lang.monticar.generator.middleware.clustering;
import de.monticore.lang.monticar.generator.middleware.clustering.algorithms.DBSCANClusteringAlgorithm;
import de.monticore.lang.monticar.generator.middleware.clustering.algorithms.MarkovClusteringAlgorithm;
import de.monticore.lang.monticar.generator.middleware.clustering.algorithms.SpectralClusteringAlgorithm;
import de.se_rwth.commons.logging.Log;
......@@ -14,6 +15,7 @@ public class ClusteringAlgorithmFactory {
switch (kind){
case SPECTRAL_CLUSTERER: return new SpectralClusteringAlgorithm();
case MARKOV_CLUSTERER: return new MarkovClusteringAlgorithm();
case DBSCAN_CLUSTERER: return new DBSCANClusteringAlgorithm();
default: Log.error("0x1D54C: No clustering algorithm found for ClusteringKind " + kind);
}
return null;
......
......@@ -2,5 +2,6 @@ package de.monticore.lang.monticar.generator.middleware.clustering;
public enum ClusteringKind {
SPECTRAL_CLUSTERER,
MARKOV_CLUSTERER
MARKOV_CLUSTERER,
DBSCAN_CLUSTERER
}
package de.monticore.lang.monticar.generator.middleware.clustering.algorithms;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.monticar.generator.middleware.clustering.AutomaticClusteringHelper;
import de.monticore.lang.monticar.generator.middleware.clustering.ClusteringAlgorithm;
import de.monticore.lang.monticar.generator.middleware.helpers.ComponentHelper;
import de.se_rwth.commons.logging.Log;
import smile.clustering.DBSCAN;
import java.util.*;
// DBSCAN clusterer product implementation
public class DBSCANClusteringAlgorithm implements ClusteringAlgorithm {
@Override
public List<Set<ExpandedComponentInstanceSymbol>> cluster(ExpandedComponentInstanceSymbol component, Object... args) {
List<Set<ExpandedComponentInstanceSymbol>> res = new ArrayList<>();
// params
Integer minPts= null;
Double radius= null;
// find mandatory params
Map<DBSCANClusteringBuilder.DBSCANParameters, Boolean> mandatoryParams = new HashMap<DBSCANClusteringBuilder.DBSCANParameters, Boolean>();
DBSCANClusteringBuilder.DBSCANParameters[] dbscanParams = DBSCANClusteringBuilder.DBSCANParameters.values();
for (DBSCANClusteringBuilder.DBSCANParameters param : dbscanParams) {
// set all mandatory params to "unset"
if (param.isMandatory()) mandatoryParams.put(param, false);
}
// Handle (optional) params for DBSCANClustering.
// Params come as one or multiple key-value-pairs in the optional varargs array for this method,
// with key as a string (containing the name of the parameter to pass thru to the clusterer) followed by its value as an object
DBSCANClusteringBuilder.DBSCANParameters key;
Object value;
int v = 0;
while (v < args.length) {
if (args[v] instanceof DBSCANClusteringBuilder.DBSCANParameters) {
key = (DBSCANClusteringBuilder.DBSCANParameters)args[v];
if (v+1 < args.length) {
value = args[v + 1];
switch (key) {
case DBSCAN_MIN_PTS:
if (value instanceof Integer) {
minPts= (Integer) value;
}
break;
case DBSCAN_RADIUS:
if (value instanceof Double) {
radius= (Double) value;
}
break;
}
// set mandatory param to "set"
if (key.isMandatory()) mandatoryParams.replace(key, true);
}
}
v = v + 2;
}
// are all mandatory params set?
boolean error= false;
Iterator iterator = mandatoryParams.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry param = (Map.Entry) iterator.next();
if (!(Boolean)param.getValue()) error= true;
}
if (error) {
Log.error("DBSCANClusteringAlgorithm: Mandatory parameter(s) missing!");
} else {
List<ExpandedComponentInstanceSymbol> subcompsOrderedByName = ComponentHelper.getSubcompsOrderedByName(component);
Map<String, Integer> labelsForSubcomps = ComponentHelper.getLabelsForSubcomps(subcompsOrderedByName);
double[][] adjMatrix = AutomaticClusteringHelper.createAdjacencyMatrix(subcompsOrderedByName,
ComponentHelper.getInnerConnectors(component),
labelsForSubcomps);
DBSCAN clustering;
DBSCANClusteringBuilder builder = new DBSCANClusteringBuilder(adjMatrix, minPts, radius);
clustering = builder.build();
int[] labels = clustering.getClusterLabel();
for (int i = 0; i < clustering.getNumClusters(); i++) {
res.add(new HashSet<>());
}
subcompsOrderedByName.forEach(sc -> {
int curClusterLabel = labels[labelsForSubcomps.get(sc.getFullName())];
res.get(curClusterLabel).add(sc);
});
}
return res;
}
}
package de.monticore.lang.monticar.generator.middleware.clustering.algorithms;
import smile.clustering.DBSCAN;
public class DBSCANClusteringBuilder {
private double[][] data; // expected: (weighted) adjacency matrix
private Integer minPts;
private Double radius;
// parameter list, true if mandatory
public enum DBSCANParameters {
DBSCAN_MIN_PTS(true),
DBSCAN_RADIUS(true);
private Boolean mandatory;
DBSCANParameters(Boolean mandatory) {
this.mandatory = mandatory;
}
public Boolean isMandatory() {
return this.mandatory;
}
}
public DBSCANClusteringBuilder(double[][] data, int minPts, double radius) {
this.data = data;
this.minPts = minPts;
this.radius= radius;
}
public DBSCANClusteringBuilder setData(double[][] data) {
this.data = data;
return this;
}
public DBSCANClusteringBuilder setMinPts(int minPts) {
this.minPts = minPts;
return this;
}
public DBSCANClusteringBuilder setRadius(double radius) {
this.radius = radius;
return this;
}
public DBSCAN build() {
DBSCAN dbc;
// |nodes| instances of data with pseudo x,y coords. set to node no.
double[][] pdata = new double[this.data.length][2];
for (int i=0; i<pdata.length; i++) {
pdata[i][0]= i;
pdata[i][1]= i;
}
dbc = new DBSCAN(pdata, new DBSCANDistance(data), minPts, radius);
return dbc;
}
}
package de.monticore.lang.monticar.generator.middleware.clustering.algorithms;
import smile.math.distance.Metric;
import java.io.Serializable;
public class DBSCANDistance implements Metric<double[]>, Serializable {
private static final long serialVersionUID = 1L;
private double[][] weightedAdjacencyMatrix = null;
public DBSCANDistance(double[][] weightedAdjacencyMatrix) {
this.weightedAdjacencyMatrix = weightedAdjacencyMatrix;
}
public String toString() {
return String.format("DBSCAN distance");
}
public double d(double[] x, double[] y) {
System.out.println((int)x[0] + ", " + (int)y[0] + ": " + weightedAdjacencyMatrix[(int)x[0]][(int)y[0]]);
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("Arrays have different length: x[%d], y[%d]", x.length, y.length));
} else {
double ret= weightedAdjacencyMatrix[(int)x[0]][(int)y[0]];
return ret > 0 ? ret : Double.MAX_VALUE; // set zeros in adj. matrix to infinity
}
}
}
package de.monticore.lang.monticar.generator.middleware;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ConnectorSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.PortSymbol;
import de.monticore.lang.monticar.generator.middleware.clustering.AutomaticClusteringHelper;
import de.monticore.lang.monticar.generator.middleware.clustering.ClusteringAlgorithm;
import de.monticore.lang.monticar.generator.middleware.clustering.ClusteringAlgorithmFactory;
import de.monticore.lang.monticar.generator.middleware.clustering.ClusteringKind;
import de.monticore.lang.monticar.generator.middleware.clustering.algorithms.MarkovClusteringAlgorithm;
import de.monticore.lang.monticar.generator.middleware.clustering.algorithms.SpectralClusteringAlgorithm;
import de.monticore.lang.monticar.generator.middleware.clustering.algorithms.SpectralClusteringBuilder;
import de.monticore.lang.monticar.generator.middleware.clustering.algorithms.*;
import de.monticore.lang.monticar.generator.middleware.helpers.ComponentHelper;
import de.monticore.lang.monticar.generator.middleware.impls.CPPGenImpl;
import de.monticore.lang.monticar.generator.middleware.impls.RosCppGenImpl;
......@@ -21,8 +20,10 @@ import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.DenseInstance;
import net.sf.javaml.core.Instance;
import org.junit.Test;
import smile.clustering.DBSCAN;
import smile.clustering.KMeans;
import smile.clustering.SpectralClustering;
import smile.math.distance.MinkowskiDistance;
import java.io.IOException;
import java.util.*;
......@@ -98,6 +99,36 @@ public class AutomaticClusteringTest extends AbstractSymtabTest{
}
@Test
public void testFlattenAlgorithm1(){
TaggingResolver taggingResolver = AbstractSymtabTest.createSymTabAndTaggingResolver(TEST_PATH);
ExpandedComponentInstanceSymbol componentInstanceSymbol = taggingResolver.<ExpandedComponentInstanceSymbol>resolve("lab.overallSystem", ExpandedComponentInstanceSymbol.KIND).orElse(null);
assertNotNull(componentInstanceSymbol);
ExpandedComponentInstanceSymbol newComponentInstanceSymbol = AutomaticClusteringHelper.flattenArchitecture(componentInstanceSymbol);
assertNotNull(newComponentInstanceSymbol);
Collection<ExpandedComponentInstanceSymbol> subComponents = newComponentInstanceSymbol.getSubComponents();
Collection<ConnectorSymbol> connectors = newComponentInstanceSymbol.getConnectors();
assertEquals(10, subComponents.size());
assertEquals(20, connectors.size());
}
@Test
public void testFlattenAlgorithm2(){
TaggingResolver taggingResolver = AbstractSymtabTest.createSymTabAndTaggingResolver(TEST_PATH);
ExpandedComponentInstanceSymbol componentInstanceSymbol = taggingResolver.<ExpandedComponentInstanceSymbol>resolve("lab.spanningSystem", ExpandedComponentInstanceSymbol.KIND).orElse(null);
assertNotNull(componentInstanceSymbol);
ExpandedComponentInstanceSymbol newComponentInstanceSymbol = AutomaticClusteringHelper.flattenArchitecture(componentInstanceSymbol);
assertNotNull(newComponentInstanceSymbol);
Collection<ExpandedComponentInstanceSymbol> subComponents = newComponentInstanceSymbol.getSubComponents();
Collection<ConnectorSymbol> connectors = newComponentInstanceSymbol.getConnectors();
assertEquals(20, subComponents.size());
assertEquals(40, connectors.size());
}
// todo: gotta move this thing later, just temporarily here for testing purposes
public static Dataset[] getClustering(Dataset data, SparseMatrix matrix) {
......@@ -148,6 +179,82 @@ public class AutomaticClusteringTest extends AbstractSymtabTest{
return output;
}
@Test
public void testDBSCANClustering(){
/*
0----1----4---6
| \/ | \ /
| /\ | 5
2----3
expected: 2 clusters a, b with a={0,1,2,3} and b={4,5,6}
*/
// for DBSCAN this could be directly weighted
double[][] adjacencyMatrix =
{
{0, 1, 1, 1, 0, 0, 0},
{1, 0, 1, 1, 1, 0, 0},
{1, 1, 0, 1, 0, 0, 0},
{1, 1, 1, 0, 0, 0, 0},
{0, 1, 0, 0, 0, 1, 1},
{0, 0, 0, 0, 1, 0, 1},
{0, 0, 0, 0, 1, 1, 0},
};
// |nodes| instances of data with pseudo x,y coords. set to node no.
double[][] data = new double[adjacencyMatrix.length][2];
for (int i=0; i<data.length; i++) {
data[i][0]= i;
data[i][1]= i;
}
// mission critical
int minPts= 2;
double radius= 5;
DBSCAN clustering = new DBSCAN(data, new DBSCANDistance(adjacencyMatrix), minPts, radius);
int[] labels = clustering.getClusterLabel();
for (int label : labels) {
System.out.println(label);
}
assertEquals(7, labels.length);
assertTrue(labels[0] == labels[1]);
assertTrue(labels[0] == labels[2]);
assertTrue(labels[0] == labels[3]);
assertTrue(labels[1] == labels[0]);
assertTrue(labels[1] == labels[2]);
assertTrue(labels[1] == labels[3]);
assertTrue(labels[1] != labels[4]); // expected cut
assertTrue(labels[2] == labels[0]);
assertTrue(labels[2] == labels[1]);
assertTrue(labels[2] == labels[3]);
assertTrue(labels[3] == labels[0]);
assertTrue(labels[3] == labels[1]);
assertTrue(labels[3] == labels[2]);
assertTrue(labels[4] != labels[1]); // expected cut
assertTrue(labels[4] == labels[5]);
assertTrue(labels[4] == labels[6]);
assertTrue(labels[5] == labels[4]);
assertTrue(labels[5] == labels[6]);
assertTrue(labels[6] == labels[4]);