Commit de826675 authored by Julian Treiber's avatar Julian Treiber

added weightsPath

parent b9a79c99
Pipeline #267262 failed with stage
in 5 minutes and 44 seconds
......@@ -14,6 +14,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.generator.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.generator.WeightsPathConfigParser;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingComponentParameterAdapter;
......@@ -246,7 +247,7 @@ public class EMADLGenerator {
String b = backend.getBackendString(backend);
String trainingDataHash = "";
String testDataHash = "";
if (architecture.get().getDataPath() != null) {
if (b.equals("CAFFE2")) {
trainingDataHash = getChecksumForLargerFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb");
......@@ -410,6 +411,21 @@ public class EMADLGenerator {
return dataPath;
}
protected String getWeightsPath(EMAComponentSymbol component, EMAComponentInstanceSymbol instance){
String weightsPath;
Path weightsPathDefinition = Paths.get(getModelsPath(), "weights_paths.txt");
if (weightsPathDefinition.toFile().exists()) {
WeightsPathConfigParser newParserConfig = new WeightsPathConfigParser(getModelsPath() + "weights_paths.txt");
weightsPath = newParserConfig.getWeightsPath(component.getFullName());
} else {
Log.warn("No weights path definition found in " + weightsPathDefinition + " found: "
+ "No pretrained weights will be loaded.");
weightsPath = null;
}
return weightsPath;
}
protected void generateComponent(List<FileContent> fileContents,
Set<EMAComponentInstanceSymbol> allInstances,
TaggingResolver taggingResolver,
......@@ -431,7 +447,9 @@ public class EMADLGenerator {
if (architecture.isPresent()){
cnnArchGenerator.check(architecture.get());
String dPath = getDataPath(taggingResolver, EMAComponentSymbol, componentInstanceSymbol);
String wPath = getWeightsPath(EMAComponentSymbol, componentInstanceSymbol);
architecture.get().setDataPath(dPath);
architecture.get().setWeightsPath(wPath);
architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
if (processedArchitecture != null) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment