Commit cafac521 authored by Julian Dierkes's avatar Julian Dierkes
Browse files

added new parameters for GAN

parent 7574b6e6
...@@ -256,13 +256,18 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -256,13 +256,18 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue; NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue;
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue; PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
// GANs Extensions // GANs Extensions
KValueEntry implements ConfigEntry = name:"k_value" ":" value:IntegerValue;
GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:StringValue;
ConditionalInputEntry implements ConfigEntry = name:"conditional_input" ":" value:StringValue;
NoiseInputEntry implements ConfigEntry = name:"noise_input" ":" value:StringValue;
interface MultiParamValueMapConfigEntry extends ConfigEntry; interface MultiParamValueMapConfigEntry extends ConfigEntry;
interface MultiParamValueMapParamValue extends ConfigValue; interface MultiParamValueMapParamValue extends ConfigValue;
interface MultiParamValueMapTupleValue extends ConfigValue; interface MultiParamValueMapTupleValue extends ConfigValue;
DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue; DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue;
DiscriminatorOptimizerEntry implements ConfigEntry = name:"discriminator_optimizer" ":" value:OptimizerValue;
QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue; QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue;
PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue; PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue;
......
...@@ -123,7 +123,13 @@ class ParameterAlgorithmMapping { ...@@ -123,7 +123,13 @@ class ParameterAlgorithmMapping {
ASTQNetworkEntry.class, ASTQNetworkEntry.class,
ASTNoiseDistributionEntry.class, ASTNoiseDistributionEntry.class,
ASTConstraintDistributionEntry.class, ASTConstraintDistributionEntry.class,
ASTConstraintLossEntry.class ASTConstraintLossEntry.class,
ASTDiscriminatorOptimizerEntry.class,
ASTKValueEntry.class,
ASTGeneratorLossEntry.class,
ASTConditionalInputEntry.class,
ASTNoiseInputEntry.class
); );
ParameterAlgorithmMapping() { ParameterAlgorithmMapping() {
......
...@@ -110,6 +110,24 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -110,6 +110,24 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
} }
} }
@Override
public void visit(ASTDiscriminatorOptimizerEntry node) {
OptimizerSymbol optimizerSymbol = new OptimizerSymbol(node.getValue().getName());
configuration.setCriticOptimizer(optimizerSymbol);
addToScopeAndLinkWithNode(optimizerSymbol, node);
}
@Override
public void endVisit(ASTDiscriminatorOptimizerEntry node) {
assert configuration.getCriticOptimizer().isPresent(): "Critic optimizer not present";
for (ASTEntry paramNode : node.getValue().getParamsList()) {
OptimizerParamSymbol param = new OptimizerParamSymbol();
OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol)paramNode.getValue().getSymbolOpt().get();
param.setValue(valueSymbol);
configuration.getCriticOptimizer().get().getOptimizerParamMap().put(paramNode.getName(), param);
}
}
@Override @Override
public void endVisit(ASTNumEpochEntry node) { public void endVisit(ASTNumEpochEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
...@@ -118,6 +136,30 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -118,6 +136,30 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry); configuration.getEntryMap().put(node.getName(), entry);
} }
@Override
public void endVisit(ASTKValueEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTGeneratorLossEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTConditionalInputEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void endVisit(ASTBatchSizeEntry node) { public void endVisit(ASTBatchSizeEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
......
...@@ -52,4 +52,10 @@ public class ConfigEntryNameConstants { ...@@ -52,4 +52,10 @@ public class ConfigEntryNameConstants {
public static final String NOISE_DISTRIBUTION = "noise_distribution"; public static final String NOISE_DISTRIBUTION = "noise_distribution";
public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions"; public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions";
public static final String CONSTRAINT_LOSS = "constraint_losses"; public static final String CONSTRAINT_LOSS = "constraint_losses";
public static final String DISCRIMINATOR_OPTIMIZER = "discriminator_optimizer";
public static final String K_VALUE = "k_value";
public static final String GENERATOR_LOSS = "generator_loss";
public static final String CONDITIONAL_INPUT = "conditional_input";
public static final String NOISE_INPUT = "noise_input";
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment