Commit f37dd998 authored by Christian Fuß's avatar Christian Fuß
Browse files

reworked OneHotLayer to take argument size. Added Stack layer for unrolling

parent c5ca1033
Pipeline #146187 failed with stages
in 3 minutes and 6 seconds
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -79,7 +79,7 @@ public class ArchitectureElementData {
boolean result = getTemplateController().isLinearRegressionOutput(getElement());
if (result){
Log.warn("The Output '" + getElement().getName() + "' is a linear regression output (squared loss) during training" +
" because the previous architecture element is not a softmax (cross-entropy loss) or sigmoid (logistic regression loss) activation. " +
" because the previous architecture element is not a softmax (cross-entropy loss), one_hot or sigmoid (logistic regression loss) activation. " +
"Other loss functions are currently not supported. "
, getElement().getSourcePosition());
}
......@@ -91,6 +91,9 @@ public class ArchitectureElementData {
}
public boolean isOneHotOutput(){
return getTemplateController().isOneHotOutput(getElement());
}
public List<Integer> getKernel(){
......@@ -158,20 +161,16 @@ public class ArchitectureElementData {
.getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
}
@Nullable
public String getPoolType(){
public double getSize(){
return ((LayerSymbol) getElement())
.getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE).get();
}
public int getOneHotIndex(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_INDEX_NAME).get();
}
public int getOneHotSize(){
@Nullable
public String getPoolType(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE).get();
.getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
}
@Nullable
......
......@@ -37,7 +37,7 @@ public class CNNArch2MxNetTemplateController extends CNNArchTemplateController {
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement) && !isOneHotOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
......
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.OneHot;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
......@@ -139,7 +140,7 @@ public abstract class CNNArchTemplateController {
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer) || isOneHotOutput(layer)){
inputNames = getLayerInputs(layer.getInputElement().get());
} else {
for (ArchitectureElementSymbol input : layer.getPrevious()) {
......@@ -228,7 +229,8 @@ public abstract class CNNArchTemplateController {
public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){
return architectureElement.isOutput()
&& !isLogisticRegressionOutput(architectureElement)
&& !isSoftmaxOutput(architectureElement);
&& !isSoftmaxOutput(architectureElement)
&& !isOneHotOutput(architectureElement);
}
......@@ -236,6 +238,11 @@ public abstract class CNNArchTemplateController {
return isTOutput(Softmax.class, architectureElement);
}
public boolean isOneHotOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(OneHot.class, architectureElement);
}
private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){
if (architectureElement.isOutput()
&& architectureElement.getInputElement().isPresent()
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment