Commit 6f26fc6a authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Add Eyüp's changes from CNNArch2MXNet

See merge request !2
parents f35249bb 8e6562a8
Pipeline #161765 canceled with stages
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>0.0.1-SNAPSHOT</version>
<version>0.0.2-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,7 +16,7 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.2-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -32,7 +32,7 @@ import java.util.Arrays;
import java.util.List;
public class ArchitectureElementData {
private String name;
private ArchitectureElementSymbol element;
private CNNArchTemplateController templateController;
......@@ -71,26 +71,6 @@ public class ArchitectureElementData {
return getTemplateController().getLayerInputs(getElement());
}
public boolean isLogisticRegressionOutput(){
return getTemplateController().isLogisticRegressionOutput(getElement());
}
public boolean isLinearRegressionOutput(){
boolean result = getTemplateController().isLinearRegressionOutput(getElement());
if (result){
Log.warn("The Output '" + getElement().getName() + "' is a linear regression output (squared loss) during training" +
" because the previous architecture element is not a softmax (cross-entropy loss) or sigmoid (logistic regression loss) activation. " +
"Other loss functions are currently not supported. "
, getElement().getSourcePosition());
}
return result;
}
public boolean isSoftmaxOutput(){
return getTemplateController().isSoftmaxOutput(getElement());
}
public int getConstValue() {
ConstantSymbol constant = (ConstantSymbol) getElement();
return constant.getExpression().getIntValue().get();
......
......@@ -139,16 +139,12 @@ public abstract class CNNArchTemplateController {
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
inputNames = getLayerInputs(layer.getInputElement().get());
} else {
for (ArchitectureElementSymbol input : layer.getPrevious()) {
if (input.getOutputTypes().size() == 1) {
inputNames.add(getName(input));
} else {
for (int i = 0; i < input.getOutputTypes().size(); i++) {
inputNames.add(getName(input) + "[" + i + "]");
}
for (ArchitectureElementSymbol input : layer.getPrevious()) {
if (input.getOutputTypes().size() == 1) {
inputNames.add(getName(input));
} else {
for (int i = 0; i < input.getOutputTypes().size(); i++) {
inputNames.add(getName(input) + "[" + i + "]");
}
}
}
......@@ -220,28 +216,4 @@ public abstract class CNNArchTemplateController {
return stringBuilder.toString();
}
public boolean isLogisticRegressionOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(Sigmoid.class, architectureElement);
}
public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){
return architectureElement.isOutput()
&& !isLogisticRegressionOutput(architectureElement)
&& !isSoftmaxOutput(architectureElement);
}
public boolean isSoftmaxOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(Softmax.class, architectureElement);
}
private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){
if (architectureElement.isOutput()
&& architectureElement.getInputElement().isPresent()
&& architectureElement.getInputElement().get() instanceof LayerSymbol){
LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
return inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration());
}
return false;
}
}
......@@ -67,6 +67,32 @@ public class ConfigurationData {
return getConfiguration().getEntry("eval_metric").getValue().toString();
}
public String getLossName() {
if (getConfiguration().getLoss() == null) {
return null;
}
return getConfiguration().getLoss().getName();
}
public Map<String, String> getLossParams() {
Map<String, String> mapToStrings = new HashMap<>();
Map<String, LossParamSymbol> lossParams = getConfiguration().getLoss().getLossParamMap();
for (Map.Entry<String, LossParamSymbol> entry : lossParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
mapToStrings.put(paramName, valueAsString);
}
if (mapToStrings.isEmpty()){
return null;
} else{
return mapToStrings;}
}
public String getOptimizerName() {
if (getConfiguration().getOptimizer() == null) {
return null;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment