Commit 7712a02f authored by Christian Fuß's avatar Christian Fuß
Browse files

some small fixes to UnrollSymbol

parent 8b12e410
Pipeline #175296 passed with stages
in 18 minutes and 35 seconds
......@@ -62,7 +62,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Stream = elements:(ArchitectureElement || "->")+;
Unroll = "timed" "<" timeParameter:ArchitectureParameter ">"
Unroll = "timed" "<" timeParameter:LayerParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
......
......@@ -52,63 +52,29 @@ public class CheckLayer implements CNNArchASTLayerCoCo{
nameSet.add(name);
}
}
if(node.getSymbolOpt().get() instanceof LayerSymbol) {
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (ParameterSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (ParameterSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
}else{
UnrollDeclarationSymbol unrollDeclaration = ((UnrollSymbol) node.getSymbolOpt().get()).getDeclaration();
if (unrollDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (ParameterSymbol param : unrollDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
for(LayerSymbol sublayer: unrollDeclaration.getLayers()){
check((ASTLayer) sublayer.getAstNode().get());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
}
}
......
......@@ -104,6 +104,7 @@ public class ArgumentSymbol extends CommonSymbol {
public void set(){
if (getRhs().isResolved() && getRhs().isSimpleValue()){
getParameter().setExpression((ArchSimpleExpressionSymbol) getRhs());
getUnrollParameter().setExpression((ArchSimpleExpressionSymbol) getRhs());
}
else {
throw new IllegalStateException("The value of the parameter is set to a sequence or the expression is not resolved. This should never happen.");
......
......@@ -31,6 +31,7 @@ import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.*;
import de.se_rwth.commons.logging.Log;
import java.lang.reflect.Array;
import java.util.*;
public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSymbolTableCreator
......@@ -346,26 +347,15 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
public void endVisit(ASTUnroll ast) {
UnrollSymbol layer = (UnrollSymbol) ast.getSymbolOpt().get();
layer.setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
//layer.getDeclaration().setBody(sces);
List<ArgumentSymbol> arguments = new ArrayList<>(6);
//ast.getArgumentsList().add(ast.getTimeParameter());
for (ASTArchArgument astArgument : ast.getArgumentsList()){
Optional<ArgumentSymbol> optArgument = astArgument.getSymbolOpt().map(e -> (ArgumentSymbol)e);
optArgument.ifPresent(arguments::add);
}
layer.setArguments(arguments);
/*List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : ast.getGroupsList()){
elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}
compositeElement.setElements(elements);
*/
removeCurrentScope();
}
......
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchitectureParameter;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
......@@ -39,7 +40,7 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
private UnrollDeclarationSymbol declaration = null;
private List<ArgumentSymbol> arguments;
private Set<ParameterSymbol> unresolvableParameters = null;
private ParameterSymbol timeParameter;
private UnrollSymbol resolvedThis = null;
private SerialCompositeElementSymbol body;
......@@ -86,6 +87,14 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
this.arguments = arguments;
}
public ParameterSymbol getTimeParameter(){
return timeParameter;
}
protected void setTimeParameter(ParameterSymbol timeParameter){
this.timeParameter = timeParameter;
}
public ArchExpressionSymbol getIfExpression(){
Optional<ArgumentSymbol> argument = getArgument(AllPredefinedVariables.CONDITIONAL_ARG_NAME);
if (argument.isPresent()){
......@@ -222,6 +231,50 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
}
}
public void setIntValue(String parameterName, int value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setIntTupleValue(String parameterName, List<Object> tupleValues) {
setTValue(parameterName, tupleValues, ArchSimpleExpressionSymbol::of);
}
public void setBooleanValue(String parameterName, boolean value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setStringValue(String parameterName, String value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setDoubleValue(String parameterName, double value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setValue(String parameterName, Object value) {
ArchSimpleExpressionSymbol res = new ArchSimpleExpressionSymbol();
res.setValue(value);
setTValue(parameterName, res, Function.identity());
}
public <T> void setTValue(String parameterName, T value, Function<T, ArchSimpleExpressionSymbol> of) {
Optional<ParameterSymbol> param = getDeclaration().getParameter(parameterName);
if (param.isPresent()) {
Optional<ArgumentSymbol> arg = getArgument(parameterName);
ArchSimpleExpressionSymbol expression = of.apply(value);
if (arg.isPresent()) {
arg.get().setRhs(expression);
}
else {
arg = Optional.of(new ArgumentSymbol(parameterName));
arg.get().setRhs(expression);
arguments.add(arg.get());
}
}
}
......
......@@ -74,6 +74,7 @@ public class AllPredefinedLayers {
public static final String OUTPUT_DIM_NAME = "output_dim";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
public static final String BEAMSEARCH_T_NAME = "t";
//possible String values
public static final String PADDING_VALID = "valid";
......
......@@ -104,17 +104,15 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
.name(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_T_NAME)
.constraints(Constraints.INTEGER, Constraints.NON_NEGATIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_WIDTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
declaration.setLayers(declaration.layers);
for(LayerSymbol layer: declaration.layers){
for(ArgumentSymbol a: layer.getArguments()) {
//layer.setIntValue(a.getName(), 10);
}
}
return declaration;
}
}
......@@ -45,9 +45,9 @@ public class AllCoCoTest extends AbstractCoCoTest {
@Test
public void testValidCoCos(){
checkValid("valid_tests", "RNNencdec");
checkValid("architectures", "ResNeXt50");
checkValid("architectures", "ResNet152");
checkValid("architectures", "Alexnet");
checkValid("architectures", "ResNeXt50");
checkValid("architectures", "ResNet34");
checkValid("architectures", "SequentialAlexnet");
checkValid("architectures", "ThreeInputCNN_M14");
......
......@@ -14,7 +14,7 @@ architecture RNNsearch(max_length=50, vocabulary_size=30001, embedding_size=620,
1 -> OneHot(n=vocabulary_size) -> target[0];
encoder.state[1] -> decoder.state;
unroll<t> BeamSearchStart(width=5, max_length=50) {
unroll<t=0> BeamSearchStart(width=5, max_length=50) {
(
(
(
......
architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target
def output Q(0:1)^{vocabulary_size} target[3]
timed<t> BeamSearchStart(max_length=max_length) {
source ->
FullyConnected(units=17) ->
Softmax() ->
target
};
source -> Softmax() -> target[0];
timed <t=2> BeamSearchStart(max_length=3){
target[t-1] ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment