Commit 5dc30390 authored by Sebastian Nickels's avatar Sebastian Nickels

OneHot clean up

parent 1bf276d6
Pipeline #172295 failed with stages
...@@ -189,16 +189,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol ...@@ -189,16 +189,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
} }
} }
protected static void computeOneHotOutputSize(LayerSymbol layer){
int outputChannels = 0;
if (layer.getOutputElement().get() instanceof VariableSymbol && layer.getOutputElement().get().isOutput()) {
outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
//padding with border_mode=valid, no padding //padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){ private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0); int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
......
...@@ -34,8 +34,6 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -34,8 +34,6 @@ public class OneHot extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ONE_HOT_NAME); super(AllPredefinedLayers.ONE_HOT_NAME);
} }
int size;
@Override @Override
public boolean isTrainable() { public boolean isTrainable() {
return false; return false;
...@@ -51,10 +49,19 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -51,10 +49,19 @@ public class OneHot extends PredefinedLayerDeclaration {
.build()); .build());
} }
private static void inferSizeFromOutput(LayerSymbol layer){
int outputChannels = 0;
if (layer.getOutputElement().isPresent() && layer.getOutputElement().get().isOutput()) {
outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
@Override @Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) { public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
computeOneHotOutputSize(layer); inferSizeFromOutput(layer);
size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
errorIfInputSizeIsNotOne(inputTypes, layer); errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1); errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1);
...@@ -62,14 +69,16 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -62,14 +69,16 @@ public class OneHot extends PredefinedLayerDeclaration {
errorIfInputWidthIsInvalid(inputTypes, layer, 1); errorIfInputWidthIsInvalid(inputTypes, layer, 1);
// Check range of input // Check range of input
ASTElementType domain = inputTypes.get(0).getDomain(); int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
if (layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get() == 0) { if (size == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. " Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition()); , layer.getSourcePosition());
} }
ASTElementType domain = inputTypes.get(0).getDomain();
if (!domain.isNaturalNumber() && !domain.isWholeNumber()) { if (!domain.isNaturalNumber() && !domain.isWholeNumber()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input needs to be natural or whole. " Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input needs to be natural or whole. "
, layer.getSourcePosition()); , layer.getSourcePosition());
...@@ -96,8 +105,6 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -96,8 +105,6 @@ public class OneHot extends PredefinedLayerDeclaration {
, layer.getSourcePosition()); , layer.getSourcePosition());
} }
int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
// Check if maximum < size // Check if maximum < size
if (max >= size) { if (max >= size) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: " Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: "
...@@ -107,7 +114,6 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -107,7 +114,6 @@ public class OneHot extends PredefinedLayerDeclaration {
} }
} }
} }
} }
public static OneHot create(){ public static OneHot create(){
...@@ -116,7 +122,7 @@ public class OneHot extends PredefinedLayerDeclaration { ...@@ -116,7 +122,7 @@ public class OneHot extends PredefinedLayerDeclaration {
new ParameterSymbol.Builder() new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SIZE_NAME) .name(AllPredefinedLayers.SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER) .constraints(Constraints.POSITIVE, Constraints.INTEGER)
.defaultValue(declaration.size) .defaultValue(0)
.build())); .build()));
declaration.setParameters(parameters); declaration.setParameters(parameters);
return declaration; return declaration;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment