Commit 40a1a42a authored by Julian Dierkes's avatar Julian Dierkes

added more features for GAN

parent 3889c267
......@@ -246,7 +246,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
// GANs Extensions
DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue;
QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue;
PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue;
ImgResizeEntry implements ConfigEntry = name:"img_resize" ":" value:IntegerTupelValue;
// Noise Distribution Creator
......
......@@ -115,6 +115,7 @@ class ParameterAlgorithmMapping {
private static final List<Class> GENERAL_GAN_PARAMETERS = Lists.newArrayList(
ASTDiscriminatorNetworkEntry.class,
ASTQNetworkEntry.class,
ASTNoiseDistributionEntry.class,
ASTImgResizeEntry.class
);
......
......@@ -52,6 +52,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
CNNTrainCompilationUnitSymbol compilationUnitSymbol = new CNNTrainCompilationUnitSymbol(compilationUnit.getName());
addToScopeAndLinkWithNode(compilationUnitSymbol, compilationUnit);
}
@Override
public void endVisit(ASTCNNTrainCompilationUnit ast) {
......@@ -463,6 +464,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTQNetworkEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTPreprocessingEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......
......@@ -23,6 +23,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private NNArchitectureSymbol trainedArchitecture;
private NNArchitectureSymbol criticNetwork;
private NNArchitectureSymbol discriminatorNetwork;
private NNArchitectureSymbol qNetwork;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
......@@ -80,6 +81,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(discriminatorNetwork);
}
public Optional<NNArchitectureSymbol> getQNetwork() {
return Optional.ofNullable(qNetwork);
}
public void setCriticNetwork(NNArchitectureSymbol criticNetwork) {
this.criticNetwork = criticNetwork;
}
......@@ -88,6 +93,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
this.discriminatorNetwork = discriminatorNetwork;
}
public void setQNetwork(NNArchitectureSymbol qNetwork) {
this.qNetwork = qNetwork;
}
public Map<String, EntrySymbol> getEntryMap() {
return entryMap;
}
......@@ -121,6 +130,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return getEntryMap().containsKey(DISCRIMINATOR_NAME);
}
public boolean hasQNetwork() {
return getEntryMap().containsKey(QNETWORK_NAME);
}
public Optional<String> getCriticName() {
if (!hasCritic()) {
return Optional.empty();
......@@ -150,4 +163,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
assert discriminatorNameValue instanceof String;
return Optional.of((String)discriminatorNameValue);
}
public Optional<String> getQNetworkName() {
if (!hasQNetwork()) {
return Optional.empty();
}
final Object qnetNameValue = getEntry(QNETWORK_NAME).getValue().getValue();
assert qnetNameValue instanceof String;
return Optional.of((String)qnetNameValue);
}
}
\ No newline at end of file
......@@ -47,6 +47,7 @@ public class ConfigEntryNameConstants {
public static final String CRITIC = "critic";
public static final String DISCRIMINATOR_NAME = "discriminator_name";
public static final String QNETWORK_NAME = "qnet_name";
public static final String PREPROCESSING_NAME = "preprocessing_name";
public static final String NOISE_DISTRIBUTION = "noise_distribution";
public static final String IMG_RESIZE = "img_resize";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment