Commit 0b65559e authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen Committed by Evgeny Kusmenko

Tensorflow group

parent 140646a3
......@@ -168,7 +168,16 @@ public class ArchitectureElementData {
@Nullable
public List<Integer> getPadding(){
return getPadding(getLayerSymbol());
String pad = ((LayerSymbol) getElement()).getStringValue(AllPredefinedLayers.PADDING_NAME).get();
if(pad.equals("same")){
return getPadding(getLayerSymbol()); //The padding calculated here is only used in the gluon/ mxnet backend, in the tensorlflow one it is interpreted as "same"
}else if(pad.equals("valid")){
return Arrays.asList(0,-1,0,0,0,0,0,0);
}else{ //"no loss"
return Arrays.asList(0,0,-1,0,0,0,0,0);
}
}
@Nullable
......@@ -194,4 +203,4 @@ public class ArchitectureElementData {
return Arrays.asList(0,0,0,0,topPad,bottomPad,leftPad,rightPad);
}
}
}
\ No newline at end of file
......@@ -26,7 +26,7 @@ public abstract class ArchitectureSupportChecker {
return true;
}
protected boolean checkMultipleInputs(ArchitectureSymbol architecture) {
if (architecture.getInputs().size() > 1) {
Log.error("This cnn architecture has multiple inputs, " +
......@@ -35,7 +35,7 @@ public abstract class ArchitectureSupportChecker {
return false;
}
return true;
}
......
......@@ -10,7 +10,7 @@ import java.io.Writer;
import java.util.*;
public abstract class CNNArchTemplateController {
public static final String FTL_FILE_ENDING = ".ftl";
public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
public static final String TEMPLATE_CONTROLLER_KEY = "tc";
......@@ -116,7 +116,7 @@ public abstract class CNNArchTemplateController {
public String getDataPath(){
return getArchitecture().getDataPath();
}
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
......
......@@ -106,6 +106,7 @@ public abstract class CNNTrainGenerator {
setInstanceName(compilationUnit.get().getFullName());
CNNTrainCocos.checkAll(compilationUnit.get());
supportCheck(compilationUnit.get().getConfiguration());
return compilationUnit.get().getConfiguration();
}
......
......@@ -93,6 +93,13 @@ public class ConfigurationData {
} else{
return mapToStrings;}
}
public String getLossWeights() {
if (!getConfiguration().getEntryMap().containsKey("loss_weights")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("loss_weights").getValue());
}
public String getOptimizerName() {
if (getConfiguration().getOptimizer() == null) {
......
......@@ -45,7 +45,7 @@ public abstract class TrainParamSupportChecker implements CNNTrainVisitor {
public void visit(ASTTrainContextEntry node){}
public void visit(ASTEvalMetricEntry node){}
public void visit(ASTSGDOptimizer node){}
public void visit(ASTAdamOptimizer node){}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment