Commit 69d859ae authored by Philipp Görick's avatar Philipp Görick

Merge remote-tracking branch 'origin/ML_clustering' into ML_clustering

parents 198ba439 f2d1534c
Pipeline #110952 failed with stages
in 34 minutes and 26 seconds
......@@ -2,6 +2,8 @@ package de.monticore.lang.monticar.generator.middleware.cli.algorithms;
import de.monticore.lang.monticar.generator.middleware.clustering.ClusteringAlgorithm;
import java.util.Optional;
public abstract class AlgorithmCliParameters {
public static final String TYPE_SPECTRAL_CLUSTERING = "SpectralClustering";
public static final String TYPE_UNKOWN = "Unkown";
......@@ -22,4 +24,8 @@ public abstract class AlgorithmCliParameters {
public abstract Object[] asAlgorithmArgs();
public abstract boolean isValid();
public Optional<Integer> expectedClusterCount(){
return Optional.empty();
}
}
......@@ -68,6 +68,11 @@ public class SpectralClusteringCliParameters extends AlgorithmCliParameters {
return numberOfClusters != null;
}
@Override
public Optional<Integer> expectedClusterCount() {
return Optional.of(getNumberOfClusters().get());
}
public Optional<Integer> getNumberOfClusters() {
return Optional.ofNullable(numberOfClusters);
}
......
......@@ -15,7 +15,9 @@ import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
public class ClusteringResult {
......@@ -23,27 +25,51 @@ public class ClusteringResult {
private EMAComponentInstanceSymbol component;
private AlgorithmCliParameters parameters;
private List<Set<EMAComponentInstanceSymbol>> clustering;
private long durration;
private long duration;
private int componentNumber;
private boolean valid;
private ClusteringResult(EMAComponentInstanceSymbol component, AlgorithmCliParameters parameters,
List<Set<EMAComponentInstanceSymbol>> clustering, long durration, int componentNumber) {
List<Set<EMAComponentInstanceSymbol>> clustering, long duration, int componentNumber, boolean valid) {
this.component = component;
this.parameters = parameters;
this.clustering = clustering;
this.durration = durration;
this.duration = duration;
this.componentNumber = componentNumber;
this.valid = valid;
}
public static ClusteringResult fromParameters(EMAComponentInstanceSymbol component, AlgorithmCliParameters parameters){
public static ClusteringResult fromParameters(EMAComponentInstanceSymbol component, AlgorithmCliParameters parameters) {
List<Set<EMAComponentInstanceSymbol>> res;
long startTime = System.currentTimeMillis();
List<Set<EMAComponentInstanceSymbol>> res = parameters.asClusteringAlgorithm().clusterWithState(component);
try {
res = parameters.asClusteringAlgorithm().clusterWithState(component);
} catch (Exception e) {
Log.warn("Marking this result as invalid. Error clustering the component.", e);
return new ClusteringResult(component, parameters, new ArrayList<>(), -1, component.getSubComponents().size(), false);
}
long endTime = System.currentTimeMillis();
boolean curValid = true;
int clustersBefore = res.size();
res.removeIf(Set::isEmpty);
if (clustersBefore != res.size()) {
Log.warn("Removed " + (clustersBefore - res.size()) + " empty clusters for algorithm " + parameters.toString());
}
Optional<Integer> expClusters = parameters.expectedClusterCount();
if(expClusters.isPresent() && !expClusters.get().equals(res.size())){
curValid = false;
Log.warn("Marking this result as invalid. The actual number of clusters(" + res.size() + ") does not equal the expected number(" + expClusters.get() +")");
}
int componentNumber = 0;
for (Set<EMAComponentInstanceSymbol> cluster : res) {
componentNumber += cluster.size();
}
return new ClusteringResult(component, parameters, res, endTime - startTime, componentNumber);
return new ClusteringResult(component, parameters, res, endTime - startTime, componentNumber, curValid);
}
public double getScore(){
......@@ -69,8 +95,8 @@ public class ClusteringResult {
return clustering.size();
}
private long getDurration() {
return this.durration;
private long getDuration() {
return this.duration;
}
public int getComponentNumber() {
......@@ -87,7 +113,7 @@ public class ClusteringResult {
String prefix = "//Algorithm: " + this.getParameters().toString() + "\n" +
"//Number of clusters: " + this.getNumberOfClusters() + "\n" +
"//Score: " + this.getScore() + "\n" +
"//Durration in ms: " + this.getDurration() + "\n";
"//Durration in ms: " + this.getDuration() + "\n";
String content = MiddlewareTagGenImpl.getFileContent(component, this.clustering);
res.setFileContent(prefix + content);
......@@ -133,9 +159,16 @@ public class ClusteringResult {
result.addProperty("Algorithm", this.getParameters().toString());
result.addProperty("NumberOfClusters", this.getNumberOfClusters());
result.addProperty("Score", this.getScore());
result.addProperty("DurationInMs", this.getDurration());
result.addProperty("DurationInMs", this.getDuration());
result.addProperty("ComponentNumber", this.getComponentNumber());
return result;
}
public boolean isValid() {
return valid;
}
public void setValid(boolean valid) {
this.valid = valid;
}
}
\ No newline at end of file
package de.monticore.lang.monticar.generator.middleware.clustering.qualityMetric;
import java.util.HashSet;
import java.util.Set;
public class SilhouetteIndex {
private double[][] distanceMatrix;
private int[] clusteringLabels;
public SilhouetteIndex(double[][] distanceMatrix, int[] clusteringLabels) {
this.distanceMatrix = distanceMatrix;
this.clusteringLabels = clusteringLabels;
}
public double getSilhouetteScore(){
double sum = 0d;
for (int i = 0; i < clusteringLabels.length; i++) {
sum += S(i);
}
return sum/clusteringLabels.length;
}
public double S(int o){
double a = distA(o);
double b = distB(o);
if(a <= 0.0000000d && b <= 0.0000000d){
return 0;
}else{
return (b-a)/Math.max(a,b);
}
}
public double distA(int o){
int clusterOfO = clusteringLabels[o];
return averageDistToCluster(o, clusterOfO);
}
public double distB(int o){
Set<Integer> processedLabels = new HashSet<>();
int clusterOfO = clusteringLabels[o];
double minDist = Double.MAX_VALUE;
for (int label : clusteringLabels) {
if(label != clusterOfO) {
if (!processedLabels.contains(label)) {
double d = averageDistToCluster(o, label);
if (d < minDist) {
minDist = d;
}
processedLabels.add(label);
}
}
}
return minDist;
}
private double averageDistToCluster(int o, int clusterLabel){
int numInCluster = 0;
double sum = 0d;
for (int i = 0; i < clusteringLabels.length; i++) {
if(clusteringLabels[i] == clusterLabel){
sum += distanceMatrix[o][i];
numInCluster++;
}
}
return sum/((double)numInCluster);
}
}
......@@ -3,12 +3,11 @@ package de.monticore.lang.monticar.generator.middleware;
import com.clust4j.algo.AffinityPropagation;
import com.clust4j.algo.AffinityPropagationParameters;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAConnectorInstanceSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAPortInstanceSymbol;
import de.monticore.lang.monticar.generator.middleware.clustering.*;
import de.monticore.lang.monticar.generator.middleware.clustering.algorithms.*;
import de.monticore.lang.monticar.generator.middleware.clustering.qualityMetric.SilhouetteIndex;
import de.monticore.lang.monticar.generator.middleware.clustering.visualization.ModelVisualizer;
import de.monticore.lang.monticar.generator.middleware.clustering.visualization.SimpleModelViewer;
import de.monticore.lang.monticar.generator.middleware.helpers.ComponentHelper;
import de.monticore.lang.monticar.generator.middleware.impls.CPPGenImpl;
import de.monticore.lang.monticar.generator.middleware.impls.RosCppGenImpl;
......@@ -22,16 +21,11 @@ import net.sf.javaml.core.DenseInstance;
import net.sf.javaml.core.Instance;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.graphstream.graph.Edge;
import org.graphstream.graph.Graph;
import org.graphstream.graph.Node;
import org.graphstream.graph.implementations.SingleGraph;
import org.graphstream.stream.file.FileSinkImages;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import smile.clustering.DBSCAN;
import smile.clustering.SpectralClustering;
import smile.plot.Palette;
import java.io.IOException;
import java.util.*;
......@@ -83,11 +77,11 @@ public class AutomaticClusteringTest extends AbstractSymtabTest{
//sorted by full name: alex, combine, dinhAn, michael, philipp
double[][] expRes = {{0,10,0,0,0}
,{10,0,10,10,10}
,{0,10,0,0,0}
,{0,10,0,0,0}
,{0,10,0,0,0}};
double[][] expRes = {{0,8,0,0,0}
,{8,0,8,8,8}
,{0,8,0,0,0}
,{0,8,0,0,0}
,{0,8,0,0,0}};
for(int i = 0; i< expRes.length; i++){
for(int j = 0; j < expRes[i].length;j++){
......@@ -110,9 +104,9 @@ public class AutomaticClusteringTest extends AbstractSymtabTest{
//sorted full name: sub1, sub2, sub3
double[][] expRes = {{0,10,20} //sub1
,{10,0,0} //sub2
,{20,0,0}}; //sub3
double[][] expRes = {{0,8,16} //sub1
,{8,0,0} //sub2
,{16,0,0}}; //sub3
for(int i = 0; i< expRes.length; i++){
for(int j = 0; j < expRes[i].length;j++){
......@@ -122,6 +116,70 @@ public class AutomaticClusteringTest extends AbstractSymtabTest{
}
@Test
public void testDistanceMatrixCreation(){
// a -(10)-> b -(20)-> c | d
double[][] adj = {
{0, 10, 0, 0},
{10, 0, 20, 0},
{0, 20, 0, 0},
{0, 0, 0, 0}};
double[][] dist = AutomaticClusteringHelper.getDistanceMatrix(adj);
double m = Double.MAX_VALUE;
double[][] expDist = {
{0, 10, 30, m},
{10, 0, 20, m},
{30, 20, 0, m},
{m, m, m, 0}};
for(int i = 0; i< expDist.length; i++){
for(int j = 0; j < expDist[i].length;j++){
assertTrue(expDist[i][j] == dist[i][j]);
}
}
}
@Test
public void testSilhouetteIndex(){
// graph:
// a,b close to each other
// c,d close to each other
// big difference between a and c as well as b and d
double[][] adjMat = {
{0, 10, 1000, 1000},
{10, 0, 1000, 1000},
{1000, 1000, 0, 10},
{1000, 1000, 10, 0}
};
double[][] dist = AutomaticClusteringHelper.getDistanceMatrix(adjMat);
int[] correctCustering = {0,0,1,1};
SilhouetteIndex index1 = new SilhouetteIndex(dist, correctCustering);
assertTrue(index1.S(0) > 0.5);
assertTrue(index1.S(1) > 0.5);
assertTrue(index1.S(2) > 0.5);
assertTrue(index1.S(3) > 0.5);
int[] badClustering = {0,0,0,1};
SilhouetteIndex index2 = new SilhouetteIndex(dist, badClustering);
assertTrue(index2.S(0) > 0.5);
assertTrue(index2.S(1) > 0.5);
assertTrue(index2.S(2) < -0.5);
assertTrue(index2.S(3) > 0.5);
assertTrue(index1.getSilhouetteScore() > index2.getSilhouetteScore());
}
@Test
public void testSpectralClustering(){
......@@ -155,6 +213,18 @@ public class AutomaticClusteringTest extends AbstractSymtabTest{
}
@Ignore
@Test
public void testSpectralClusteringIsolatedVertex(){
double[][] adjMatrix = {
{0, 1, 1, 0},
{1, 0, 0, 0},
{1, 0, 0, 0},
{0, 0, 0, 0}};
SpectralClustering clustering = new SpectralClustering(adjMatrix,2);
}
// todo: gotta move this thing later, just temporarily here for testing purposes
public static Dataset[] getClustering(Dataset data, SparseMatrix matrix) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment