Commit 635831fa authored by Sebastian Nickels's avatar Sebastian Nickels

Merge

parents aac22a6d 0d4c4530
Pipeline #170395 passed with stages
in 20 minutes and 46 seconds
......@@ -392,7 +392,14 @@ All predefined methods start with a capital letter and all constructed methods h
Opposite of *Concatenate*. Handles a single input stream and splits it into *n* output streams.
The output streams have the same height and width as the input stream and a number channels which is in general `input_channels / n`.
The last output stream will have a higher number of channels than the other if `input_channels` is not divisible by `n`.
* **n** (integer > 0, required): The number of output streams. Cannot be higher than the number of input channels.
* **OneHot(size)**
Creates a OneHot vector of a given size, given a scalar in the previous layer that determines the OneHot-Index (the index at which the *1* in the vector will be placed).
* **size** (integer > 0, optional): The OneHot-vector's size. Can be omitted to automatically use the output size of the architecture.
......@@ -189,6 +189,16 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected static void computeOneHotOutputSize(LayerSymbol layer){
int outputChannels = 0;
if (layer.getOutputElement().get() instanceof VariableSymbol && layer.getOutputElement().get().isOutput()) {
outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
//padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
......
......@@ -34,6 +34,8 @@ public class OneHot extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ONE_HOT_NAME);
}
int size;
@Override
public boolean isTrainable() {
return false;
......@@ -41,24 +43,8 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
// TODO: Execute this code somewhere before checkInput(), for now size parameter is required
/*if(layer.getOutputElement().get() instanceof IOSymbol && layer.getOutputElement().get().isOutput()) {
int outputChannels = ((IOSymbol) layer.getOutputElement().get()).getDefinition().getType().getChannels();
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}*/
int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
/*if (size == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition());
}*/
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(size)
.channels(layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get())
.height(1)
.width(1)
.elementType("0", "1")
......@@ -67,6 +53,9 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
computeOneHotOutputSize(layer);
size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1);
errorIfInputHeightIsInvalid(inputTypes, layer, 1);
......@@ -75,6 +64,12 @@ public class OneHot extends PredefinedLayerDeclaration {
// Check range of input
ASTElementType domain = inputTypes.get(0).getDomain();
if (layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get() == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition());
}
if (!domain.isNaturalNumber() && !domain.isWholeNumber()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input needs to be natural or whole. "
, layer.getSourcePosition());
......@@ -121,6 +116,7 @@ public class OneHot extends PredefinedLayerDeclaration {
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER)
.defaultValue(declaration.size)
.build()));
declaration.setParameters(parameters);
return declaration;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment