Commit 79cfdeba authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Fixed dot layer

parent 5636a8dd
......@@ -121,14 +121,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected void errorIfInputNotFeasibleForDotProduct(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if(!(layer.getInputTypes().get(1).getHeight() == layer.getInputTypes().get(0).getWidth())){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Dot Product cannot be applied to input 1 with height " +
layer.getInputTypes().get(1).getHeight() + " and input 0 with width " + layer.getInputTypes().get(0).getWidth()
, layer.getSourcePosition());
}
}
protected void errorIfInputChannelSizeIsInvalid(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, int channels) {
for (ArchTypeSymbol inputType : inputTypes) {
if (inputType.getChannels() != channels) {
......
......@@ -21,6 +21,8 @@
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
......@@ -35,29 +37,24 @@ public class Dot extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
if(layer.getInputTypes().get(0).getWidth() == 1 || layer.getInputTypes().get(0).getHeight() == 1) {
// if dot product between vectors
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
}else {
// if dot product between matrices
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(layer.getInputTypes().get(0).getHeight())
.width(layer.getInputTypes().get(0).getWidth())
.elementType("-oo", "oo")
.build());
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(layer.getInputTypes().get(1).getHeight())
.width(1)
.elementType("-oo", "oo")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputIsEmpty(inputTypes, layer);
errorIfInputNotFeasibleForDotProduct(inputTypes, layer);
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
if (layer.getInputTypes().get(0).getHeight().intValue() != layer.getInputTypes().get(1).getChannels().intValue()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Dot cannot be applied to input 0 with height " +
layer.getInputTypes().get(0).getHeight() + " and input 1 with channel size " + layer.getInputTypes().get(1).getChannels()
, layer.getSourcePosition());
}
}
public static Dot create(){
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment