Commit 8f498657 authored by Nicola Gatto's avatar Nicola Gatto

Annotate train configuration with architecture

parent ea6325f7
......@@ -69,6 +69,8 @@ public class EMADLGenerator {
private String modelsPath;
private Map<String, ArchitectureSymbol> processedArchitecture;
public EMADLGenerator(Backend backend) {
......@@ -111,6 +113,7 @@ public class EMADLGenerator {
}
public void generate(String modelPath, String qualifiedName, String pythonPath, String forced, boolean doCompile) throws IOException, TemplateException {
processedArchitecture = new HashMap<>();
setModelsPath( modelPath );
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentSymbol component = symtab.<EMAComponentSymbol>resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null);
......@@ -132,6 +135,7 @@ public class EMADLGenerator {
if (doCompile) {
compile();
}
processedArchitecture = null;
}
public void compile() throws IOException {
......@@ -379,6 +383,9 @@ public class EMADLGenerator {
architecture.get().setDataPath(dPath);
architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
if (processedArchitecture != null) {
processedArchitecture.put(architecture.get().getComponentName(), architecture.get());
}
}
else if (mathStatements.isPresent()){
generateMathComponent(fileContents, taggingResolver, componentInstanceSymbol, mathStatements.get());
......@@ -507,7 +514,13 @@ public class EMADLGenerator {
trainConfigFilename = names.get(names.size()-1);
Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1)));
ConfigurationSymbol configuration = cnnTrainGenerator.getConfigurationSymbol(modelPath, trainConfigFilename);
configuration.setTrainedArchitecture(new ArchitectureAdapter(architecture.get()));
// Annotate train configuration with architecture
final String fullConfigName = String.join(".", names);
ArchitectureSymbol correspondingArchitecture = this.processedArchitecture.get(fullConfigName);
assert correspondingArchitecture != null : "No architecture found for train " + fullConfigName + " configuration!";
configuration.setTrainedArchitecture(new ArchitectureAdapter(correspondingArchitecture));
cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_"));
Map<String, String> fileContentMap = cnnTrainGenerator.generateStrings(configuration);
for (String fileName : fileContentMap.keySet()){
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment