Commit a57778c1 authored by Dmytro Semenchenko's avatar Dmytro Semenchenko
Browse files

Merge branch 'master' into onnx-dmytro

parents 69e2a60a 65a4d7a1
Pipeline #649193 passed with stage
in 1 minute and 24 seconds
......@@ -245,3 +245,4 @@
</snapshotRepository>
</distributionManagement>
</project>
......@@ -90,6 +90,10 @@ public class ArchitectureElementData {
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
}
public String getPdf() { return getLayerSymbol().getStringValue(AllPredefinedLayers.PDF_NAME).get(); }
public int getNumEmbeddings() { return getLayerSymbol().getIntValue(AllPredefinedLayers.NUM_EMBEDDINGS_NAME).get();
}
public int getGroups(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.GROUPS_NAME).get();
}
......@@ -314,7 +318,11 @@ public class ArchitectureElementData {
public int getValuesDim(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.VALUES_DIM_NAME).get();
}
public int getNodes(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.NODES_NAME).get();
}
@Nullable
public List<Integer> getPadding(){
......
......@@ -106,6 +106,7 @@ public abstract class CNNArchTemplateController {
public boolean containsAdaNet(){
return this.architecture.containsAdaNet();
}
public String getName(ArchitectureElementSymbol layer){
return nameManager.getName(layer);
}
......
......@@ -6,6 +6,7 @@ import com.google.common.collect.Maps;
import de.monticore.lang.monticar.cnnarch.generator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.generator.annotations.Range;
import de.monticore.lang.monticar.cnnarch.generator.training.RlAlgorithm;
import de.monticore.lang.monticar.cnnarch.generator.training.NetworkType;
import de.monticore.lang.monticar.cnnarch.generator.training.TrainingComponentsContainer;
import de.monticore.lang.monticar.cnnarch.generator.training.TrainingConfiguration;
......@@ -39,6 +40,10 @@ public abstract class ConfigurationData {
return trainingConfiguration.isGanLearning();
}
public Boolean isVaeLearning() {
return trainingConfiguration.isVaeLearning();
}
public Boolean isReinforcementLearning() {
return trainingConfiguration.isReinforcementLearning();
}
......@@ -73,6 +78,16 @@ public abstract class ConfigurationData {
return loadPretrainedOpt.orElse(null);
}
public Double getKlLossWeight() {
Optional<Double> klLossWeightOpt = trainingConfiguration.getKlLossWeight();
return klLossWeightOpt.orElse(null);
}
public String getReconLossName() {
Optional<String> reconLossNameOpt = trainingConfiguration.getReconLossName();
return reconLossNameOpt.orElse(null);
}
// COMPARE WITH CNNTRAIN IMPLEMENTATION IN GluonConfigurationData
public Boolean getPreprocessor() {
Optional<String> preprocessorOpt = trainingConfiguration.getPreprocessor();
......@@ -89,6 +104,21 @@ public abstract class ConfigurationData {
return onnxExport.orElse(null);
}
public Boolean getMultiGraph() {
Optional<Boolean> multiGraphOpt = trainingConfiguration.getMultiGraph();
return multiGraphOpt.orElse(null);
}
public List<Integer> getTrainMask() {
Optional<List<Integer>> trainMaskOpt = trainingConfiguration.getTrainMask();
return trainMaskOpt.orElse(null);
}
public List<Integer> getTestMask() {
Optional<List<Integer>> testMaskOpt = trainingConfiguration.getTestMask();
return testMaskOpt.orElse(null);
}
public Boolean getShuffleData() {
Optional<Boolean> shuffleDataOpt = trainingConfiguration.getShuffleData();
return shuffleDataOpt.orElse(null);
......@@ -352,7 +382,11 @@ public abstract class ConfigurationData {
// public Map<String, Map<String, Object>> getConstraintLosses() { // TODO
// return getMultiParamMapEntry(CONSTRAINT_LOSS, "name");
// }
public String getSelfPlay() { // added Parameter self_play for cooperative driving
Optional<String> selfPlay = trainingConfiguration.getSelfPlay();
return selfPlay.orElse(null);
}
public String getRlAlgorithm() {
Optional<RlAlgorithm> rlAlgorithmOpt = trainingConfiguration.getRlAlgorithm();
if (!rlAlgorithmOpt.isPresent()) {
......@@ -370,6 +404,16 @@ public abstract class ConfigurationData {
return DQN;
}
public String getNetworkType() {
Optional<NetworkType> networkTypeOpt = trainingConfiguration.getNetworkType();
NetworkType networkType = networkTypeOpt.get();
if (networkType.equals(NetworkType.GNN)) {
return GNN;
}
return null;
}
// protected Object getDefaultValueOrElse(String parameterKey, Object elseValue) {
// if (schema == null) {
// return elseValue;
......@@ -631,4 +675,4 @@ public abstract class ConfigurationData {
}
return object.toString();
}
}
\ No newline at end of file
}
......@@ -4,7 +4,8 @@ public enum LearningMethod {
SUPERVISED("supervised"),
REINFORCEMENT("reinforcement"),
GAN("gan");
GAN("gan"),
VAE("vae");
String method;
......
package de.monticore.lang.monticar.cnnarch.generator.training;
public enum NetworkType {
GNN("gnn");
String type;
NetworkType(String type) {
this.type = type;
}
public static NetworkType networkType(String type) {
for (NetworkType nt : values()) {
if (nt.type.equals(type)) {
return nt;
}
}
throw new IllegalArgumentException(String.valueOf(type));
}
public String getType() {
return type;
}
}
\ No newline at end of file
......@@ -20,6 +20,8 @@ public class TrainingComponentsContainer {
private ArchitectureAdapter actorNetwork;
private ArchitectureAdapter criticNetwork;
private ArchitectureAdapter generatorNetwork;
private ArchitectureAdapter encoderNetwork;
private ArchitectureAdapter decoderNetwork;
private ArchitectureAdapter discriminatorNetwork;
private ArchitectureAdapter qNetwork;
private EMAComponentInstanceSymbol rewardFunction;
......@@ -43,11 +45,18 @@ public class TrainingComponentsContainer {
return Optional.ofNullable(generatorNetwork);
}
public Optional<ArchitectureAdapter> getDiscriminatorNetwork() {
return Optional.ofNullable(discriminatorNetwork);
}
public Optional<ArchitectureAdapter> getDecoderNetwork() {
return Optional.ofNullable(decoderNetwork);
}
public Optional<ArchitectureAdapter> getEncoderNetwork() {
return Optional.ofNullable(encoderNetwork);
}
public Optional<ArchitectureAdapter> getQNetwork() {
return Optional.ofNullable(qNetwork);
}
......@@ -80,6 +89,8 @@ public class TrainingComponentsContainer {
}
} else if (trainingConfiguration.isGanLearning()) {
setGeneratorNetwork(trainedArchitecture);
} else if (trainingConfiguration.isVaeLearning()) {
setDecoderNetwork(trainedArchitecture);
}
}
......@@ -108,6 +119,16 @@ public class TrainingComponentsContainer {
addTrainingComponent(QNETWORK, qNetwork);
}
public void setDecoderNetwork(ArchitectureAdapter decoderNetwork) {
this.decoderNetwork = decoderNetwork;
addTrainingComponent(DECODER, decoderNetwork);
}
public void setEncoderNetwork(ArchitectureAdapter encoderNetwork) {
this.encoderNetwork = encoderNetwork;
addTrainingComponent(ENCODER, encoderNetwork);
}
public void setRewardFunction(EMAComponentInstanceSymbol rewardFunction) {
this.rewardFunction = rewardFunction;
addTrainingComponent(REWARD_FUNCTION, rewardFunction.getComponentType());
......
......@@ -7,10 +7,7 @@ import conflang._symboltable.ConfigurationEntrySymbol;
import conflang._symboltable.NestedConfigurationEntrySymbol;
import schemalang._symboltable.SchemaDefinitionSymbol;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;
import static de.monticore.lang.monticar.cnnarch.generator.training.TrainingParameterConstants.*;
......@@ -50,6 +47,10 @@ public class TrainingConfiguration {
return getParameterValue(CONTEXT);
}
public Optional<String> getSelfPlay() {
return getParameterValue(SELF_PLAY);
}
public Optional<LearningMethod> getLearningMethod() {
Optional<ConfigurationEntry> learningMethodOpt =
configurationSymbol.getConfigurationEntry(LEARNING_METHOD);
......@@ -83,6 +84,15 @@ public class TrainingConfiguration {
return LearningMethod.GAN.equals(learningMethod);
}
public boolean isVaeLearning() {
Optional<LearningMethod> learningMethodOpt = getLearningMethod();
if (!learningMethodOpt.isPresent()) {
return false; // Not correct here to return false..
}
LearningMethod learningMethod = learningMethodOpt.get();
return LearningMethod.VAE.equals(learningMethod);
}
public boolean isReinforcementLearning() {
Optional<LearningMethod> learningMethodOpt = getLearningMethod();
if (!learningMethodOpt.isPresent()) {
......@@ -101,6 +111,15 @@ public class TrainingConfiguration {
return Optional.of(RlAlgorithm.rlAlgorithm(rlAlgorithm));
}
public Optional<NetworkType> getNetworkType() {
Optional<ConfigurationEntry> networkTypeOpt = configurationSymbol.getConfigurationEntry(NETWORK_TYPE);
if (!networkTypeOpt.isPresent()) {
return Optional.empty();
}
String networkType = (String) networkTypeOpt.get().getValue();
return Optional.of(NetworkType.networkType(networkType));
}
public Optional<Integer> getBatchSize() {
return getParameterValue(BATCH_SIZE);
}
......@@ -141,6 +160,18 @@ public class TrainingConfiguration {
return getParameterValue(SHUFFLE_DATA);
}
public Optional<Boolean> getMultiGraph() {
return getParameterValue(MULTI_GRAPH);
}
public Optional<List<Integer>> getTrainMask() {
return getParameterValue(TRAIN_MASK);
}
public Optional<List<Integer>> getTestMask() {
return getParameterValue(TEST_MASK);
}
public Optional<Double> getClipGlobalGradNorm() {
return getParameterValue(CLIP_GLOBAL_GRAD_NORM);
}
......@@ -353,6 +384,20 @@ public class TrainingConfiguration {
return getObjectParameterParameters(DISCRIMINATOR_OPTIMIZER);
}
public Boolean hasEncoderName() {
return hasParameter(ENCODER);
}
public Optional<String> getEncoderName() {
return getObjectParameterValue(ENCODER);
}
public Optional<String> getReconLossName() {
return getObjectParameterValue(RECON_LOSS);
}
public Optional<Double> getKlLossWeight() { return getParameterValue(KL_LOSS_WEIGHT); }
public boolean hasStrategy() {
return hasParameter(STRATEGY);
}
......@@ -513,4 +558,5 @@ public class TrainingConfiguration {
}
return keyValues;
}
}
\ No newline at end of file
......@@ -15,6 +15,7 @@ public class TrainingParameterConstants {
public static final String SUPERVISED = "supervised";
public static final String REINFORCEMENT = "reinforcement";
public static final String GAN = "gan";
public static final String VAE = "vae";
/*
* Optimizers
......@@ -41,6 +42,7 @@ public class TrainingParameterConstants {
public static final String NORMALIZE = "normalize";
public static final String CONTEXT = "context";
public static final String SHUFFLE_DATA = "shuffle_data";
public static final String CLIP_GLOBAL_GRAD_NORM = "clip_global_grad_norm";
public static final String USE_TEACHER_FORCING = "use_teacher_forcing";
public static final String SAVE_ATTENTION_IMAGE = "save_attention_image";
......@@ -56,6 +58,13 @@ public class TrainingParameterConstants {
public static final String DQN = "dqn";
public static final String DDPG = "ddpg";
public static final String TD3 = "td3";
public static final String SELF_PLAY = "self_play";
public static final String MULTI_GRAPH = "multi_graph";
public static final String TRAIN_MASK = "train_mask";
public static final String TEST_MASK = "test_mask";
public static final String GNN = "gnn";
public static final String NETWORK_TYPE = "network_type";
public static final String LEARNING_METHOD = "learning_method";
public static final String EVAL_METRIC = "eval_metric";
......@@ -111,4 +120,9 @@ public class TrainingParameterConstants {
public static final String GENERATOR_LOSS_WEIGHT = "generator_loss_weight";
public static final String DISCRIMINATOR_LOSS_WEIGHT = "discriminator_loss_weight";
public static final String PRINT_IMAGES = "print_images";
public static final String ENCODER = "encoder";
public static final String DECODER = "decoder";
public static final String KL_LOSS_WEIGHT = "kl_loss_weight";
public static final String RECON_LOSS = "reconstruction_loss";
}
\ No newline at end of file
/* (c) https://github.com/MontiCore/monticore */
schema GNN extends Supervised {
train_mask: Z*
test_mask: Z*
multi_graph: B
}
......@@ -4,7 +4,7 @@ import Optimizer;
schema General {
learning_method = supervised: schema {
supervised, reinforcement, gan;
supervised, reinforcement, gan, vae;
}
context: enum {
......
......@@ -8,6 +8,10 @@ schema Reinforcement extends General {
dqn, ddpg, td3;
}
self_play: enum {
no, yes;
}
agent_name: string
num_episodes = 50: N1
num_max_steps = 99999: N
......@@ -21,4 +25,4 @@ schema Reinforcement extends General {
actor_optimizer: optimizer_type
environment: environment_type!
replay_memory = buffer: replay_memory_type
}
\ No newline at end of file
}
......@@ -4,6 +4,10 @@ import Loss;
schema Supervised extends General {
network_type: schema {
gnn;
}
batch_size: N1
num_epoch: N
normalize: B
......
/* (c) https://github.com/MontiCore/monticore */
import Optimizer;
schema VAE extends General {
reference-model: referencemodels.vae.VAE, referencemodels.vae.CVAE
batch_size: N1
num_epoch: N1
normalize: B
checkpoint_period = 5: N
load_checkpoint: B
load_pretrained: B
log_period: N
reconstruction_loss = mse: reconLoss_type
print_images = false: B
kl_loss_weight: Q
reconLoss_type {
values:
bce,
mse;
}
}
\ No newline at end of file
/* (c) https://github.com/MontiCore/monticore */
package referencemodels.vae;
component CVAE {
component Encoder {
ports
in X data,
in W^{1} label,
out D encoding;
}
component Decoder {
ports
in D encoding,
in W^{1} label,
out X data;
}
instance Encoder encoder;
instance Decoder decoder;
connect encoder.encoding -> decoder.encoding;
}
\ No newline at end of file
/* (c) https://github.com/MontiCore/monticore */
package referencemodels.vae;
component VAE {
component Encoder{
ports
in X data,
out D encoding;
}
component Decoder{
ports
in D encoding,
out X data;
}
instance Encoder encoder;
instance Decoder decoder;
connect encoder.encoding -> decoder.encoding;
}
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment