Commit 80c0e869 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'ba-baumann' into 'master'

GNN and DGL support, merge after CNNArchLang

See merge request !28
parents 643960d4 79649129
Pipeline #639851 passed with stage
in 1 minute and 54 seconds
......@@ -314,7 +314,11 @@ public class ArchitectureElementData {
public int getValuesDim(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.VALUES_DIM_NAME).get();
}
public int getNodes(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.NODES_NAME).get();
}
@Nullable
public List<Integer> getPadding(){
......
......@@ -106,6 +106,7 @@ public abstract class CNNArchTemplateController {
public boolean containsAdaNet(){
return this.architecture.containsAdaNet();
}
public String getName(ArchitectureElementSymbol layer){
return nameManager.getName(layer);
}
......
......@@ -6,6 +6,7 @@ import com.google.common.collect.Maps;
import de.monticore.lang.monticar.cnnarch.generator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.generator.annotations.Range;
import de.monticore.lang.monticar.cnnarch.generator.training.RlAlgorithm;
import de.monticore.lang.monticar.cnnarch.generator.training.NetworkType;
import de.monticore.lang.monticar.cnnarch.generator.training.TrainingComponentsContainer;
import de.monticore.lang.monticar.cnnarch.generator.training.TrainingConfiguration;
......@@ -98,6 +99,21 @@ public abstract class ConfigurationData {
return normalizeOpt.orElse(null);
}
public Boolean getMultiGraph() {
Optional<Boolean> multiGraphOpt = trainingConfiguration.getMultiGraph();
return multiGraphOpt.orElse(null);
}
public List<Integer> getTrainMask() {
Optional<List<Integer>> trainMaskOpt = trainingConfiguration.getTrainMask();
return trainMaskOpt.orElse(null);
}
public List<Integer> getTestMask() {
Optional<List<Integer>> testMaskOpt = trainingConfiguration.getTestMask();
return testMaskOpt.orElse(null);
}
public Boolean getShuffleData() {
Optional<Boolean> shuffleDataOpt = trainingConfiguration.getShuffleData();
return shuffleDataOpt.orElse(null);
......@@ -379,6 +395,16 @@ public abstract class ConfigurationData {
return DQN;
}
public String getNetworkType() {
Optional<NetworkType> networkTypeOpt = trainingConfiguration.getNetworkType();
NetworkType networkType = networkTypeOpt.get();
if (networkType.equals(NetworkType.GNN)) {
return GNN;
}
return null;
}
// protected Object getDefaultValueOrElse(String parameterKey, Object elseValue) {
// if (schema == null) {
// return elseValue;
......
package de.monticore.lang.monticar.cnnarch.generator.training;
public enum NetworkType {
GNN("gnn");
String type;
NetworkType(String type) {
this.type = type;
}
public static NetworkType networkType(String type) {
for (NetworkType nt : values()) {
if (nt.type.equals(type)) {
return nt;
}
}
throw new IllegalArgumentException(String.valueOf(type));
}
public String getType() {
return type;
}
}
\ No newline at end of file
......@@ -107,6 +107,15 @@ public class TrainingConfiguration {
return Optional.of(RlAlgorithm.rlAlgorithm(rlAlgorithm));
}
public Optional<NetworkType> getNetworkType() {
Optional<ConfigurationEntry> networkTypeOpt = configurationSymbol.getConfigurationEntry(NETWORK_TYPE);
if (!networkTypeOpt.isPresent()) {
return Optional.empty();
}
String networkType = (String) networkTypeOpt.get().getValue();
return Optional.of(NetworkType.networkType(networkType));
}
public Optional<Integer> getBatchSize() {
return getParameterValue(BATCH_SIZE);
}
......@@ -143,6 +152,18 @@ public class TrainingConfiguration {
return getParameterValue(SHUFFLE_DATA);
}
public Optional<Boolean> getMultiGraph() {
return getParameterValue(MULTI_GRAPH);
}
public Optional<List<Integer>> getTrainMask() {
return getParameterValue(TRAIN_MASK);
}
public Optional<List<Integer>> getTestMask() {
return getParameterValue(TEST_MASK);
}
public Optional<Double> getClipGlobalGradNorm() {
return getParameterValue(CLIP_GLOBAL_GRAD_NORM);
}
......
......@@ -42,6 +42,7 @@ public class TrainingParameterConstants {
public static final String NORMALIZE = "normalize";
public static final String CONTEXT = "context";
public static final String SHUFFLE_DATA = "shuffle_data";
public static final String CLIP_GLOBAL_GRAD_NORM = "clip_global_grad_norm";
public static final String USE_TEACHER_FORCING = "use_teacher_forcing";
public static final String SAVE_ATTENTION_IMAGE = "save_attention_image";
......@@ -57,6 +58,12 @@ public class TrainingParameterConstants {
public static final String DDPG = "ddpg";
public static final String TD3 = "td3";
public static final String MULTI_GRAPH = "multi_graph";
public static final String TRAIN_MASK = "train_mask";
public static final String TEST_MASK = "test_mask";
public static final String GNN = "gnn";
public static final String NETWORK_TYPE = "network_type";
public static final String LEARNING_METHOD = "learning_method";
public static final String EVAL_METRIC = "eval_metric";
public static final String NUM_EPISODES = "num_episodes";
......
/* (c) https://github.com/MontiCore/monticore */
schema GNN extends Supervised {
train_mask: Z*
test_mask: Z*
multi_graph: B
}
......@@ -4,6 +4,10 @@ import Loss;
schema Supervised extends General {
network_type: schema {
gnn;
}
batch_size: N1
num_epoch: N1
normalize: B
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment