Commit 5dc30390 authored by Sebastian Nickels's avatar Sebastian Nickels

OneHot clean up

parent 1bf276d6
Pipeline #172295 failed with stages
......@@ -189,16 +189,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected static void computeOneHotOutputSize(LayerSymbol layer){
int outputChannels = 0;
if (layer.getOutputElement().get() instanceof VariableSymbol && layer.getOutputElement().get().isOutput()) {
outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
//padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
......
......@@ -34,8 +34,6 @@ public class OneHot extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ONE_HOT_NAME);
}
int size;
@Override
public boolean isTrainable() {
return false;
......@@ -51,10 +49,19 @@ public class OneHot extends PredefinedLayerDeclaration {
.build());
}
private static void inferSizeFromOutput(LayerSymbol layer){
int outputChannels = 0;
if (layer.getOutputElement().isPresent() && layer.getOutputElement().get().isOutput()) {
outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
computeOneHotOutputSize(layer);
size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
inferSizeFromOutput(layer);
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1);
......@@ -62,14 +69,16 @@ public class OneHot extends PredefinedLayerDeclaration {
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
// Check range of input
ASTElementType domain = inputTypes.get(0).getDomain();
int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
if (layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get() == 0) {
if (size == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition());
}
ASTElementType domain = inputTypes.get(0).getDomain();
if (!domain.isNaturalNumber() && !domain.isWholeNumber()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input needs to be natural or whole. "
, layer.getSourcePosition());
......@@ -96,8 +105,6 @@ public class OneHot extends PredefinedLayerDeclaration {
, layer.getSourcePosition());
}
int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
// Check if maximum < size
if (max >= size) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: "
......@@ -107,7 +114,6 @@ public class OneHot extends PredefinedLayerDeclaration {
}
}
}
}
public static OneHot create(){
......@@ -116,7 +122,7 @@ public class OneHot extends PredefinedLayerDeclaration {
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER)
.defaultValue(declaration.size)
.defaultValue(0)
.build()));
declaration.setParameters(parameters);
return declaration;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment