Commit 643960d4 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'ba_celik' into 'master'

Merge branch 'ba_celik' into 'master'

See merge request !29
parents 54775014 2f94ddc2
Pipeline #638465 passed with stage
in 1 minute and 6 seconds
......@@ -12,7 +12,7 @@
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.4.8-SNAPSHOT</CNNArch.version>
<CNNArch.version>0.4.9-SNAPSHOT</CNNArch.version>
<conflang.version>1.0.0-SNAPSHOT</conflang.version>
<schemalang.version>1.0.0-SNAPSHOT</schemalang.version>
<embedded-montiarc-math-generator>0.4.9</embedded-montiarc-math-generator>
......@@ -245,3 +245,4 @@
</snapshotRepository>
</distributionManagement>
</project>
......@@ -90,6 +90,10 @@ public class ArchitectureElementData {
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
}
public String getPdf() { return getLayerSymbol().getStringValue(AllPredefinedLayers.PDF_NAME).get(); }
public int getNumEmbeddings() { return getLayerSymbol().getIntValue(AllPredefinedLayers.NUM_EMBEDDINGS_NAME).get();
}
public int getGroups(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.GROUPS_NAME).get();
}
......
......@@ -39,6 +39,10 @@ public abstract class ConfigurationData {
return trainingConfiguration.isGanLearning();
}
public Boolean isVaeLearning() {
return trainingConfiguration.isVaeLearning();
}
public Boolean isReinforcementLearning() {
return trainingConfiguration.isReinforcementLearning();
}
......@@ -73,6 +77,16 @@ public abstract class ConfigurationData {
return loadPretrainedOpt.orElse(null);
}
public Double getKlLossWeight() {
Optional<Double> klLossWeightOpt = trainingConfiguration.getKlLossWeight();
return klLossWeightOpt.orElse(null);
}
public String getReconLossName() {
Optional<String> reconLossNameOpt = trainingConfiguration.getReconLossName();
return reconLossNameOpt.orElse(null);
}
// COMPARE WITH CNNTRAIN IMPLEMENTATION IN GluonConfigurationData
public Boolean getPreprocessor() {
Optional<String> preprocessorOpt = trainingConfiguration.getPreprocessor();
......
......@@ -4,7 +4,8 @@ public enum LearningMethod {
SUPERVISED("supervised"),
REINFORCEMENT("reinforcement"),
GAN("gan");
GAN("gan"),
VAE("vae");
String method;
......
......@@ -20,6 +20,8 @@ public class TrainingComponentsContainer {
private ArchitectureAdapter actorNetwork;
private ArchitectureAdapter criticNetwork;
private ArchitectureAdapter generatorNetwork;
private ArchitectureAdapter encoderNetwork;
private ArchitectureAdapter decoderNetwork;
private ArchitectureAdapter discriminatorNetwork;
private ArchitectureAdapter qNetwork;
private EMAComponentInstanceSymbol rewardFunction;
......@@ -43,11 +45,18 @@ public class TrainingComponentsContainer {
return Optional.ofNullable(generatorNetwork);
}
public Optional<ArchitectureAdapter> getDiscriminatorNetwork() {
return Optional.ofNullable(discriminatorNetwork);
}
public Optional<ArchitectureAdapter> getDecoderNetwork() {
return Optional.ofNullable(decoderNetwork);
}
public Optional<ArchitectureAdapter> getEncoderNetwork() {
return Optional.ofNullable(encoderNetwork);
}
public Optional<ArchitectureAdapter> getQNetwork() {
return Optional.ofNullable(qNetwork);
}
......@@ -80,6 +89,8 @@ public class TrainingComponentsContainer {
}
} else if (trainingConfiguration.isGanLearning()) {
setGeneratorNetwork(trainedArchitecture);
} else if (trainingConfiguration.isVaeLearning()) {
setDecoderNetwork(trainedArchitecture);
}
}
......@@ -108,6 +119,16 @@ public class TrainingComponentsContainer {
addTrainingComponent(QNETWORK, qNetwork);
}
public void setDecoderNetwork(ArchitectureAdapter decoderNetwork) {
this.decoderNetwork = decoderNetwork;
addTrainingComponent(DECODER, decoderNetwork);
}
public void setEncoderNetwork(ArchitectureAdapter encoderNetwork) {
this.encoderNetwork = encoderNetwork;
addTrainingComponent(ENCODER, encoderNetwork);
}
public void setRewardFunction(EMAComponentInstanceSymbol rewardFunction) {
this.rewardFunction = rewardFunction;
addTrainingComponent(REWARD_FUNCTION, rewardFunction.getComponentType());
......
......@@ -7,10 +7,7 @@ import conflang._symboltable.ConfigurationEntrySymbol;
import conflang._symboltable.NestedConfigurationEntrySymbol;
import schemalang._symboltable.SchemaDefinitionSymbol;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;
import static de.monticore.lang.monticar.cnnarch.generator.training.TrainingParameterConstants.*;
......@@ -83,6 +80,15 @@ public class TrainingConfiguration {
return LearningMethod.GAN.equals(learningMethod);
}
public boolean isVaeLearning() {
Optional<LearningMethod> learningMethodOpt = getLearningMethod();
if (!learningMethodOpt.isPresent()) {
return false; // Not correct here to return false..
}
LearningMethod learningMethod = learningMethodOpt.get();
return LearningMethod.VAE.equals(learningMethod);
}
public boolean isReinforcementLearning() {
Optional<LearningMethod> learningMethodOpt = getLearningMethod();
if (!learningMethodOpt.isPresent()) {
......@@ -349,6 +355,20 @@ public class TrainingConfiguration {
return getObjectParameterParameters(DISCRIMINATOR_OPTIMIZER);
}
public Boolean hasEncoderName() {
return hasParameter(ENCODER);
}
public Optional<String> getEncoderName() {
return getObjectParameterValue(ENCODER);
}
public Optional<String> getReconLossName() {
return getObjectParameterValue(RECON_LOSS);
}
public Optional<Double> getKlLossWeight() { return getParameterValue(KL_LOSS_WEIGHT); }
public boolean hasStrategy() {
return hasParameter(STRATEGY);
}
......@@ -509,4 +529,5 @@ public class TrainingConfiguration {
}
return keyValues;
}
}
\ No newline at end of file
......@@ -15,6 +15,7 @@ public class TrainingParameterConstants {
public static final String SUPERVISED = "supervised";
public static final String REINFORCEMENT = "reinforcement";
public static final String GAN = "gan";
public static final String VAE = "vae";
/*
* Optimizers
......@@ -110,4 +111,9 @@ public class TrainingParameterConstants {
public static final String GENERATOR_LOSS_WEIGHT = "generator_loss_weight";
public static final String DISCRIMINATOR_LOSS_WEIGHT = "discriminator_loss_weight";
public static final String PRINT_IMAGES = "print_images";
public static final String ENCODER = "encoder";
public static final String DECODER = "decoder";
public static final String KL_LOSS_WEIGHT = "kl_loss_weight";
public static final String RECON_LOSS = "reconstruction_loss";
}
\ No newline at end of file
......@@ -4,7 +4,7 @@ import Optimizer;
schema General {
learning_method = supervised: schema {
supervised, reinforcement, gan;
supervised, reinforcement, gan, vae;
}
context: enum {
......
/* (c) https://github.com/MontiCore/monticore */
import Optimizer;
schema VAE extends General {
reference-model: referencemodels.vae.VAE, referencemodels.vae.CVAE
batch_size: N1
num_epoch: N1
normalize: B
checkpoint_period = 5: N
load_checkpoint: B
load_pretrained: B
log_period: N
reconstruction_loss = mse: reconLoss_type
print_images = false: B
kl_loss_weight: Q
reconLoss_type {
values:
bce,
mse;
}
}
\ No newline at end of file
/* (c) https://github.com/MontiCore/monticore */
package referencemodels.vae;
component CVAE {
component Encoder {
ports
in X data,
in W^{1} label,
out D encoding;
}
component Decoder {
ports
in D encoding,
in W^{1} label,
out X data;
}
instance Encoder encoder;
instance Decoder decoder;
connect encoder.encoding -> decoder.encoding;
}
\ No newline at end of file
/* (c) https://github.com/MontiCore/monticore */
package referencemodels.vae;
component VAE {
component Encoder{
ports
in X data,
out D encoding;
}
component Decoder{
ports
in D encoding,
out X data;
}
instance Encoder encoder;
instance Decoder decoder;
connect encoder.encoding -> decoder.encoding;
}
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment