Commit cafac521 authored by Julian Dierkes's avatar Julian Dierkes

added new parameters for GAN

parent 7574b6e6
......@@ -256,13 +256,18 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue;
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
// GANs Extensions
KValueEntry implements ConfigEntry = name:"k_value" ":" value:IntegerValue;
GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:StringValue;
ConditionalInputEntry implements ConfigEntry = name:"conditional_input" ":" value:StringValue;
NoiseInputEntry implements ConfigEntry = name:"noise_input" ":" value:StringValue;
interface MultiParamValueMapConfigEntry extends ConfigEntry;
interface MultiParamValueMapParamValue extends ConfigValue;
interface MultiParamValueMapTupleValue extends ConfigValue;
DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue;
DiscriminatorOptimizerEntry implements ConfigEntry = name:"discriminator_optimizer" ":" value:OptimizerValue;
QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue;
PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue;
......
......@@ -123,7 +123,13 @@ class ParameterAlgorithmMapping {
ASTQNetworkEntry.class,
ASTNoiseDistributionEntry.class,
ASTConstraintDistributionEntry.class,
ASTConstraintLossEntry.class
ASTConstraintLossEntry.class,
ASTDiscriminatorOptimizerEntry.class,
ASTKValueEntry.class,
ASTGeneratorLossEntry.class,
ASTConditionalInputEntry.class,
ASTNoiseInputEntry.class
);
ParameterAlgorithmMapping() {
......
......@@ -110,6 +110,24 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
}
}
@Override
public void visit(ASTDiscriminatorOptimizerEntry node) {
OptimizerSymbol optimizerSymbol = new OptimizerSymbol(node.getValue().getName());
configuration.setCriticOptimizer(optimizerSymbol);
addToScopeAndLinkWithNode(optimizerSymbol, node);
}
@Override
public void endVisit(ASTDiscriminatorOptimizerEntry node) {
assert configuration.getCriticOptimizer().isPresent(): "Critic optimizer not present";
for (ASTEntry paramNode : node.getValue().getParamsList()) {
OptimizerParamSymbol param = new OptimizerParamSymbol();
OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol)paramNode.getValue().getSymbolOpt().get();
param.setValue(valueSymbol);
configuration.getCriticOptimizer().get().getOptimizerParamMap().put(paramNode.getName(), param);
}
}
@Override
public void endVisit(ASTNumEpochEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......@@ -118,6 +136,30 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTKValueEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTGeneratorLossEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTConditionalInputEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTBatchSizeEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......
......@@ -52,4 +52,10 @@ public class ConfigEntryNameConstants {
public static final String NOISE_DISTRIBUTION = "noise_distribution";
public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions";
public static final String CONSTRAINT_LOSS = "constraint_losses";
public static final String DISCRIMINATOR_OPTIMIZER = "discriminator_optimizer";
public static final String K_VALUE = "k_value";
public static final String GENERATOR_LOSS = "generator_loss";
public static final String CONDITIONAL_INPUT = "conditional_input";
public static final String NOISE_INPUT = "noise_input";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment