From d28d94d358d104f5e75888f7a304a502070e0a94 Mon Sep 17 00:00:00 2001 From: Julian Dierkes Date: Sun, 5 Apr 2020 16:04:01 +0200 Subject: [PATCH] introduced tests for inter Architecture GAN CoCos --- .../cnntrain/_cocos/CNNTrainCocos.java | 6 +- ...ckGANDiscriminatorQNetworkDependency.java} | 20 +-- ...ckGANGeneratorDiscriminatorDependency.java | 30 ++++ .../_cocos/CheckGANGeneratorHasOneOutput.java | 29 +++ .../CheckGANGeneratorQNetworkDependency.java | 31 ++++ .../_cocos/CheckGANQNetworkhasOneInput.java | 29 +++ .../cnntrain/cocos/InterCocoTest.java | 149 +++++++++++++++- .../cocos/NNArchitecturerBuilder.java | 166 ++++++++++++++++++ 8 files changed, 439 insertions(+), 21 deletions(-) rename src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/{CheckGANNetworkPorts.java => CheckGANDiscriminatorQNetworkDependency.java} (61%) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorDiscriminatorDependency.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorHasOneOutput.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorQNetworkDependency.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANQNetworkhasOneInput.java diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 63ff00b..87a3746 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -59,7 +59,11 @@ public class CNNTrainCocos { public static void checkGANCocos(final ConfigurationSymbol configurationSymbol) { CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker() - .addCoCo(new CheckGANNetworkPorts()); + .addCoCo(new CheckGANDiscriminatorQNetworkDependency()) + .addCoCo(new CheckGANGeneratorDiscriminatorDependency()) + .addCoCo(new CheckGANGeneratorHasOneOutput()) + .addCoCo(new CheckGANGeneratorQNetworkDependency()) + .addCoCo(new CheckGANQNetworkhasOneInput()); checker.checkAll(configurationSymbol); } } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANNetworkPorts.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANDiscriminatorQNetworkDependency.java similarity index 61% rename from src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANNetworkPorts.java rename to src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANDiscriminatorQNetworkDependency.java index ef9a024..9e2a8a8 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANNetworkPorts.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANDiscriminatorQNetworkDependency.java @@ -13,25 +13,16 @@ import de.se_rwth.commons.logging.Log; import java.util.Optional; -public class CheckGANNetworkPorts implements CNNTrainConfigurationSymbolCoCo { +public class CheckGANDiscriminatorQNetworkDependency implements CNNTrainConfigurationSymbolCoCo { public void CheckGANNetworkPorts() { } @Override public void check(ConfigurationSymbol configurationSymbol) { - NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get(); NNArchitectureSymbol dis = configurationSymbol.getDiscriminatorNetwork().get(); Optional qnet = configurationSymbol.getQNetwork(); - if(gen.getOutputs().size() != 1) - Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator network has more then one output, " + - "but is supposed to only have one"); - - if(qnet.isPresent() && qnet.get().getInputs().size() != 1) - Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Q-Network has more then one input, " + - "but is supposed to only have one"); - if(qnet.isPresent() && dis.getOutputs().size() != 2) Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 2 output " + "ports when q-network is given"); @@ -48,14 +39,5 @@ public class CheckGANNetworkPorts implements CNNTrainConfigurationSymbolCoCo { if(qnet.isPresent() && !qnet.get().getInputs().get(0).equals("features")) Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Input to q-network needs to be named features"); - - if(!gen.getOutputs().get(0).equals(dis.getInputs().get(0))) - Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " The generator networks output name does not " + - "fit the first discriminators input name"); - - if(qnet.isPresent()) - if(gen.getInputs().contains(qnet.get().getOutputs())) - Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator input does not contain all " + - "latent-codes outputted by q-network"); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorDiscriminatorDependency.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorDiscriminatorDependency.java new file mode 100644 index 0000000..be94e3e --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorDiscriminatorDependency.java @@ -0,0 +1,30 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Optional; + +public class CheckGANGeneratorDiscriminatorDependency implements CNNTrainConfigurationSymbolCoCo { + + public void CheckGANNetworkPorts() { } + + @Override + public void check(ConfigurationSymbol configurationSymbol) { + + NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get(); + NNArchitectureSymbol dis = configurationSymbol.getDiscriminatorNetwork().get(); + + if(!gen.getOutputs().get(0).equals(dis.getInputs().get(0))) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " The generator networks output name does not " + + "fit the first discriminators input name"); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorHasOneOutput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorHasOneOutput.java new file mode 100644 index 0000000..97df616 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorHasOneOutput.java @@ -0,0 +1,29 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Optional; + +public class CheckGANGeneratorHasOneOutput implements CNNTrainConfigurationSymbolCoCo { + + public void CheckGANNetworkPorts() { } + + @Override + public void check(ConfigurationSymbol configurationSymbol) { + + NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get(); + + if(gen.getOutputs().size() != 1) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator network has more then one output, " + + "but is supposed to only have one"); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorQNetworkDependency.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorQNetworkDependency.java new file mode 100644 index 0000000..d11d89d --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANGeneratorQNetworkDependency.java @@ -0,0 +1,31 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Optional; + +public class CheckGANGeneratorQNetworkDependency implements CNNTrainConfigurationSymbolCoCo { + + public void CheckGANNetworkPorts() { } + + @Override + public void check(ConfigurationSymbol configurationSymbol) { + + NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get(); + Optional qnet = configurationSymbol.getQNetwork(); + + if(qnet.isPresent()) + if(!gen.getInputs().containsAll(qnet.get().getOutputs())) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator input does not contain all " + + "latent-codes outputted by q-network"); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANQNetworkhasOneInput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANQNetworkhasOneInput.java new file mode 100644 index 0000000..913a07e --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANQNetworkhasOneInput.java @@ -0,0 +1,29 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Optional; + +public class CheckGANQNetworkhasOneInput implements CNNTrainConfigurationSymbolCoCo { + + public void CheckGANNetworkPorts() { } + + @Override + public void check(ConfigurationSymbol configurationSymbol) { + + Optional qnet = configurationSymbol.getQNetwork(); + + if(qnet.isPresent() && qnet.get().getInputs().size() != 1) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Q-Network has more then one input, " + + "but is supposed to only have one"); + } +} diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/InterCocoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/InterCocoTest.java index 27552f2..9eb9833 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/InterCocoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/InterCocoTest.java @@ -43,6 +43,31 @@ public class InterCocoTest extends AbstractCoCoTest { checkValidCriticArchitecture(configurationSymbol); } + @Test + public void testValidDefaultGAN() { + // given + final NNArchitectureSymbol validGenerator = NNBuilder.getValidGenerator(); + final NNArchitectureSymbol validDiscriminator = NNBuilder.getValidDiscriminator(); + ConfigurationSymbol configurationSymbol = getDefaultGANConfigurationSymbolFrom("valid_tests", "DefaultGANConfig", + validGenerator, validDiscriminator); + + // when + checkValidGANArchitecture(configurationSymbol); + } + + @Test + public void testValidInfoGAN() { + // given + final NNArchitectureSymbol validGenerator = NNBuilder.getValidInfoGANGenerator(); + final NNArchitectureSymbol validDiscriminator = NNBuilder.getValidDiscriminatorWithQNet(); + final NNArchitectureSymbol validQNetwork = NNBuilder.getValidQNetwork(); + ConfigurationSymbol configurationSymbol = getInfoGANConfigurationSymbolFrom("valid_tests", "InfoGANConfig", + validGenerator, validDiscriminator, validQNetwork); + + // when + checkValidGANArchitecture(configurationSymbol); + } + @Test public void testInvalidTrainingArchitectureWithTwoInputs() { // given @@ -223,6 +248,79 @@ public class InterCocoTest extends AbstractCoCoTest { new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR)); } + @Test + public void testInvalidDiscriminatorQNetworkDependency() { + //given + CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckGANDiscriminatorQNetworkDependency(); + NNArchitectureSymbol gen = NNBuilder.getValidGenerator(); + NNArchitectureSymbol dis = NNBuilder.getValidDiscriminator(); + NNArchitectureSymbol qnet = NNBuilder.getValidQNetwork(); + ConfigurationSymbol configurationSymbol = getInfoGANConfigurationSymbolFrom("valid_tests", "InfoGANConfig", + gen, dis, qnet); + + // when + checkInvalidGANArchitecture(configurationSymbol, cocoUUT, + new ExpectedErrorInfo(1, ErrorCodes.GAN_ARCHITECTURE_ERROR)); + } + + @Test + public void testInvalidGeneratorQNetworkDependency() { + //given + CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckGANGeneratorQNetworkDependency(); + NNArchitectureSymbol gen = NNBuilder.getValidGenerator(); + NNArchitectureSymbol dis = NNBuilder.getValidDiscriminatorWithQNet(); + NNArchitectureSymbol qnet = NNBuilder.getValidQNetwork(); + ConfigurationSymbol configurationSymbol = getInfoGANConfigurationSymbolFrom("valid_tests", "InfoGANConfig", + gen, dis, qnet); + + // when + checkInvalidGANArchitecture(configurationSymbol, cocoUUT, + new ExpectedErrorInfo(1, ErrorCodes.GAN_ARCHITECTURE_ERROR)); + } + + @Test + public void testInvalidGeneratorHasMultipleOutputs() { + //given + CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckGANGeneratorHasOneOutput(); + NNArchitectureSymbol gen = NNBuilder.getInvalidGeneratorMultipleOutputs(); + NNArchitectureSymbol dis = NNBuilder.getValidDiscriminator(); + ConfigurationSymbol configurationSymbol = getDefaultGANConfigurationSymbolFrom("valid_tests", "DefaultGANConfig", + gen, dis); + + // when + checkInvalidGANArchitecture(configurationSymbol, cocoUUT, + new ExpectedErrorInfo(1, ErrorCodes.GAN_ARCHITECTURE_ERROR)); + } + + @Test + public void testInvalidGeneratorDiscriminatorDependency() { + //given + CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckGANGeneratorDiscriminatorDependency(); + NNArchitectureSymbol gen = NNBuilder.getValidGenerator(); + NNArchitectureSymbol dis = NNBuilder.getValidDiscriminatorDifferentInput(); + ConfigurationSymbol configurationSymbol = getDefaultGANConfigurationSymbolFrom("valid_tests", "DefaultGANConfig", + gen, dis); + + // when + checkInvalidGANArchitecture(configurationSymbol, cocoUUT, + new ExpectedErrorInfo(1, ErrorCodes.GAN_ARCHITECTURE_ERROR)); + } + + @Test + public void testInvalidQNetworkMultipleInputs() { + //given + CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckGANQNetworkhasOneInput(); + NNArchitectureSymbol gen = NNBuilder.getValidGenerator(); + NNArchitectureSymbol dis = NNBuilder.getValidDiscriminatorWithQNet(); + NNArchitectureSymbol qnet = NNBuilder.getInvalidQNetworkMultipleInputs(); + ConfigurationSymbol configurationSymbol = getInfoGANConfigurationSymbolFrom("valid_tests", "InfoGANConfig", + gen, dis, qnet); + + // when + checkInvalidGANArchitecture(configurationSymbol, cocoUUT, + new ExpectedErrorInfo(1, ErrorCodes.GAN_ARCHITECTURE_ERROR)); + } + private ConfigurationSymbol getConfigurationSymbolFrom(final String modelPath, final String model, final NNArchitectureSymbol actorArchitecture, final NNArchitectureSymbol criticArchitecture) { final ConfigurationSymbol configurationSymbol = getConfigurationSymbolByPath( modelPath, model); @@ -231,6 +329,25 @@ public class InterCocoTest extends AbstractCoCoTest { return configurationSymbol; } + private ConfigurationSymbol getDefaultGANConfigurationSymbolFrom(final String modelPath, final String model, + final NNArchitectureSymbol genArchitecture, final NNArchitectureSymbol disArchitecture) { + final ConfigurationSymbol configurationSymbol = getConfigurationSymbolByPath( modelPath, model); + configurationSymbol.setTrainedArchitecture(genArchitecture); + configurationSymbol.setDiscriminatorNetwork(disArchitecture); + return configurationSymbol; + } + + private ConfigurationSymbol getInfoGANConfigurationSymbolFrom(final String modelPath, final String model, + final NNArchitectureSymbol genArchitecture, + final NNArchitectureSymbol disArchitecture, + final NNArchitectureSymbol qnetArchitecture) { + final ConfigurationSymbol configurationSymbol = getConfigurationSymbolByPath( modelPath, model); + configurationSymbol.setTrainedArchitecture(genArchitecture); + configurationSymbol.setDiscriminatorNetwork(disArchitecture); + configurationSymbol.setQNetwork(qnetArchitecture); + return configurationSymbol; + } + private ConfigurationSymbol getConfigurationSymbolByPath(final String modelPath, final String model) { return getCompilationUnitSymbol(modelPath, model).getConfiguration(); } @@ -238,6 +355,7 @@ public class InterCocoTest extends AbstractCoCoTest { private enum CheckOption { TRAINED_ARCHITECTURE_COCOS, CRITIC_ARCHITECTURE_COCOS, + GAN_ARCHITECTURE_COCOS, } private void checkInvalidArchitecture( @@ -249,10 +367,13 @@ public class InterCocoTest extends AbstractCoCoTest { if (checkOption.equals(CheckOption.TRAINED_ARCHITECTURE_COCOS)) { CNNTrainCocos.checkTrainedArchitectureCoCos(configurationSymbol); - } else { + } else if(checkOption.equals(CheckOption.CRITIC_ARCHITECTURE_COCOS)) { CNNTrainCocos.checkCriticCocos(configurationSymbol); + } else if(checkOption.equals(CheckOption.GAN_ARCHITECTURE_COCOS)) { + CNNTrainCocos.checkGANCocos(configurationSymbol); } + expectedErrors.checkExpectedPresent(Log.getFindings(), "Got no findings when checking all " + "cocos. Did you forget to add the new coco to MontiArcCocos?"); Log.getFindings().clear(); @@ -262,6 +383,19 @@ public class InterCocoTest extends AbstractCoCoTest { + "the given coco. Did you pass an empty coco checker?"); } + private void checkInvalidArchitectureOnlyCoCo( + final ConfigurationSymbol configurationSymbol, + final CNNTrainConfigurationSymbolCoCo cocoUUT, + final ExpectedErrorInfo expectedErrors) { + Log.getFindings().clear(); + + Log.getFindings().clear(); + CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker().addCoCo(cocoUUT); + checker.checkAll(configurationSymbol); + expectedErrors.checkOnlyExpectedPresent(Log.getFindings(), "Got no findings when checking only " + + "the given coco. Did you pass an empty coco checker?"); + } + private void checkInvalidTrainedArchitecture( final ConfigurationSymbol configurationSymbol, final CNNTrainConfigurationSymbolCoCo cocoUUT, @@ -287,4 +421,17 @@ public class InterCocoTest extends AbstractCoCoTest { CNNTrainCocos.checkCriticCocos(configurationSymbol); new ExpectedErrorInfo().checkOnlyExpectedPresent(Log.getFindings()); } + + private void checkValidGANArchitecture(final ConfigurationSymbol configurationSymbol) { + Log.getFindings().clear(); + CNNTrainCocos.checkGANCocos(configurationSymbol); + new ExpectedErrorInfo().checkOnlyExpectedPresent(Log.getFindings()); + } + + private void checkInvalidGANArchitecture( + final ConfigurationSymbol configurationSymbol, + final CNNTrainConfigurationSymbolCoCo cocoUUT, + ExpectedErrorInfo expectedErrors) { + checkInvalidArchitectureOnlyCoCo(configurationSymbol, cocoUUT, expectedErrors); + } } diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/NNArchitecturerBuilder.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/NNArchitecturerBuilder.java index 6738fff..6498161 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/NNArchitecturerBuilder.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/NNArchitecturerBuilder.java @@ -13,6 +13,8 @@ import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; import de.monticore.lang.monticar.cnntrain.annotations.Range; import java.awt.*; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -228,4 +230,168 @@ public class NNArchitecturerBuilder { dimensions, getValidCriticTypes(), getValidCriticRanges()); } + + public NNArchitectureSymbol getValidGenerator() { + ArrayList input = Lists.newArrayList("noise"); + ArrayList output = Lists.newArrayList("data"); + HashMap dims = Maps.newHashMap(ImmutableMap.>builder() + .put("noise", Lists.newArrayList(100)) + .put("data", Lists.newArrayList(3,28,28)) + .build()); + HashMap types = Maps.newHashMap(ImmutableMap.builder() + .put("noise", "Q") + .put("data", "Q") + .build()); + HashMap ranges = Maps.newHashMap(ImmutableMap.builder() + .put("noise", Range.withInfinityLimits() ) + .put("data", Range.withLimits(-1,1)) + .build()); + return getNNArchitectureSymbolFrom("GeneratorValid", input, output, + dims, types, ranges); + } + + public NNArchitectureSymbol getValidInfoGANGenerator() { + ArrayList input = Lists.newArrayList("noise", "c1"); + ArrayList output = Lists.newArrayList("data"); + HashMap dims = Maps.newHashMap(ImmutableMap.>builder() + .put("noise", Lists.newArrayList(100)) + .put("data", Lists.newArrayList(3,28,28)) + .put("c1", Lists.newArrayList(10)) + .build()); + HashMap types = Maps.newHashMap(ImmutableMap.builder() + .put("noise", "Q") + .put("data", "Q") + .put("c1", "Q") + .build()); + HashMap ranges = Maps.newHashMap(ImmutableMap.builder() + .put("noise", Range.withInfinityLimits() ) + .put("data", Range.withLimits(-1,1)) + .put("c1", Range.withLimits(0,1)) + .build()); + return getNNArchitectureSymbolFrom("GeneratorValid", input, output, + dims, types, ranges); + } + + public NNArchitectureSymbol getInvalidGeneratorMultipleOutputs() { + ArrayList input = Lists.newArrayList("noise"); + ArrayList output = Lists.newArrayList("data1", "data2"); + HashMap dims = Maps.newHashMap(ImmutableMap.>builder() + .put("noise", Lists.newArrayList(100)) + .put("data1", Lists.newArrayList(3,28,28)) + .put("data2", Lists.newArrayList(10)) + .build()); + HashMap types = Maps.newHashMap(ImmutableMap.builder() + .put("noise", "Q") + .put("data1", "Q") + .put("data2", "Q") + .build()); + HashMap ranges = Maps.newHashMap(ImmutableMap.builder() + .put("noise", Range.withInfinityLimits() ) + .put("data1", Range.withLimits(-1,1)) + .put("data2", Range.withLimits(0,1)) + .build()); + return getNNArchitectureSymbolFrom("GeneratorInvalidGeneratorMultipleOutputs", input, output, + dims, types, ranges); + } + + public NNArchitectureSymbol getValidDiscriminator() { + ArrayList input = Lists.newArrayList("data"); + ArrayList output = Lists.newArrayList("dis"); + HashMap dims = Maps.newHashMap(ImmutableMap.>builder() + .put("data", Lists.newArrayList(3,28,28)) + .put("dis", Lists.newArrayList(1)) + .build()); + HashMap types = Maps.newHashMap(ImmutableMap.builder() + .put("data", "Q") + .put("dis", "Q") + .build()); + HashMap ranges = Maps.newHashMap(ImmutableMap.builder() + .put("data", Range.withInfinityLimits() ) + .put("dis", Range.withLimits(0,1)) + .build()); + return getNNArchitectureSymbolFrom("DiscriminatorValid", input, output, + dims, types, ranges); + } + + public NNArchitectureSymbol getValidDiscriminatorWithQNet() { + ArrayList input = Lists.newArrayList("data"); + ArrayList output = Lists.newArrayList("dis", "features"); + HashMap dims = Maps.newHashMap(ImmutableMap.>builder() + .put("data", Lists.newArrayList(3,28,28)) + .put("dis", Lists.newArrayList(1)) + .put("features", Lists.newArrayList(1024)) + .build()); + HashMap types = Maps.newHashMap(ImmutableMap.builder() + .put("data", "Q") + .put("dis", "Q") + .put("features", "Q") + .build()); + HashMap ranges = Maps.newHashMap(ImmutableMap.builder() + .put("data", Range.withInfinityLimits() ) + .put("dis", Range.withLimits(0,1)) + .put("features", Range.withInfinityLimits()) + .build()); + return getNNArchitectureSymbolFrom("DiscriminatorValidQNet", input, output, + dims, types, ranges); + } + + public NNArchitectureSymbol getValidDiscriminatorDifferentInput() { + ArrayList input = Lists.newArrayList("data2"); + ArrayList output = Lists.newArrayList("dis"); + HashMap dims = Maps.newHashMap(ImmutableMap.>builder() + .put("dis", Lists.newArrayList(1)) + .put("data2", Lists.newArrayList(1024)) + .build()); + HashMap types = Maps.newHashMap(ImmutableMap.builder() + .put("dis", "Q") + .put("data2", "Q") + .build()); + HashMap ranges = Maps.newHashMap(ImmutableMap.builder() + .put("dis", Range.withLimits(0,1)) + .put("data2", Range.withInfinityLimits()) + .build()); + return getNNArchitectureSymbolFrom("DiscriminatorValidDifferentInputs", input, output, + dims, types, ranges); + } + + public NNArchitectureSymbol getValidQNetwork() { + ArrayList input = Lists.newArrayList("features"); + ArrayList output = Lists.newArrayList("c1"); + HashMap dims = Maps.newHashMap(ImmutableMap.>builder() + .put("features", Lists.newArrayList(1024)) + .put("c1", Lists.newArrayList(10)) + .build()); + HashMap types = Maps.newHashMap(ImmutableMap.builder() + .put("features", "Q") + .put("c1", "Q") + .build()); + HashMap ranges = Maps.newHashMap(ImmutableMap.builder() + .put("features", Range.withInfinityLimits() ) + .put("1", Range.withLimits(0,1)) + .build()); + return getNNArchitectureSymbolFrom("QNetworkValid", input, output, + dims, types, ranges); + } + + public NNArchitectureSymbol getInvalidQNetworkMultipleInputs() { + ArrayList input = Lists.newArrayList("features1", "features2"); + ArrayList output = Lists.newArrayList("c1"); + HashMap dims = Maps.newHashMap(ImmutableMap.>builder() + .put("features1", Lists.newArrayList(1024)) + .put("features2", Lists.newArrayList(1024)) + .put("c1", Lists.newArrayList(10)) + .build()); + HashMap types = Maps.newHashMap(ImmutableMap.builder() + .put("features1", "Q") + .put("features2", "Q") + .put("c1", "Q") + .build()); + HashMap ranges = Maps.newHashMap(ImmutableMap.builder() + .put("features1", Range.withInfinityLimits() ) + .put("features2", Range.withInfinityLimits() ) + .put("1", Range.withLimits(0,1)) + .build()); + return getNNArchitectureSymbolFrom("QNetworkInvalidMultipleInputs", input, output, + dims, types, ranges); + } } -- GitLab