Commit 849f0e36 authored by Sebastian Nickels's avatar Sebastian Nickels

Changed OneHot layer and added support for constants

parent 844877c0
Pipeline #155896 failed with stages
in 2 minutes and 57 seconds
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
......@@ -90,12 +91,11 @@ public class ArchitectureElementData {
return getTemplateController().isSoftmaxOutput(getElement());
}
public boolean isOneHotOutput(){
return getTemplateController().isOneHotOutput(getElement());
public int getConstValue() {
ConstantSymbol constant = (ConstantSymbol) getElement();
return constant.getExpression().getIntValue().get();
}
public List<Integer> getKernel(){
return ((LayerSymbol) getElement())
.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
......@@ -162,13 +162,8 @@ public class ArchitectureElementData {
}
public int getSize(){
if(getElement().isOutput()) {
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
}else{
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
}
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
}
@Nullable
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.List;
public class ArchitectureSupportChecker {
public ArchitectureSupportChecker() {}
......@@ -57,10 +63,44 @@ public class ArchitectureSupportChecker {
return true;
}
protected boolean hasConstant(ArchitectureElementSymbol element) {
ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
if (resolvedElement instanceof CompositeElementSymbol) {
List<ArchitectureElementSymbol> constructedElements = ((CompositeElementSymbol) resolvedElement).getElements();
for (ArchitectureElementSymbol constructedElement : constructedElements) {
if (hasConstant(constructedElement)) {
return true;
}
}
}
else if (resolvedElement instanceof ConstantSymbol) {
return true;
}
return false;
}
protected boolean checkConstants(ArchitectureSymbol architecture) {
for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
for (ArchitectureElementSymbol element : stream.getElements()) {
if (hasConstant(element)) {
Log.error("This cnn architecture has a constant, which is currently not supported by the code generator."
, architecture.getSourcePosition());
return false;
}
}
}
return true;
}
public boolean check(ArchitectureSymbol architecture) {
return checkMultipleStreams(architecture)
&& checkMultipleInputs(architecture)
&& checkMultipleOutputs(architecture)
&& checkMultiDimensionalOutput(architecture);
&& checkMultiDimensionalOutput(architecture)
&& checkConstants(architecture);
}
}
......@@ -36,7 +36,7 @@ public class CNNArch2MxNetTemplateController extends CNNArchTemplateController {
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement) && !isOneHotOutput(nextElement)){
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
......@@ -61,7 +61,7 @@ public class CNNArch2MxNetTemplateController extends CNNArchTemplateController {
public void include(ArchitectureElementSymbol architectureElement, Writer writer){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer);
} else if (architectureElement instanceof LayerSymbol){
} else if (architectureElement instanceof LayerSymbol) {
include((LayerSymbol) architectureElement, writer);
} else {
include((IOSymbol) architectureElement, writer);
......
......@@ -23,7 +23,6 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.monticore.lang.monticar.cnnarch.predefined.OneHot;
import java.io.StringWriter;
import java.io.Writer;
......@@ -140,7 +139,7 @@ public abstract class CNNArchTemplateController {
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer) || isOneHotOutput(layer)){
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
inputNames = getLayerInputs(layer.getInputElement().get());
} else {
for (ArchitectureElementSymbol input : layer.getPrevious()) {
......@@ -229,12 +228,7 @@ public abstract class CNNArchTemplateController {
public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){
return architectureElement.isOutput()
&& !isLogisticRegressionOutput(architectureElement)
&& !isSoftmaxOutput(architectureElement)
&& !isOneHotOutput(architectureElement);
}
public boolean isOneHotOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(OneHot.class, architectureElement);
&& !isSoftmaxOutput(architectureElement);
}
public boolean isSoftmaxOutput(ArchitectureElementSymbol architectureElement){
......
......@@ -4,6 +4,7 @@ import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.se_rwth.commons.logging.Log;
......@@ -34,6 +35,11 @@ public abstract class LayerSupportChecker {
return true;
}
// Support for constants is checked in ArchitectureSupportChecker
if (resolvedElement instanceof ConstantSymbol) {
return true;
}
// Support all layer declarations
if (resolvedElement instanceof LayerSymbol) {
if (!((LayerSymbol) resolvedElement).getDeclaration().isPredefined()) {
......
<#assign size = element.size>
${element.name} = mx.symbol.one_hot(data=${element.inputs[0]},
indices=mx.symbol.argmax(data=${element.inputs[0]}, axis=1), depth=${size})
<#include "OutputShape.ftl">
\ No newline at end of file
......@@ -8,7 +8,4 @@
<#elseif element.linearRegressionOutput>
${element.name} = mx.symbol.LinearRegressionOutput(data=${element.inputs[0]},
name="${element.name}")
<#elseif element.oneHotOutput>
${element.name} = mx.symbol.SoftmaxOutput(data=${element.inputs[0]},
name="${element.name}")
</#if>
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment