Commit 3e611cc4 authored by Malte Heithoff's avatar Malte Heithoff
Browse files

Merge branch 'GeneratorInterface' into 'master'

Generator interface

See merge request !31
parents 68ee368d 2f7ea16d
Pipeline #336027 passed with stage
in 1 minute and 49 seconds
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.12-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,10 +17,7 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.12-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.7-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<CNNArch2X.version>0.4.0-SNAPSHOT</CNNArch2X.version>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
......@@ -69,40 +66,6 @@
<version>${CNNArch2X.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>${CNNArch.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>${CNNArch.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-math-opt-generator</artifactId>
<version>${embedded-montiarc-math-opt-generator}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-pythonwrapper-generator</artifactId>
......
......@@ -9,7 +9,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
import de.se_rwth.commons.logging.Log;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import java.util.*;
......@@ -20,99 +20,105 @@ public class CNNArch2Gluon extends CNNArchGenerator {
layerSupportChecker = new CNNArch2GluonLayerSupportChecker();
}
@Override
public CMakeConfig getCmakeConfig() {
return null;
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
@Override
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = compileFileContentMap(architecture);
return fileContentMap;
public List<FileContent> generateStrings(TaggingResolver taggingResolver, ArchitectureSymbol architecture){
List<FileContent> fileContents = compileFileContents(architecture);
return fileContents;
}
public Map<String, String> generateStringsAllowMultipleIO(ArchitectureSymbol architecture, Boolean pythonFilesOnly) {
Map<String, String> fileContentMap;
public List<FileContent> generateStringsAllowMultipleIO(ArchitectureSymbol architecture, Boolean pythonFilesOnly) {
List<FileContent> fileContents;
if (pythonFilesOnly) {
fileContentMap = compilePythonFilesOnlyContentMap(architecture);
fileContents = compilePythonFilesOnlyContents(architecture);
} else {
fileContentMap = compileFileContentMap(architecture);
fileContents = compileFileContents(architecture);
}
return fileContentMap;
return fileContents;
}
private Map<String, String> compilePythonFiles(CNNArch2GluonTemplateController controller, ArchitectureSymbol architecture) {
Map<String, String> fileContentMap = new HashMap<>();
Map.Entry<String, String> temp;
private List<FileContent> compilePythonFiles(CNNArch2GluonTemplateController controller, ArchitectureSymbol architecture) {
List<FileContent> fileContents = new ArrayList<>();
FileContent temp;
temp = controller.process("CNNNet", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
fileContents.add(temp);
if (architecture.getDataPath() != null) {
temp = controller.process("CNNDataLoader", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
fileContents.add(temp);
}
temp = controller.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
fileContents.add(temp);
return fileContentMap;
return fileContents;
}
private Map<String, String> compileCppFiles(CNNArch2GluonTemplateController controller) {
Map<String, String> fileContentMap = new HashMap<>();
Map.Entry<String, String> temp;
private List<FileContent> compileCppFiles(CNNArch2GluonTemplateController controller) {
List<FileContent> fileContents = new ArrayList<>();
FileContent temp;
temp = controller.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
fileContents.add(temp);
temp = controller.process("CNNSupervisedTrainer", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
fileContents.add(temp);
temp = controller.process("CNNGanTrainer", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
fileContents.add(temp);
temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = new FileContent(temp.getFileContent(), temp.getFileName().replace(".h",""));
fileContents.add(temp);
temp = controller.process("CNNModelLoader", Target.CPP);
fileContentMap.put("CNNModelLoader.h", temp.getValue());
fileContents.add(new FileContent(temp.getFileContent(), "CNNModelLoader.h"));
return fileContentMap;
return fileContents;
}
private Map<String, String> compileFileContentMap(ArchitectureSymbol architecture) {
private List<FileContent> compileFileContents(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
architecture.processForEpisodicReplayMemory();
Map<String, String> fileContentMap = new HashMap<>();
List<FileContent> fileContents = new ArrayList<>();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
fileContentMap.putAll(compilePythonFiles(archTc, architecture));
fileContentMap.putAll(compileCppFiles(archTc));
fileContents.addAll(compilePythonFiles(archTc, architecture));
fileContents.addAll(compileCppFiles(archTc));
return fileContentMap;
return fileContents;
}
private Map<String, String> compilePythonFilesOnlyContentMap(ArchitectureSymbol architecture) {
private List<FileContent> compilePythonFilesOnlyContents(ArchitectureSymbol architecture) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
return compilePythonFiles(archTc, architecture);
}
public Map<String, String> generateCMakeContent(String rootModelName) {
public List<FileContent> generateCMakeContent(String rootModelName) {
// model name should start with a lower case letter. If it is a component, replace dot . by _
rootModelName = rootModelName.replace('.', '_').replace('[', '_').replace(']', '_');
rootModelName = rootModelName.substring(0, 1).toLowerCase() + rootModelName.substring(1);
CMakeConfig cMakeConfig = new CMakeConfig(rootModelName);
cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true));
cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} mxnet)");
cMakeConfig.addCmakeLibraryLinkage("mxnet");
Map<String,String> fileContentMap = new HashMap<>();
List<FileContent> fileContents = new ArrayList<>();
for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){
fileContentMap.put(fileContent.getFileName(), fileContent.getFileContent());
fileContents.add(fileContent);
}
return fileContentMap;
return fileContents;
}
}
......@@ -2,12 +2,9 @@
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingComponentParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingPortChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
......@@ -18,7 +15,6 @@ import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
......@@ -74,16 +70,13 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
private void generateFilesFromConfigurationSymbol(ConfigurationSymbol configuration) {
Map<String, String> fileContents = this.generateStrings(configuration);
List<FileContent> fileContents = this.generateStrings(configuration);
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(this.getGenerationTargetPath());
try {
Iterator var6 = fileContents.keySet().iterator();
while (var6.hasNext()) {
String fileName = (String) var6.next();
genCPP.generateFile(new FileContent((String) fileContents.get(fileName), fileName));
for (FileContent fileContent : fileContents) {
genCPP.generateFile(fileContent);
}
} catch (IOException var8) {
Log.error("CNNTrainer file could not be generated" + var8.getMessage());
......@@ -119,7 +112,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
@Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
public List<FileContent> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
GluonConfigurationData configData = new GluonConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>();
......@@ -128,11 +121,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
Map<String, Object> ftlContext = Maps.newHashMap();
ftlContext.put("configurations", configDataList);
Map<String, String> fileContentMap = new HashMap<>();
List<FileContent> fileContents = new ArrayList<>();
//Context Information and Optimizer for local adaption during prediction for replay memory layer (the second only applicaple for supervised learning)
String cnnTrainLAOptimizerTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNLAOptimizer.ftl");
fileContentMap.put("CNNLAOptimizer_" + getInstanceName() + ".h", cnnTrainLAOptimizerTemplateContent);
fileContents.add(new FileContent(cnnTrainLAOptimizerTemplateContent, "CNNLAOptimizer_" + getInstanceName() + ".h"));
//AdamW optimizer if used for training
if(configuration.getOptimizer() != null) {
......@@ -144,13 +137,13 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
if (optimizerName.equals("adamw") || criticOptimizerName.equals("adamw")) {
String adamWContent = templateConfiguration.processTemplate(ftlContext, "Optimizer/AdamW.ftl");
fileContentMap.put("AdamW.py", adamWContent);
fileContents.add(new FileContent(adamWContent, "AdamW.py"));
}
}
if (configData.isSupervisedLearning()) {
String cnnTrainTrainerTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTrainerTemplateContent);
fileContents.add(new FileContent(cnnTrainTrainerTemplateContent, "CNNTrainer_" + getInstanceName() + ".py"));
} else if (configData.isGan()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
if (!configuration.getDiscriminatorNetwork().isPresent()) {
......@@ -166,34 +159,32 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
gluonGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), GAN_LEARNING_FRAMEWORK_MODULE).toString());
Map<String, String> disArchitectureFileContentMap
List<FileContent> disArchitectureFileContents
= gluonGenerator.generateStringsAllowMultipleIO(disArchitectureSymbol, true);
final String disCreatorName = disArchitectureFileContentMap.keySet().iterator().next();
final String disCreatorName = disArchitectureFileContents.get(0).getFileName();
final String discriminatorInstanceName = disCreatorName.substring(
disCreatorName.indexOf('_') + 1, disCreatorName.lastIndexOf(".py"));
fileContentMap.putAll(disArchitectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
fileContents.addAll(disArchitectureFileContents.stream()
.map(k -> new FileContent(k.getFileContent(), GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getFileName()))
.collect(Collectors.toList()));
if (configuration.hasQNetwork()) {
NNArchitectureSymbol genericQArchitectureSymbol = configuration.getQNetwork().get();
ArchitectureSymbol qArchitectureSymbol
= ((ArchitectureAdapter) genericQArchitectureSymbol).getArchitectureSymbol();
Map<String, String> qArchitectureFileContentMap
List<FileContent> qArchitectureFileContents
= gluonGenerator.generateStringsAllowMultipleIO(qArchitectureSymbol, true);
final String qCreatorName = qArchitectureFileContentMap.keySet().iterator().next();
final String qCreatorName = qArchitectureFileContents.get(0).getFileName();
final String qNetworkInstanceName = qCreatorName.substring(
qCreatorName.indexOf('_') + 1, qCreatorName.lastIndexOf(".py"));
fileContentMap.putAll(qArchitectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
fileContents.addAll(qArchitectureFileContents.stream()
.map(k -> new FileContent(k.getFileContent(), GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getFileName()))
.collect(Collectors.toList()));
ftlContext.put("qNetworkInstanceName", qNetworkInstanceName);
}
......@@ -203,10 +194,10 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
ftlContext.put("trainerName", trainerName);
final String initContent = "";
fileContentMap.put(GAN_LEARNING_FRAMEWORK_MODULE + "/__init__.py", initContent);
fileContents.add(new FileContent(initContent, GAN_LEARNING_FRAMEWORK_MODULE + "/__init__.py"));
final String ganTrainerContent = templateConfiguration.processTemplate(ftlContext, "gan/Trainer.ftl");
fileContentMap.put(trainerName + ".py", ganTrainerContent);
fileContents.add(new FileContent(ganTrainerContent, trainerName + ".py"));
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
......@@ -225,18 +216,17 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
Map<String, String> architectureFileContentMap
List<FileContent> architectureFileContents
= gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true);
final String creatorName = architectureFileContentMap.keySet().iterator().next();
final String creatorName = architectureFileContents.get(0).getFileName();
final String criticInstanceName = creatorName.substring(
creatorName.indexOf('_') + 1, creatorName.lastIndexOf(".py"));
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
fileContents.addAll(architectureFileContents.stream()
.map(k -> new FileContent(k.getFileContent(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getFileName()))
.collect(Collectors.toList()));
ftlContext.put("criticInstanceName", criticInstanceName);
}
......@@ -254,16 +244,16 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
ftlContext.put("trainerName", trainerName);
Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext, rlAlgorithm);
fileContentMap.putAll(rlFrameworkContentMap);
List<FileContent> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext, rlAlgorithm);
fileContents.addAll(rlFrameworkContentMap);
final String reinforcementTrainerContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/Trainer.ftl");
fileContentMap.put(trainerName + ".py", reinforcementTrainerContent);
fileContents.add(new FileContent(reinforcementTrainerContent, trainerName + ".py"));
final String startTrainerScriptContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/StartTrainer.ftl");
fileContentMap.put("start_training.sh", startTrainerScriptContent);
fileContents.add(new FileContent(startTrainerScriptContent, "start_training.sh"));
}
return fileContentMap;
return fileContents;
}
private void generateRewardFunction(NNArchitectureSymbol trainedArchitecture,
......@@ -318,41 +308,41 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
}
private Map<String, String> constructReinforcementLearningFramework(
private List<FileContent> constructReinforcementLearningFramework(
final TemplateConfiguration templateConfiguration,
final Map<String, Object> ftlContext,
RLAlgorithm rlAlgorithm) {
Map<String, String> fileContentMap = Maps.newHashMap();
List<FileContent> fileContents = new ArrayList<>();
ftlContext.put("rlFrameworkModule", REINFORCEMENT_LEARNING_FRAMEWORK_MODULE);
final String loggerContent = templateConfiguration.processTemplate(ftlContext,
"reinforcement/util/Logger.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/cnnarch_logger.py", loggerContent);
fileContents.add(new FileContent(loggerContent, REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/cnnarch_logger.py"));
final String reinforcementAgentContent = templateConfiguration.processTemplate(ftlContext,
"reinforcement/agent/Agent.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/agent.py", reinforcementAgentContent);
fileContents.add(new FileContent(reinforcementAgentContent, REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/agent.py"));
final String reinforcementStrategyContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/Strategy.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/strategy.py", reinforcementStrategyContent);
fileContents.add(new FileContent(reinforcementStrategyContent, REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/strategy.py"));
final String replayMemoryContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/ReplayMemory.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/replay_memory.py", replayMemoryContent);
fileContents.add(new FileContent(replayMemoryContent, REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/replay_memory.py"));
final String environmentContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/environment/Environment.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/environment.py", environmentContent);
fileContents.add(new FileContent(environmentContent, REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/environment.py"));
final String utilContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/util/Util.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/util.py", utilContent);
fileContents.add(new FileContent(utilContent, REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/util.py"));
final String initContent = "";
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/__init__.py", initContent);
fileContents.add(new FileContent(initContent, REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/__init__.py"));
return fileContentMap;
return fileContents;
}
}
......@@ -42,8 +42,8 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Paths.get("target/generated-sources-cnnarch"),
Paths.get("src/test/resources/target_code"),
Arrays.asList(
"CNNCreator_CifarClassifierNetwork.py",
"CNNNet_CifarClassifierNetwork.py",
......@@ -435,8 +435,8 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue(Log.getFindings().stream().noneMatch(Finding::isError));
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code/default-gan"),
Paths.get("target/generated-sources-cnnarch"),
Paths.get("src/test/resources/target_code/default-gan"),
Arrays.asList(
"gan/CNNCreator_Discriminator.py",
"gan/CNNNet_Discriminator.py",
......@@ -462,8 +462,8 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue(Log.getFindings().stream().noneMatch(Finding::isError));
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code/info-gan"),
Paths.get("target/generated-sources-cnnarch"),
Paths.get("src/test/resources/target_code/info-gan"),
Arrays.asList(
"gan/CNNCreator_InfoDiscriminator.py",
"gan/CNNNet_InfoDiscriminator.py",
......
......@@ -11,9 +11,12 @@ find_package(Armadillo REQUIRED)
set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Armadillo_INCLUDE_DIRS})
set(LIBS ${LIBS} ${Armadillo_LIBRARIES})
# additional commands
# additional library linkage
set(LIBS ${LIBS} mxnet)
# additional commands
# create static library
include_directories(${INCLUDE_DIRS})
add_library(alexnet alexnet.cpp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment