Commit 1ad60862 authored by Alexander David Hellwig's avatar Alexander David Hellwig

Added Dynamic Parameters for automatic hyperparameter search

parent e124d95f
Pipeline #104770 failed with stages
in 20 minutes and 1 second
package de.monticore.lang.monticar.generator.middleware.cli;
import alice.tuprolog.Int;
import com.google.gson.*;
import com.google.gson.stream.JsonReader;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.*;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic.*;
import de.se_rwth.commons.logging.Log;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
public class CliParametersLoader {
private CliParametersLoader() {
......@@ -16,16 +20,43 @@ public class CliParametersLoader {
public static CliParameters loadCliParameters(String filePath) throws FileNotFoundException {
JsonReader jsonReader = new JsonReader(new FileReader(filePath));
Gson gson = new GsonBuilder().registerTypeAdapter(AlgorithmCliParameters.class, new AlgorithmParametersInterfaceAdapter()).create();
JsonDeserializer<DynamicParameter> deserializer = new JsonDeserializer<DynamicParameter>() {
@Override
public DynamicParameter deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
if(jsonElement.isJsonObject()){
JsonObject jsonObj = (JsonObject) jsonElement;
Double min = jsonObj.has("min") ? jsonObj.get("min").getAsDouble() : null;
Double max = jsonObj.has("max") ? jsonObj.get("max").getAsDouble() : null;
Double step = jsonObj.has("step") ? jsonObj.get("step").getAsDouble() : null;
Integer count = jsonObj.has("count") ? jsonObj.get("count").getAsInt() : null;
return new GeneratorParamter(min,max,step,count);
}else if(jsonElement.isJsonArray()) {
JsonArray jArray = (JsonArray) jsonElement;
List<Double> values = new ArrayList<>();
for(int i = 0; i < jArray.size(); i++){
values.add(jArray.get(i).getAsDouble());
}
return new ListParameter(values);
}else{
return new ListParameter(jsonElement.getAsDouble());
}
}
};
Gson gson = new GsonBuilder()
.registerTypeAdapter(DynamicAlgorithmCliParameters.class, new AlgorithmParametersInterfaceAdapter())
.registerTypeAdapter(DynamicParameter.class, deserializer)
.create();
return gson.fromJson(jsonReader, CliParameters.class);
}
static final class AlgorithmParametersInterfaceAdapter implements JsonSerializer<AlgorithmCliParameters>, JsonDeserializer<AlgorithmCliParameters> {
public JsonElement serialize(AlgorithmCliParameters object, Type interfaceType, JsonSerializationContext context) {
static final class AlgorithmParametersInterfaceAdapter implements JsonSerializer<DynamicAlgorithmCliParameters>, JsonDeserializer<DynamicAlgorithmCliParameters> {
public JsonElement serialize(DynamicAlgorithmCliParameters object, Type interfaceType, JsonSerializationContext context) {
return context.serialize(object);
}
public AlgorithmCliParameters deserialize(JsonElement elem, Type interfaceType, JsonDeserializationContext context) throws JsonParseException {
public DynamicAlgorithmCliParameters deserialize(JsonElement elem, Type interfaceType, JsonDeserializationContext context) throws JsonParseException {
final Type actualType = typeForName(((JsonObject)elem).get("name"));
return context.deserialize(elem, actualType);
}
......@@ -33,13 +64,13 @@ public class CliParametersLoader {
private Type typeForName(final JsonElement typeElem){
String algoName = typeElem.getAsString().toLowerCase();
switch (algoName){
case "spectralclustering": return SpectralClusteringCliParameters.class;
case "dbscan": return DBScanCliParameters.class;
case "affinitypropagation": return AffinityPropagationCliParameters.class;
case "markov": return MarkovCliParameters.class;
case "spectralclustering": return DynamicSpectralClusteringCliParameters.class;
case "dbscan": return DynamicDBScanCliParameters.class;
case "affinitypropagation": return DynamicAffinityPropagationCliParameters.class;
case "markov": return DynamicMarkovCliParameters.class;
default:{
Log.warn("Loaded config of unknown clustering algorithm: " + algoName);
return UnknownAlgorithmCliParameters.class;
return DynamicUnknownAlgorithmCliParameters.class;
}
}
}
......
......@@ -2,17 +2,22 @@ package de.monticore.lang.monticar.generator.middleware.cli;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.AlgorithmCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.SpectralClusteringCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic.DynamicAlgorithmCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic.DynamicParameter;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic.DynamicSpectralClusteringCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic.ListParameter;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
public class ClusteringParameters {
private Integer numberOfClusters;
private Boolean flatten;
private Integer flattenLevel;
private ResultChoosingStrategy chooseBy = ResultChoosingStrategy.bestWithFittingN;
private List<AlgorithmCliParameters> algorithmParameters = new ArrayList<>();
private List<DynamicAlgorithmCliParameters> algorithmParameters = new ArrayList<>();
public ClusteringParameters() {
}
......@@ -25,18 +30,23 @@ public class ClusteringParameters {
return chooseBy;
}
public List<AlgorithmCliParameters> getAlgorithmParameters() {
//Override numberOfClusters for all spectral clustering parameters
public List<DynamicAlgorithmCliParameters> getDynamicAlgorithmCliParameters(){
if(getNumberOfClusters().isPresent()){
Integer n = getNumberOfClusters().get();
DynamicParameter n = new ListParameter(getNumberOfClusters().get().doubleValue());
algorithmParameters.stream()
.filter(a -> a.getName().equals(AlgorithmCliParameters.TYPE_SPECTRAL_CLUSTERING))
.forEach(a -> ((SpectralClusteringCliParameters)a).setNumberOfClusters(n));
.forEach(a -> ((DynamicSpectralClusteringCliParameters)a).setNumberOfClusters(n));
}
return algorithmParameters;
}
public List<AlgorithmCliParameters> getAlgorithmParameters() {
return getDynamicAlgorithmCliParameters()
.stream()
.flatMap(d -> d.getAll().stream())
.collect(Collectors.toList());
}
public boolean getFlatten(){
return flatten == null ? false : flatten;
}
......
......@@ -17,6 +17,11 @@ public class DBScanCliParameters extends AlgorithmCliParameters {
public DBScanCliParameters() {
}
public DBScanCliParameters(Integer min_pts, Double radius) {
this.min_pts = min_pts;
this.radius = radius;
}
@Override
public String getName() {
return TYPE_DBSCAN;
......
......@@ -17,6 +17,13 @@ public class MarkovCliParameters extends AlgorithmCliParameters {
public MarkovCliParameters() {
}
public MarkovCliParameters(Double max_residual, Double gamma_exp, Double loop_gain, Double zero_max) {
this.max_residual = max_residual;
this.gamma_exp = gamma_exp;
this.loop_gain = loop_gain;
this.zero_max = zero_max;
}
@Override
public String getName() {
return TYPE_MARKOV;
......
......@@ -16,6 +16,16 @@ public class SpectralClusteringCliParameters extends AlgorithmCliParameters {
public SpectralClusteringCliParameters() {
}
public SpectralClusteringCliParameters(int numberOfClusters) {
this.numberOfClusters = numberOfClusters;
}
public SpectralClusteringCliParameters(int numberOfClusters, int l, double sigma) {
this.numberOfClusters = numberOfClusters;
this.l = l;
this.sigma = sigma;
}
@Override
public String getName() {
return TYPE_SPECTRAL_CLUSTERING;
......
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.AffinityPropagationCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.AlgorithmCliParameters;
import java.util.Arrays;
import java.util.List;
public class DynamicAffinityPropagationCliParameters extends DynamicAlgorithmCliParameters {
@Override
public List<AlgorithmCliParameters> getAll() {
return Arrays.asList(new AffinityPropagationCliParameters());
}
@Override
public boolean isValid() {
return true;
}
@Override
public String getName() {
return AlgorithmCliParameters.TYPE_AFFINITY_PROPAGATION;
}
}
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.AlgorithmCliParameters;
import java.util.Arrays;
import java.util.List;
public abstract class DynamicAlgorithmCliParameters {
public abstract List<AlgorithmCliParameters> getAll();
public abstract boolean isValid();
public abstract String getName();
protected List<Double> getValuesOrSingleElement(DynamicParameter dynamicParameter, Double elem){
if(dynamicParameter == null){
return Arrays.asList(elem);
}else{
return dynamicParameter.getAll();
}
}
}
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.AlgorithmCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.DBScanCliParameters;
import java.util.ArrayList;
import java.util.List;
public class DynamicDBScanCliParameters extends DynamicAlgorithmCliParameters {
private DynamicParameter min_pts;
private DynamicParameter radius;
@Override
public List<AlgorithmCliParameters> getAll() {
List<AlgorithmCliParameters> res = new ArrayList<>();
for(Integer pts : min_pts.getAllAsInt()){
for(Double rad: radius.getAll()){
res.add(new DBScanCliParameters(pts, rad));
}
}
return res;
}
@Override
public boolean isValid() {
return min_pts != null && min_pts.isValid() && radius != null && radius.isValid();
}
@Override
public String getName() {
return AlgorithmCliParameters.TYPE_DBSCAN;
}
}
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.AlgorithmCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.MarkovCliParameters;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class DynamicMarkovCliParameters extends DynamicAlgorithmCliParameters {
private DynamicParameter max_residual;
private DynamicParameter gamma_exp;
private DynamicParameter loop_gain;
private DynamicParameter zero_max;
@Override
public List<AlgorithmCliParameters> getAll() {
List<AlgorithmCliParameters> res = new ArrayList<>();
for(Double mr : getValuesOrSingleElement(max_residual, null)){
for(Double ge : getValuesOrSingleElement(gamma_exp, null)){
for(Double lg : getValuesOrSingleElement(loop_gain, null)){
for(Double zm : getValuesOrSingleElement(zero_max, null)){
res.add(new MarkovCliParameters(mr, ge, lg, zm));
}
}
}
}
return res;
}
@Override
public String getName() {
return AlgorithmCliParameters.TYPE_MARKOV;
}
@Override
public boolean isValid() {
return true;
}
}
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import java.util.ArrayList;
import java.util.List;
public abstract class DynamicParameter{
public List<Integer> getAllAsInt(){
List<Double> asDouble = getAll();
List<Integer> res = new ArrayList<>();
asDouble.forEach(d -> res.add(d.intValue()));
return res;
}
public abstract List<Double> getAll();
public abstract boolean isValid();
}
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.AlgorithmCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.SpectralClusteringCliParameters;
import java.util.ArrayList;
import java.util.List;
public class DynamicSpectralClusteringCliParameters extends DynamicAlgorithmCliParameters {
DynamicParameter numberOfClusters;
DynamicParameter l;
DynamicParameter sigma;
public DynamicParameter getNumberOfClusters() {
return numberOfClusters;
}
public void setNumberOfClusters(DynamicParameter numberOfClusters) {
this.numberOfClusters = numberOfClusters;
}
public DynamicParameter getL() {
return l;
}
public void setL(DynamicParameter l) {
this.l = l;
}
public DynamicParameter getSigma() {
return sigma;
}
public void setSigma(DynamicParameter sigma) {
this.sigma = sigma;
}
@Override
public List<AlgorithmCliParameters> getAll(){
if(!isValid()){
return new ArrayList<>();
}
List<AlgorithmCliParameters> res = new ArrayList<>();
if(l == null || sigma == null){
for (Integer n : numberOfClusters.getAllAsInt()) {
res.add(new SpectralClusteringCliParameters(n));
}
}else{
for(Integer a : numberOfClusters.getAllAsInt()){
for(Integer b : l.getAllAsInt()){
for(Double c : sigma.getAll()){
res.add(new SpectralClusteringCliParameters(a, b, c));
}
}
}
}
return res;
}
@Override
public boolean isValid(){
if(numberOfClusters == null || !numberOfClusters.isValid()){
return false;
}
if(l == null && sigma == null){
return true;
}
if(l != null && l.isValid() && sigma != null && sigma.isValid()){
return true;
}
return false;
}
@Override
public String getName() {
return AlgorithmCliParameters.TYPE_SPECTRAL_CLUSTERING;
}
}
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.AlgorithmCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.UnknownAlgorithmCliParameters;
import de.monticore.lang.monticar.generator.middleware.clustering.ClusteringAlgorithm;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class DynamicUnknownAlgorithmCliParameters extends DynamicAlgorithmCliParameters {
public DynamicUnknownAlgorithmCliParameters() {
}
@Override
public List<AlgorithmCliParameters> getAll() {
return Arrays.asList(new UnknownAlgorithmCliParameters());
}
@Override
public boolean isValid() {
return false;
}
@Override
public String getName() {
return AlgorithmCliParameters.TYPE_UNKOWN;
}
}
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class GeneratorParamter extends DynamicParameter {
public Double min;
public Double max;
public Double step;
public Integer count;
public GeneratorParamter() {
}
public GeneratorParamter(Double min) {
this.min = min;
}
public GeneratorParamter(Double min, Double max, Double step) {
this.min = min;
this.max = max;
this.step = step;
}
public GeneratorParamter(Double min, Double max, Integer count) {
this.min = min;
this.max = max;
this.count = count;
}
public GeneratorParamter(Double min, Double max, Double step, Integer count) {
this.min = min;
this.max = max;
this.step = step;
this.count = count;
}
public List<Integer> getAllAsInt(){
List<Double> asDouble = getAll();
List<Integer> res = new ArrayList<>();
asDouble.forEach(d -> res.add(d.intValue()));
return res;
}
public List<Double> getAll(){
if(!isValid()){
return new ArrayList<>();
}
if(isSingleValue()){
return Arrays.asList(min);
}
// min + (count - 1)*(mix - min)/(count-1)
if(count != null && count > 1){
List<Double> res = new ArrayList<>();
double d = (max - min) / (count - 1);
for(int i = 0; i < count; i++){
res.add(min + i*d);
}
return res;
}
if(step == null){
step = 1.0;
}
List<Double> res = new ArrayList<>();
Double cur = min;
while(max - cur > -0.000001){
res.add(cur);
cur += step;
}
return res;
}
public boolean isValid(){
// only one value
if(isSingleValue()){
return true;
}else
//multiple values
if(min != null && max != null){
if(max < min){
return false;
}else{
return (step != null && step > 0) || (step == null) || (count != null && count > 1);
}
}else{
return false;
}
}
private boolean isSingleValue() {
return min != null && max == null && step == null && count == null;
}
}
package de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic;
import java.util.Arrays;
import java.util.List;
public class ListParameter extends DynamicParameter {
private List<Double> values;
public ListParameter() {
}
public ListParameter(List<Double> values) {
this.values = values;
}
public ListParameter(Double value) {
this.values = Arrays.asList(value);
}
@Override
public List<Double> getAll() {
return values;
}
@Override
public boolean isValid() {
return !values.isEmpty();
}
}
......@@ -5,6 +5,8 @@ import de.monticore.lang.monticar.generator.middleware.cli.CliParametersLoader;
import de.monticore.lang.monticar.generator.middleware.cli.ClusteringParameters;
import de.monticore.lang.monticar.generator.middleware.cli.ResultChoosingStrategy;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.*;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic.DynamicAlgorithmCliParameters;
import de.monticore.lang.monticar.generator.middleware.cli.algorithms.dynamic.DynamicSpectralClusteringCliParameters;
import org.junit.Test;
import java.io.FileNotFoundException;
......@@ -117,7 +119,58 @@ public class ParameterLoadingTest {
}
@Test
public void testDynamicClusteringArgs() throws FileNotFoundException {
CliParameters params = loadCliParameters("clusterDynamic");
ClusteringParameters clusteringParameters = params.getClusteringParameters().get();
List<DynamicAlgorithmCliParameters> dynm = clusteringParameters.getDynamicAlgorithmCliParameters();
assertEquals(2, dynm.size());
assertTrue(dynm.get(0) instanceof DynamicSpectralClusteringCliParameters);
DynamicSpectralClusteringCliParameters dynmSpectral = (DynamicSpectralClusteringCliParameters) dynm.get(0);
assertTrue(dynmSpectral.isValid());
assertEquals(8, dynmSpectral.getNumberOfClusters().getAllAsInt().size());
assertEquals(10, dynmSpectral.getL().getAllAsInt().size());
assertEquals(11, dynmSpectral.getSigma().getAllAsInt().size());
List<AlgorithmCliParameters> spectrals = dynmSpectral.getAll();
assertEquals(8 * 11 * 10 ,spectrals.size());
for(AlgorithmCliParameters s : spectrals){
System.out.println(s);
}
List<AlgorithmCliParameters> compatible = dynm.get(1).getAll();
assertEquals(1, compatible.size());
}
@Test
public void testListParameterClusteringArgs() throws FileNotFoundException {
CliParameters params = loadCliParameters("clusterDynamicList");
ClusteringParameters clusteringParameters = params.getClusteringParameters().get();
List<DynamicAlgorithmCliParameters> dynm = clusteringParameters.getDynamicAlgorithmCliParameters();
assertEquals(1, dynm.size());
assertTrue(dynm.get(0) instanceof DynamicSpectralClusteringCliParameters);
DynamicSpectralClusteringCliParameters dynmSpectralList = (DynamicSpectralClusteringCliParameters) dynm.get(0);
assertTrue(dynmSpectralList.isValid());
assertEquals(3, dynmSpectralList.getNumberOfClusters().getAllAsInt().size());
assertEquals(4, dynmSpectralList.getL().getAllAsInt().size());
assertEquals(1, dynmSpectralList.getSigma().getAllAsInt().size());
List<AlgorithmCliParameters> spectrals = dynmSpectralList.getAll();
assertEquals(3 * 4 * 1 ,spectrals.size());
for(AlgorithmCliParameters s : spectrals){
System.out.println(