Commit f04f2298 authored by Christian Fuß's avatar Christian Fuß
Browse files

adjusted ArgMax CoCo to allow ArgMax usage before (LSTM) states

parent 08aeee32
...@@ -35,7 +35,7 @@ public class CheckArgmaxLayer extends CNNArchSymbolCoCo { ...@@ -35,7 +35,7 @@ public class CheckArgmaxLayer extends CNNArchSymbolCoCo {
} }
public void checkArgmaxBeforeOutput(LayerSymbol layer) { public void checkArgmaxBeforeOutput(LayerSymbol layer) {
if(!(layer.getOutputElement().get() instanceof VariableSymbol && ((VariableSymbol) layer.getOutputElement().get()).getType() == VariableSymbol.Type.IO)){ if(!(layer.getOutputElement().get() instanceof VariableSymbol)){
Log.error("0" + ErrorCodes.ILLEGAL_LAYER_USE + " ArgMax Layer must be applied directly before an output symbol."); Log.error("0" + ErrorCodes.ILLEGAL_LAYER_USE + " ArgMax Layer must be applied directly before an output symbol.");
} }
} }
......
...@@ -90,10 +90,10 @@ public class Embedding extends PredefinedLayerDeclaration { ...@@ -90,10 +90,10 @@ public class Embedding extends PredefinedLayerDeclaration {
ASTElementType domain = layer.getInputTypes().get(0).getDomain(); ASTElementType domain = layer.getInputTypes().get(0).getDomain();
if (!domain.isWholeNumber() && !domain.isNaturalNumber()) { /*if (!domain.isWholeNumber() && !domain.isNaturalNumber()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input must be natural. ", layer.getSourcePosition()); Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input must be natural. ", layer.getSourcePosition());
} }
else if (!domain.isPresentRange() || domain.getRange().hasNoUpperLimit()) { else */if (!domain.isPresentRange() || domain.getRange().hasNoUpperLimit()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input range must have an upper limit. ", layer.getSourcePosition()); Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input range must have an upper limit. ", layer.getSourcePosition());
} }
else if (domain.getRange().getStartValue().intValue() < 0 || domain.getRange().getEndValue().intValue() >= inputDim) { else if (domain.getRange().getStartValue().intValue() < 0 || domain.getRange().getEndValue().intValue() >= inputDim) {
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment