Commit 635831fa authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Merge

parents aac22a6d 0d4c4530
Pipeline #170395 passed with stages
in 20 minutes and 46 seconds
...@@ -392,7 +392,14 @@ All predefined methods start with a capital letter and all constructed methods h ...@@ -392,7 +392,14 @@ All predefined methods start with a capital letter and all constructed methods h
Opposite of *Concatenate*. Handles a single input stream and splits it into *n* output streams. Opposite of *Concatenate*. Handles a single input stream and splits it into *n* output streams.
The output streams have the same height and width as the input stream and a number channels which is in general `input_channels / n`. The output streams have the same height and width as the input stream and a number channels which is in general `input_channels / n`.
The last output stream will have a higher number of channels than the other if `input_channels` is not divisible by `n`. The last output stream will have a higher number of channels than the other if `input_channels` is not divisible by `n`.
* **n** (integer > 0, required): The number of output streams. Cannot be higher than the number of input channels. * **n** (integer > 0, required): The number of output streams. Cannot be higher than the number of input channels.
* **OneHot(size)**
Creates a OneHot vector of a given size, given a scalar in the previous layer that determines the OneHot-Index (the index at which the *1* in the vector will be placed).
* **size** (integer > 0, optional): The OneHot-vector's size. Can be omitted to automatically use the output size of the architecture.
...@@ -189,6 +189,16 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol ...@@ -189,6 +189,16 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
} }
} }
protected static void computeOneHotOutputSize(LayerSymbol layer){
int outputChannels = 0;
if (layer.getOutputElement().get() instanceof VariableSymbol && layer.getOutputElement().get().isOutput()) {
outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
//padding with border_mode=valid, no padding //padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){ private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0); int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
......
...@@ -34,6 +34,8 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -34,6 +34,8 @@ public class OneHot extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ONE_HOT_NAME); super(AllPredefinedLayers.ONE_HOT_NAME);
} }
int size;
@Override @Override
public boolean isTrainable() { public boolean isTrainable() {
return false; return false;
...@@ -41,24 +43,8 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -41,24 +43,8 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override @Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) { public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
// TODO: Execute this code somewhere before checkInput(), for now size parameter is required
/*if(layer.getOutputElement().get() instanceof IOSymbol && layer.getOutputElement().get().isOutput()) {
int outputChannels = ((IOSymbol) layer.getOutputElement().get()).getDefinition().getType().getChannels();
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}*/
int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
/*if (size == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition());
}*/
return Collections.singletonList(new ArchTypeSymbol.Builder() return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(size) .channels(layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get())
.height(1) .height(1)
.width(1) .width(1)
.elementType("0", "1") .elementType("0", "1")
...@@ -67,6 +53,9 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -67,6 +53,9 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override @Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) { public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
computeOneHotOutputSize(layer);
size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
errorIfInputSizeIsNotOne(inputTypes, layer); errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1); errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1);
errorIfInputHeightIsInvalid(inputTypes, layer, 1); errorIfInputHeightIsInvalid(inputTypes, layer, 1);
...@@ -75,6 +64,12 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -75,6 +64,12 @@ public class OneHot extends PredefinedLayerDeclaration {
// Check range of input // Check range of input
ASTElementType domain = inputTypes.get(0).getDomain(); ASTElementType domain = inputTypes.get(0).getDomain();
if (layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get() == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition());
}
if (!domain.isNaturalNumber() && !domain.isWholeNumber()) { if (!domain.isNaturalNumber() && !domain.isWholeNumber()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input needs to be natural or whole. " Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input needs to be natural or whole. "
, layer.getSourcePosition()); , layer.getSourcePosition());
...@@ -121,6 +116,7 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -121,6 +116,7 @@ public class OneHot extends PredefinedLayerDeclaration {
new ParameterSymbol.Builder() new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SIZE_NAME) .name(AllPredefinedLayers.SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER) .constraints(Constraints.POSITIVE, Constraints.INTEGER)
.defaultValue(declaration.size)
.build())); .build()));
declaration.setParameters(parameters); declaration.setParameters(parameters);
return declaration; return declaration;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment