Commit 7a403446 authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns Committed by Thomas Michael Timmermanns

Improved Generator.

Fixed port adapter bug.
parent 99234812
......@@ -42,13 +42,22 @@ public class CheckIOName implements CNNArchASTIOLayerCoCo {
"The input or output '" + node.getName() + "' does not exist"
, node.get_SourcePositionStart());
}
else if (ioDeclarations.size() > 1){
else {
IODeclarationSymbol ioDeclaration = ioDeclarations.iterator().next();
if (!checkedIODeclarations.contains(ioDeclaration)){
Log.error("0" + ErrorCodes.DUPLICATED_NAME + " Duplicated IO name. " +
"The name '" + ioDeclaration.getName() + "' is already used."
, ioDeclaration.getSourcePosition());
checkedIODeclarations.addAll(ioDeclarations);
if (ioDeclarations.size() > 1) {
if (!checkedIODeclarations.contains(ioDeclaration)) {
Log.error("0" + ErrorCodes.DUPLICATED_NAME + " Duplicated IO name. " +
"The name '" + ioDeclaration.getName() + "' is already used."
, ioDeclaration.getSourcePosition());
checkedIODeclarations.addAll(ioDeclarations);
}
}
else {
if (ioDeclaration.getName().endsWith("_")){
Log.error("0" + ErrorCodes.ILLEGAL_NAME + " Illegal IO name. " +
"Input and output names cannot end with \"_\"",
ioDeclaration.getSourcePosition());
}
}
}
}
......
......@@ -25,6 +25,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import javax.annotation.Nullable;
import java.util.*;
public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
......@@ -34,6 +35,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
private LayerSymbol body;
private List<IOLayerSymbol> inputs = new ArrayList<>();
private List<IOLayerSymbol> outputs = new ArrayList<>();
private Map<String, IODeclarationSymbol> ioDeclarationMap = new HashMap<>();
public ArchitectureSymbol() {
super("", KIND);
......@@ -55,15 +57,23 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return outputs;
}
public Set<IODeclarationSymbol> getIODeclarations(){
Set<IODeclarationSymbol> ioDeclarations = new HashSet<>();
for (IOLayerSymbol input : getInputs()){
ioDeclarations.add(input.getDefinition());
//called in IOLayer to get IODeclaration; only null if error; will be checked in coco CheckIOName
@Nullable
protected IODeclarationSymbol resolveIODeclaration(String name){
IODeclarationSymbol ioDeclaration = ioDeclarationMap.get(name);
if (ioDeclaration == null){
Collection<IODeclarationSymbol> ioDefCollection = getEnclosingScope().resolveMany(name, IODeclarationSymbol.KIND);
if (!ioDefCollection.isEmpty()){
ioDeclaration = ioDefCollection.iterator().next();
ioDeclarationMap.put(name, ioDeclaration);
ioDeclaration.setArchitecture(this);
}
}
for (IOLayerSymbol output : getOutputs()){
ioDeclarations.add(output.getDefinition());
}
return ioDeclarations;
return ioDeclaration;
}
public Collection<IODeclarationSymbol> getIODeclarations(){
return ioDeclarationMap.values();
}
public Collection<MethodDeclarationSymbol> getMethodDeclarations(){
......
......@@ -25,8 +25,9 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonSymbol;
import java.util.HashSet;
import java.util.Set;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
public class IODeclarationSymbol extends CommonSymbol {
......@@ -35,8 +36,7 @@ public class IODeclarationSymbol extends CommonSymbol {
private ArchTypeSymbol type;
private boolean input; //true->input, false->output
private int arrayLength = 1;
private Set<IOLayerSymbol> connectedLayers = new HashSet<>();
private ArchitectureSymbol architecture = null; // set by ArchitectureSymbol
protected IODeclarationSymbol(String name) {
super(name, KIND);
......@@ -50,8 +50,21 @@ public class IODeclarationSymbol extends CommonSymbol {
this.type = type;
}
public Set<IOLayerSymbol> getConnectedLayers() {
return connectedLayers;
public List<IOLayerSymbol> getConnectedLayers() {
if (getArchitecture() == null){
return new ArrayList<>();
}
else {
List<IOLayerSymbol> completeList;
if (input) {
completeList = getArchitecture().getInputs();
} else {
completeList = getArchitecture().getOutputs();
}
return completeList.stream()
.filter(e -> e.getName().equals(getName()))
.collect(Collectors.toList());
}
}
public boolean isOutput(){
......@@ -74,7 +87,13 @@ public class IODeclarationSymbol extends CommonSymbol {
this.arrayLength = arrayLength;
}
public ArchitectureSymbol getArchitecture() {
return architecture;
}
public void setArchitecture(ArchitectureSymbol architecture) {
this.architecture = architecture;
}
public static class Builder{
private ArchTypeSymbol type;
......
......@@ -50,12 +50,10 @@ public class IOLayerSymbol extends LayerSymbol {
this.arrayAccess = ArchSimpleExpressionSymbol.of(arrayAccess);
}
//returns null if IODeclaration does not exist. This is checked in coco CheckIOName.
public IODeclarationSymbol getDefinition() {
if (definition == null){
Collection<IODeclarationSymbol> ioDefCollection = getEnclosingScope().resolveMany(getName(), IODeclarationSymbol.KIND);
if (!ioDefCollection.isEmpty()){
setDefinition(ioDefCollection.iterator().next());
}
this.definition = getArchitecture().resolveIODeclaration(getName());
}
return definition;
}
......@@ -65,11 +63,6 @@ public class IOLayerSymbol extends LayerSymbol {
return super.isResolvable() && getDefinition() != null;
}
private void setDefinition(IODeclarationSymbol definition) {
this.definition = definition;
definition.getConnectedLayers().add(this);
}
@Override
public boolean isInput(){
return getDefinition().isInput();
......
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.Joiners;
import java.util.*;
......@@ -96,6 +97,16 @@ public abstract class LayerSymbol extends CommonScopeSpanningSymbol {
return false;
}
public ArchitectureSymbol getArchitecture(){
Symbol sym = getEnclosingScope().getSpanningSymbol().get();
if (sym instanceof ArchitectureSymbol){
return (ArchitectureSymbol) sym;
}
else {
return ((LayerSymbol) sym).getArchitecture();
}
}
/**
* only call after resolve():
* @return returns the non-empty atomic layers which have the output of this layer as input.
......
......@@ -33,37 +33,27 @@ import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
public class CNNArchGenerator {
private Target targetLanguage;
private String generationTargetPath;
public CNNArchGenerator() {
setTargetLanguage(Target.CPP);
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public Target getTargetLanguage() {
return targetLanguage;
}
public void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
public String getGenerationTargetPath() {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
}
return generationTargetPath;
}
public void setGenerationTargetPath(String generationTargetPath) {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
}
else {
this.generationTargetPath = generationTargetPath;
}
this.generationTargetPath = generationTargetPath;
}
public void generate(Path modelsDirPath, String rootModelName){
......@@ -83,44 +73,53 @@ public class CNNArchGenerator {
try{
ArchitectureSymbol architecture = compilationUnit.get().getArchitecture();
generateNetworkFile(architecture);
generateFiles(architecture);
}
catch (IOException e){
Log.error(e.toString());
}
}
public String generateNetworkString(ArchitectureSymbol architecture){
TemplateController archTc = new TemplateController(architecture, targetLanguage);
return archTc.process();
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>();
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
Map.Entry<String, String> temp;
public void generateNetworkFile(ArchitectureSymbol architecture) throws IOException{
File f = new File(getGenerationTargetPath() + getFileName(architecture));
Log.info(f.getName(), "FileCreation:");
if (!f.exists()) {
f.getParentFile().mkdirs();
if (!f.createNewFile()) {
Log.error("File could not be created");
}
}
temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("Network", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
FileWriter writer = new FileWriter(f);
TemplateController archTc = new TemplateController(architecture, targetLanguage);
archTc.process(writer);
writer.close();
temp = archTc.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = archTc.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
return fileContentMap;
}
public String getFileName(ArchitectureSymbol architecture){
String name = architecture.getEnclosingScope().getSpanningSymbol().get().getFullName();
name = name.replaceAll("\\.", "_").replaceAll("\\[", "_").replaceAll("\\]", "_");
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public void generateFiles(ArchitectureSymbol architecture) throws IOException{
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
Map<String, String> fileContentMap = generateStrings(architecture);
for (String fileName : fileContentMap.keySet()){
File f = new File(getGenerationTargetPath() + fileName);
Log.info(f.getName(), "FileCreation:");
if (!f.exists()) {
f.getParentFile().mkdirs();
if (!f.createNewFile()) {
Log.error("File could not be created");
}
}
String fileEnding = getTargetLanguage().toString();
if (getTargetLanguage() == Target.CPP){
fileEnding = ".h";
FileWriter writer = new FileWriter(f);
writer.write(fileContentMap.get(fileName));
writer.close();
}
return name + "__network" + fileEnding;
}
}
......@@ -48,12 +48,12 @@ public class CNNArchGeneratorCli {
.required(false)
.build();
public static final Option OPTION_TARGET_LANG = Option.builder("t")
/*public static final Option OPTION_TARGET_LANG = Option.builder("t")
.longOpt("target-language")
.desc("target language of network e.g. c++ or python")
.hasArg(true)
.required(false)
.build();
.build();*/
private CNNArchGeneratorCli() {
}
......@@ -72,7 +72,7 @@ public class CNNArchGeneratorCli {
options.addOption(OPTION_MODELS_PATH);
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
options.addOption(OPTION_TARGET_LANG);
//options.addOption(OPTION_TARGET_LANG);
return options;
}
......@@ -92,14 +92,11 @@ public class CNNArchGeneratorCli {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
String targetLanguage = cliArgs.getOptionValue(OPTION_TARGET_LANG.getOpt());
//String targetLanguage = cliArgs.getOptionValue(OPTION_TARGET_LANG.getOpt());
CNNArchGenerator generator = new CNNArchGenerator();
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
if (targetLanguage != null){
generator.setTargetLanguage(Target.fromString(targetLanguage));
}
generator.generate(modelsDirPath, rootModelName);
}
}
......@@ -36,7 +36,7 @@ import java.io.StringWriter;
import java.io.Writer;
import java.util.*;
public class TemplateController {
public class CNNArchTemplateController {
public static final String FTL_FILE_ENDING = ".ftl";
public static final String TEMPLATE_LAYER_DIR_PATH = "layers/";
......@@ -44,18 +44,27 @@ public class TemplateController {
private LayerNameCreator nameManager;
private Configuration freemarkerConfig = TemplateConfiguration.get();
private ArchitectureSymbol architecture;
private Writer writer;
private String mainTemplateNameWithoutEnding;
private Target targetLanguage;
private LayerSymbol currentLayer;
private Target target;
public TemplateController(ArchitectureSymbol architecture, Target target) {
public CNNArchTemplateController(ArchitectureSymbol architecture) {
setArchitecture(architecture);
this.target = target;
}
public String getTarget(){
return target.toString();
public String getFileNameWithoutEnding() {
return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName();
}
public Target getTargetLanguage(){
return targetLanguage;
}
public void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
public LayerSymbol getCurrentLayer() {
......@@ -75,23 +84,6 @@ public class TemplateController {
this.nameManager = new LayerNameCreator(architecture);
}
public String getCurrentOutputShape(){
return getOutputShape(getCurrentLayer());
}
public String getOutputShape(LayerSymbol layer){
if (layer.getOutputTypes().size() == 1){
return shapeToString(layer.getOutputTypes().get(0));
}
else {
List<String> strings = new ArrayList<>();
for (ArchTypeSymbol shape : layer.getOutputTypes()){
strings.add(shapeToString(shape));
}
return "{" + Joiners.COMMA.join(strings) + "}";
}
}
private String shapeToString(ArchTypeSymbol shape){
return "[" + Joiners.COMMA.join(shape.getDimensions()) + "]";
}
......@@ -104,15 +96,23 @@ public class TemplateController {
return nameManager.getName(layer);
}
public String getArchitectureName(){
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getName();
}
public String getFullArchitectureName(){
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getFullName();
}
public List<String> getCurrentInputs(){
return getInputs(getCurrentLayer());
return getLayerInputs(getCurrentLayer());
}
public List<String> getInputs(LayerSymbol layer){
public List<String> getLayerInputs(LayerSymbol layer){
List<String> inputNames = new ArrayList<>();
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
inputNames = getInputs(layer.getInputLayer().get());
inputNames = getLayerInputs(layer.getInputLayer().get());
}
else {
for (LayerSymbol input : layer.getPrevious()) {
......@@ -145,20 +145,24 @@ public class TemplateController {
return list;
}
public void include(String relativePath, String templateWithoutFileEnding, Writer stringWriter){
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
try {
Template template = freemarkerConfig.getTemplate(templatePath);
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
template.process(ftlContext, stringWriter);
Map<String, Object> ftlContext = Collections.singletonMap(TEMPLATE_CONTROLLER_KEY, this);
this.writer = writer;
template.process(ftlContext, writer);
this.writer = null;
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
}
......@@ -166,7 +170,6 @@ public class TemplateController {
LayerSymbol previousLayer = getCurrentLayer();
setCurrentLayer(layer);
String result;
if (layer.isInput()){
include(TEMPLATE_LAYER_DIR_PATH, "Input", writer);
}
......@@ -217,20 +220,30 @@ public class TemplateController {
}
}
public String include(LayerSymbol layer){
StringWriter writer = new StringWriter();
public void include(LayerSymbol layer){
if (writer == null){
throw new IllegalStateException("missing writer");
}
include(layer, writer);
return writer.toString();
}
public void process(Writer writer) throws IOException{
include("", "Network", writer);
}
public String process(){
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
StringWriter writer = new StringWriter();
include("", "Network", writer);
return writer.toString();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage;
include("", templateNameWithoutEnding, writer);
String fileEnding = targetLanguage.toString();
if (targetLanguage == Target.CPP){
fileEnding = ".h";
}
String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString());
this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null;
return fileContent;
}
public String join(Iterable iterable, String separator){
......@@ -399,5 +412,4 @@ public class TemplateController {
return Arrays.asList(0,0,0,0,topPad,bottomPad,leftPad,rightPad);
}
}
......@@ -113,12 +113,12 @@ public class LayerNameCreator {
IOLayerSymbol ioLayer = (IOLayerSymbol) layer;
if (ioLayer.getArrayAccess().isPresent()){
int arrayAccess = ioLayer.getArrayAccess().get().getIntValue().get();
name = name + arrayAccess;
name = name + "_" + arrayAccess + "_";
}
return name;
}
else {
return createBaseName(layer) + stage + createStreamPostfix(streamIndices);
return createBaseName(layer) + stage + createStreamPostfix(streamIndices) + "_";
}
}
......
......@@ -25,11 +25,25 @@ import freemarker.template.TemplateExceptionHandler;
public class TemplateConfiguration {
private static TemplateConfiguration instance;
private Configuration configuration;
private TemplateConfiguration() {
configuration = new Configuration(Configuration.VERSION_2_3_23);
configuration.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/");
configuration.setDefaultEncoding("UTF-8");
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
}
public Configuration getConfiguration() {
return configuration;
}
public static Configuration get(){
Configuration cfg = new Configuration(Configuration.VERSION_2_3_23);
cfg.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/");
cfg.setDefaultEncoding("UTF-8");
cfg.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
return cfg;
if (instance == null){
instance = new TemplateConfiguration();
}
return instance.getConfiguration();
}
}
#ifndef CNNBUFFERFILE_H
#define CNNBUFFERFILE_H
#include <stdio.h>
#include <iostream>
#include <fstream>
// Read file to buffer
class BufferFile {
public :
std::string file_path_;
int length_;
char* buffer_;
explicit BufferFile(std::string file_path)
:file_path_(file_path) {
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
std::cerr << "Can't open the file. Please check " << file_path << ". \n";
length_ = 0;
buffer_ = NULL;
return;
}
ifs.seekg(0, std::ios::end);
length_ = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
buffer_ = new char[sizeof(char) * length_];
ifs.read(buffer_, length_);
ifs.close();
}
int GetLength() {
return length_;
}
char* GetBuffer() {
return buffer_;
}