Commit 52ba313d authored by Julian Dierkes's avatar Julian Dierkes
Browse files

introduce GAN features

parent 5b194010
......@@ -26,6 +26,7 @@ import de.monticore.lang.monticar.generator.cpp.ArmadilloHelper;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
import de.monticore.lang.monticar.generator.cpp.SimulatorIntegrationHelper;
import de.monticore.lang.monticar.generator.cpp.TypesGeneratorCPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapper;
import de.monticore.lang.monticar.generator.cpp.converter.TypeConverter;
import de.monticore.lang.tagging._symboltable.TagSymbol;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
......@@ -51,6 +52,7 @@ public class EMADLGenerator {
private GeneratorEMAMOpt2CPP emamGen;
private CNNArchGenerator cnnArchGenerator;
private CNNTrainGenerator cnnTrainGenerator;
private GeneratorPythonWrapper pythonWrapper;
private Backend backend;
private String modelsPath;
......@@ -62,6 +64,7 @@ public class EMADLGenerator {
emamGen = new GeneratorEMAMOpt2CPP();
emamGen.useArmadilloBackend();
emamGen.setGenerationTargetPath("./target/generated-sources-emadl/");
pythonWrapper.setGenerationTargetPath("./target/");
cnnArchGenerator = backend.getCNNArchGenerator();
cnnTrainGenerator = backend.getCNNTrainGenerator();
}
......@@ -123,7 +126,10 @@ public class EMADLGenerator {
System.exit(1);
}
return component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get();
Scope c1 = component.getEnclosingScope();
Optional<EMAComponentInstanceSymbol> c2 = c1.<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND);
EMAComponentInstanceSymbol c3 = c2.get();
return c3;
}
public void compile() throws IOException {
......@@ -589,6 +595,41 @@ public class EMADLGenerator {
CNNTrainCocos.checkCriticCocos(configuration);
}
// Resolve discriminator network if discriminator is present
if (configuration.getDiscriminatorName().isPresent()) {
String fullDiscriminatorName = configuration.getDiscriminatorName().get();
int indexOfFirstNameCharacter = fullDiscriminatorName.lastIndexOf('.') + 1;
fullDiscriminatorName = fullDiscriminatorName.substring(0, indexOfFirstNameCharacter)
+ fullDiscriminatorName.substring(indexOfFirstNameCharacter, indexOfFirstNameCharacter + 1).toUpperCase()
+ fullDiscriminatorName.substring(indexOfFirstNameCharacter + 1);
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentInstanceSymbol instanceSymbol = resolveComponentInstanceSymbol(fullDiscriminatorName, symtab);
EMADLCocos.checkAll(instanceSymbol);
Optional<ArchitectureSymbol> discriminator = instanceSymbol.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if (!discriminator.isPresent()) {
Log.error("During the resolving of critic component: Critic component "
+ fullDiscriminatorName + " does not have a CNN implementation but is required to have one");
System.exit(-1);
}
discriminator.get().setComponentName(fullDiscriminatorName);
configuration.setDiscriminatorNetwork(new ArchitectureAdapter(fullDiscriminatorName, discriminator.get()));
//CNNTrainCocos.checkCriticCocos(configuration);
}
if (configuration.hasPreprocessor()) {
String preprocessor_name = configuration.getPreprocessingName().get();
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentInstanceSymbol instance = resolveComponentInstanceSymbol(preprocessor_name, symtab);
generateComponent(fileContents, allInstances, symtab, instance, symtab);
try {
pythonWrapper.generateFiles(instance);
} catch (IOException e) {
// todo: add fancy error message here
e.printStackTrace();
}
}
cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_"));
Map<String, String> fileContentMap = cnnTrainGenerator.generateStrings(configuration);
......
......@@ -33,6 +33,7 @@ import java.nio.file.Paths;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertFalse;
@Ignore
public class IntegrationTensorflowTest extends IntegrationTest {
private Path multipleStreamsHashFile = Paths.get("./target/generated-sources-emadl/MultipleStreams.training_hash");
......@@ -41,6 +42,7 @@ public class IntegrationTensorflowTest extends IntegrationTest {
super("TENSORFLOW", "39253EC049D4A4E5FA0536AD34874B9D#1DBAEE1B1BD83FB7CB5F70AE91B29638#C4C23549E737A759721D6694C75D9771#5AF0CE68E408E8C1F000E49D72AC214A");
}
@Ignore
@Test
public void testMultipleStreams() {
Log.getFindings().clear();
......
Data loading failure. File '/home/julian/Dropbox/Dokumente/bachelorarbeit/pipeline/EMADL2CPP/resources/training_data/train.h5' does not exist.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment