Commit 2c927e9a authored by Christian Fuß's avatar Christian Fuß

re-added ability to leave out size parameter in OneHot-layer.

parent 4a2199ac
Pipeline #158621 passed with stages
in 18 minutes and 59 seconds
......@@ -392,7 +392,14 @@ All predefined methods start with a capital letter and all constructed methods h
Opposite of *Concatenate*. Handles a single input stream and splits it into *n* output streams.
The output streams have the same height and width as the input stream and a number channels which is in general `input_channels / n`.
The last output stream will have a higher number of channels than the other if `input_channels` is not divisible by `n`.
* **n** (integer > 0, required): The number of output streams. Cannot be higher than the number of input channels.
* **OneHot(size)**
Creates a OneHot vector of a given size, given a scalar in the previous layer that determines the OneHot-Index (the index at which the *1* in the vector will be placed).
* **size** (integer > 0, optional): The OneHot-vector's size. Can be omitted to automatically use the output size of the architecture.
......@@ -163,6 +163,16 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected static void computeOneHotOutputSize(LayerSymbol layer){
int outputChannels = 0;
if(layer.getOutputElement().get() instanceof IOSymbol && layer.getOutputElement().get().isOutput()) {
outputChannels = ((IOSymbol) layer.getOutputElement().get()).getDefinition().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
//padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
......
......@@ -23,7 +23,6 @@ package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.lang.monticar.ranges._ast.ASTRangeStepResolution;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import de.se_rwth.commons.logging.Log;
......@@ -35,26 +34,13 @@ public class OneHot extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ONE_HOT_NAME);
}
int size;
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
// TODO: Execute this code somewhere before checkInput(), for now size parameter is required
/*if(layer.getOutputElement().get() instanceof IOSymbol && layer.getOutputElement().get().isOutput()) {
int outputChannels = ((IOSymbol) layer.getOutputElement().get()).getDefinition().getType().getChannels();
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}*/
int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
/*if (size == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition());
}*/
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(size)
.channels(layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get())
.height(1)
.width(1)
.elementType("0", "1")
......@@ -63,6 +49,10 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
computeOneHotOutputSize(layer);
size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1);
errorIfInputHeightIsInvalid(inputTypes, layer, 1);
......@@ -71,6 +61,12 @@ public class OneHot extends PredefinedLayerDeclaration {
// Check range of input
ASTElementType domain = inputTypes.get(0).getDomain();
if (layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get() == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition());
}
if (!domain.isNaturalNumber() && !domain.isWholeNumber()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input needs to be natural or whole. "
, layer.getSourcePosition());
......@@ -117,6 +113,7 @@ public class OneHot extends PredefinedLayerDeclaration {
new VariableSymbol.Builder()
.name(AllPredefinedLayers.SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER)
.defaultValue(declaration.size)
.build()));
declaration.setParameters(parameters);
return declaration;
......
......@@ -2,15 +2,6 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
def input Z(0:255)^{img_channels, img_height, img_width} data
def output Q(0:1)^{classes} predictions
unroll<t=5> beamSearchStart (width=5, max_length=50){
FullyConnected(units=4096) ->
Relu() ->
Dropout()
}
def split1(i){
[i] ->
Convolution(kernel=(5,5), channels=128) ->
......@@ -33,8 +24,6 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
}
data ->
Convolution(kernel=(11,11), channels=96, stride=(4,4), padding="no_loss") ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment