Commit a5b546e2 authored by Sebastian Nickels's avatar Sebastian Nickels

Changed OneHot layer and added support for constants

parent 6e5d8673
Pipeline #155895 passed with stages
in 18 minutes and 44 seconds
......@@ -69,6 +69,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
IOElement implements ArchitectureElement = Name ("[" index:ArchSimpleExpression "]")?;
Constant implements ArchitectureElement = ArchSimpleExpression;
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = "unroll" "<" timeParameter:LayerParameter ">"
......
......@@ -20,7 +20,6 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.expressionsbasis._ast.ASTExpression;
import de.monticore.lang.math._symboltable.MathSymbolTableCreator;
import de.monticore.lang.math._symboltable.expression.*;
......@@ -427,6 +426,17 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
addToScopeAndLinkWithNode(argument, node);
}
public void visit(ASTConstant node) {
ConstantSymbol constant = new ConstantSymbol();
addToScopeAndLinkWithNode(constant, node);
}
public void endVisit(ASTConstant node) {
ConstantSymbol constant = (ConstantSymbol) node.getSymbolOpt().get();
constant.setExpression((ArchSimpleExpressionSymbol) node.getArchSimpleExpression().getSymbolOpt().get());
removeCurrentScope();
}
public void visit(ASTIOElement node) {
IOSymbol ioElement = new IOSymbol(node.getName());
addToScopeAndLinkWithNode(ioElement, node);
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.helper.Utils;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.*;
public class ConstantSymbol extends ArchitectureElementSymbol {
private ArchSimpleExpressionSymbol expression = null;
protected ConstantSymbol() {
super("const");
}
public ArchSimpleExpressionSymbol getExpression() {
return expression;
}
protected void setExpression(ArchSimpleExpressionSymbol expression) {
this.expression = expression;
}
@Override
public boolean isResolvable() {
return super.isResolvable();
}
@Override
public boolean isAtomic() {
return getResolvedThis().isPresent() && getResolvedThis().get() == this;
}
@Override
public List<ArchitectureElementSymbol> getFirstAtomicElements() {
if (getResolvedThis().isPresent() && getResolvedThis().get() != this) {
return getResolvedThis().get().getFirstAtomicElements();
}
else {
return Collections.singletonList(this);
}
}
@Override
public List<ArchitectureElementSymbol> getLastAtomicElements() {
if (getResolvedThis().isPresent() && getResolvedThis().get() != this) {
return getResolvedThis().get().getLastAtomicElements();
}
else {
return Collections.singletonList(this);
}
}
@Override
public Set<VariableSymbol> resolve() throws ArchResolveException {
if (!isResolved()) {
if (isResolvable()) {
resolveExpressions();
setResolvedThis(this);
}
}
return getUnresolvableVariables();
}
@Override
protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables) {
getExpression().checkIfResolvable(allVariables);
unresolvableVariables.addAll(getExpression().getUnresolvableVariables());
}
@Override
public List<ArchTypeSymbol> computeOutputTypes() {
List<ArchTypeSymbol> outputShapes;
if (isAtomic()) {
ArchTypeSymbol outputShape = new ArchTypeSymbol();
// Since symbol is resolved at this point, it is safe to assume that the expression is an int
int value = getExpression().getIntValue().get();
ASTRange range = new ASTRange();
range.setStartValue(String.valueOf(value));
range.setEndValue(String.valueOf(value));
ASTElementType domain = new ASTElementType("Z", Optional.of(range));
outputShape.setDomain(domain);
outputShapes = Collections.singletonList(outputShape);
}
else {
if (!getResolvedThis().isPresent()){
throw new IllegalStateException("The architecture resolve() method was never called");
}
outputShapes = getResolvedThis().get().computeOutputTypes();
}
return outputShapes;
}
@Override
public void checkInput() {
if (isAtomic()) {
if (!getInputTypes().isEmpty()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid number of input streams. "
, getSourcePosition());
}
}
else {
if (!getResolvedThis().isPresent()) {
throw new IllegalStateException("The architecture resolve() method was never called");
}
getResolvedThis().get().checkInput();
}
}
@Override
public Optional<Integer> getParallelLength() {
return Optional.of(1);
}
@Override
public Optional<List<Integer>> getSerialLengths() {
return Optional.of(Collections.nCopies(getParallelLength().get(), 1));
}
@Override
protected void putInScope(Scope scope) {
Collection<Symbol> symbolsInScope = scope.getLocalSymbols().get(getName());
if (symbolsInScope == null || !symbolsInScope.contains(this)) {
scope.getAsMutableScope().add(this);
getExpression().putInScope(getSpannedScope());
}
}
@Override
protected void resolveExpressions() throws ArchResolveException {
getExpression().resolveOrError();
if (!Constraints.INTEGER.check(getExpression(), getSourcePosition(), getName())) {
throw new ArchResolveException();
}
}
@Override
protected ArchitectureElementSymbol preResolveDeepCopy() {
ConstantSymbol copy = new ConstantSymbol();
if (getAstNode().isPresent()) {
copy.setAstNode(getAstNode().get());
}
copy.setExpression(getExpression());
return copy;
}
}
......@@ -91,6 +91,36 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected void errorIfInputChannelSizeIsInvalid(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, int channels) {
for (ArchTypeSymbol inputType : inputTypes) {
if (inputType.getChannels() != channels) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Input channel size is "
+ inputType.getChannels() + " but needs to be " + channels + "."
, layer.getSourcePosition());
}
}
}
protected void errorIfInputHeightIsInvalid(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, int height) {
for (ArchTypeSymbol inputType : inputTypes) {
if (inputType.getHeight() != height) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Input height is "
+ inputType.getHeight() + " but needs to be " + height + "."
, layer.getSourcePosition());
}
}
}
protected void errorIfInputWidthIsInvalid(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, int width) {
for (ArchTypeSymbol inputType : inputTypes) {
if (inputType.getWidth() != width) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Input width is "
+ inputType.getWidth() + " but needs to be " + width + "."
, layer.getSourcePosition());
}
}
}
//check input for convolution and pooling
protected static void errorIfInputSmallerThanKernel(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (!inputTypes.isEmpty()) {
......@@ -116,22 +146,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
//check input for onehot layer
protected static void errorIfInputSizeUnequalToOnehotSize(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (!inputTypes.isEmpty() && layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get() != 0) {
int inputChannels = inputTypes.get(0).getChannels();
int onehotSize = layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
if (onehotSize != inputChannels){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE +
"The size of the onehot vector is not equal to the output size of the previous layer." +
"This is usually not intended."
, layer.getSourcePosition());
}
}
}
//output type function for convolution and pooling
protected static List<ArchTypeSymbol> computeConvAndPoolOutputShape(ArchTypeSymbol inputType, LayerSymbol method, int channels) {
String borderModeSetting = method.getStringValue(AllPredefinedLayers.PADDING_NAME).get();
......
......@@ -63,7 +63,7 @@ public class AllPredefinedLayers {
public static final String BETA_NAME = "beta";
public static final String PADDING_NAME = "padding";
public static final String POOL_TYPE_NAME = "pool_type";
public static final String ONE_HOT_SIZE_NAME = "size";
public static final String SIZE_NAME = "size";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
......
......@@ -21,11 +21,13 @@
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.lang.monticar.ranges._ast.ASTRangeStepResolution;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.*;
public class OneHot extends PredefinedLayerDeclaration {
......@@ -35,14 +37,12 @@ public class OneHot extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ONE_HOT_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
channels=layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
channels = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get())
.channels(layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get())
.height(1)
.width(1)
.elementType("0", "1")
......@@ -52,14 +52,58 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputSizeUnequalToOnehotSize(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1);
errorIfInputHeightIsInvalid(inputTypes, layer, 1);
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
// Check range of input
ASTElementType domain = inputTypes.get(0).getDomain();
if (!domain.isNaturalNumber() && !domain.isWholeNumber()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input needs to be natural or whole. "
, layer.getSourcePosition());
}
else {
if (!domain.getRangeOpt().isPresent()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Range is missing. "
, layer.getSourcePosition());
}
else {
ASTRange range = domain.getRangeOpt().get();
if (!range.getMin().getNumber().isPresent() || !range.getMax().getNumber().isPresent()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Minimum or maximum is missing. "
, layer.getSourcePosition());
}
else {
double min = range.getMin().getNumber().get();
double max = range.getMax().getNumber().get();
// Check if minimum >= 0
if (min < 0) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Minimum needs to be bigger than 0. "
, layer.getSourcePosition());
}
int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
// Check if maximum < size
if (max >= size) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: "
+ "Maximum needs to be smaller than size " + size + ". "
, layer.getSourcePosition());
}
}
}
}
}
public static OneHot create(){
OneHot declaration = new OneHot();
List<VariableSymbol> parameters = new ArrayList<>(Arrays.asList(
new VariableSymbol.Builder()
.name(AllPredefinedLayers.ONE_HOT_SIZE_NAME)
.name(AllPredefinedLayers.SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER)
.defaultValue(channels)
.build()));
......
......@@ -55,6 +55,5 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Concatenate() ->
FullyConnected(units=10) ->
Softmax() ->
OneHot() ->
predictions;
}
\ No newline at end of file
......@@ -49,6 +49,5 @@ architecture Alexnet_alt_OneHotOutput(img_height=224, img_width=224, img_channel
Dropout() ->
FullyConnected(units=classes) ->
Softmax() ->
OneHot(size=classes) ->
predictions;
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment