Commit 0052260f authored by Christian Fuß's avatar Christian Fuß
Browse files

progress

parent 38bd40be
......@@ -57,8 +57,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
@attribute body
The architecture of the neural network.
*/
Architecture = unrollDeclarations:UnrollDeclaration*
methodDeclaration:LayerDeclaration*
Architecture = methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
interface Instruction;
......@@ -73,9 +72,9 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = "unroll" "<" timeParameter:LayerParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
Unroll implements ArchitectureElement = "unroll" Name "(" arguments:(ArchArgument || ",")* ")" "{"
body:Stream
"}";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
......
......@@ -24,9 +24,12 @@ import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayer;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
......@@ -50,29 +53,62 @@ public class CheckLayer implements CNNArchASTLayerCoCo{
}
}
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (VariableSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
if(node.getSymbolOpt().get() instanceof LayerSymbol) {
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (VariableSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}else{
UnrollDeclarationSymbol unrollDeclaration = ((UnrollSymbol) node.getSymbolOpt().get()).getDeclaration();
if (unrollDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (VariableSymbol param : unrollDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
for(LayerSymbol sublayer: unrollDeclaration.getLayers()){
check((ASTLayer) sublayer.getAstNode().get());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
}
}
......
......@@ -351,7 +351,14 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
@Override
public void endVisit(ASTUnroll ast) {
UnrollSymbol layer = (UnrollSymbol) ast.getSymbolOpt().get();
layer.getDeclaration().setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
SerialCompositeElementSymbol sces = new SerialCompositeElementSymbol();
List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTArchitectureElement astElement : ast.getBody().getElementsList()){
elements.add((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
}
layer.getDeclaration().setBody(sces);
layer.getDeclaration().getBody().setElements(elements);
List<ArgumentSymbol> arguments = new ArrayList<>(6);
for (ASTArchArgument astArgument : ast.getArgumentsList()){
......@@ -362,11 +369,27 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
List<ArchitectureElementSymbol> elements = new ArrayList<>();
int elementNumber = 0;
for (ASTArchitectureElement astElement : ast.getBody().getElementsList()){
elements.add((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
//TODO: assign parameters to layers in Unroll
//sublayer.getDeclaration().setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
List<ArgumentSymbol> subarguments = new ArrayList<>(6);
if(astElement.getSymbolOpt().get() instanceof LayerSymbol) {
for (ArgumentSymbol argument : ((LayerSymbol) astElement.getSymbolOpt().get()).getArguments()) {
System.err.println("arg: " + argument);
subarguments.add(argument);
}
((LayerSymbol) astElement.getSymbolOpt().get()).setArguments(subarguments);
}
if(elementNumber == 0 && astElement.getSymbolOpt().get() instanceof IOSymbol){
layer.getDeclaration().getBody().setInputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
}else if(elementNumber == (ast.getBody().getElementsList().size() - 1) && astElement.getSymbolOpt().get() instanceof IOSymbol){
......@@ -375,7 +398,6 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
elementNumber++;
}
layer.getDeclaration().getBody().setElements(elements);
/*List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : ast.getGroupsList()){
......
......@@ -55,6 +55,9 @@ public class IOSymbol extends ArchitectureElementSymbol {
//returns null if IODeclaration does not exist. This is checked in coco CheckIOName.
public IODeclarationSymbol getDefinition() {
if (definition == null){
System.err.println("#1112 " + getEnclosingScope().getSpanningSymbol().get().toString());
System.err.println("#11123 " + getEnclosingScope().getSpanningSymbol().get().getEnclosingScope().getSpanningSymbol().get().toString());
System.err.println("#111234 " + getEnclosingScope().getSpanningSymbol().get().getEnclosingScope().getSpanningSymbol().get().getEnclosingScope().getSpanningSymbol().get().toString());
this.definition = getArchitecture().resolveIODeclaration(getName());
}
return definition;
......@@ -183,6 +186,8 @@ public class IOSymbol extends ArchitectureElementSymbol {
getArrayAccess().get().checkIfResolvable(allVariables);
unresolvableVariables.addAll(getArrayAccess().get().getUnresolvableVariables());
}
System.err.println("IO Definition: " + getDefinition());
System.err.println("IO Type: " + getDefinition().getType());
getDefinition().getType().checkIfResolvable(allVariables);
unresolvableVariables.addAll(getDefinition().getType().getUnresolvableVariables());
}
......
......@@ -48,6 +48,13 @@ abstract public class PredefinedUnrollDeclaration extends UnrollDeclarationSymbo
}
}
protected void setLayers(List<LayerSymbol> layers) {
super.setLayers(layers);
for (LayerSymbol layer : layers){
layer.putInScope(getSpannedScope());
}
}
@Override
public boolean isPredefined() {
return true;
......
......@@ -35,6 +35,7 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
public static final UnrollDeclarationKind KIND = new UnrollDeclarationKind();
private List<VariableSymbol> parameters;
private List<LayerSymbol> layers;
private SerialCompositeElementSymbol body;
......@@ -52,6 +53,14 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
return (UnrollDeclarationScope) super.getSpannedScope();
}
protected void setLayers(List<LayerSymbol> layers) {
this.layers = layers;
}
public List<LayerSymbol> getLayers() {
return layers;
}
public List<VariableSymbol> getParameters() {
return parameters;
}
......
......@@ -33,8 +33,48 @@ import java.util.function.Function;
public class UnrollSymbol extends ArchitectureElementSymbol {
protected List<ArchitectureElementSymbol> elements = new ArrayList<>();
protected void setElements(List<ArchitectureElementSymbol> elements) {
ArchitectureElementSymbol previous = null;
for (ArchitectureElementSymbol current : elements){
if (previous != null){
current.setInputElement(previous);
previous.setOutputElement(current);
}
else {
if (getInputElement().isPresent()){
current.setInputElement(getInputElement().get());
}
if (getOutputElement().isPresent()){
current.setOutputElement(getOutputElement().get());
}
}
previous = current;
}
this.elements = elements;
}
public List<ArchitectureElementSymbol> getElements() {
return elements;
}
private UnrollDeclarationSymbol declaration = null;
private List<ArgumentSymbol> arguments;
private SerialCompositeElementSymbol body;
public SerialCompositeElementSymbol getBody() {
return body;
}
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
public boolean isNetworkLayer() {
return body.isNetwork();
}
protected UnrollSymbol(String name) {
super(name);
......@@ -114,22 +154,12 @@ public class UnrollSymbol extends ArchitectureElementSymbol {
@Override
public List<ArchitectureElementSymbol> getFirstAtomicElements() {
if (isAtomic()){
return Collections.singletonList(this);
}
else {
return getResolvedThis().get().getFirstAtomicElements();
}
return getDeclaration().getBody().getElements().get(0).getFirstAtomicElements();
}
@Override
public List<ArchitectureElementSymbol> getLastAtomicElements() {
if (isAtomic()){
return Collections.singletonList(this);
}
else {
return getResolvedThis().get().getLastAtomicElements();
}
return getDeclaration().getBody().getElements().get(getDeclaration().getBody().getElements().size()-1).getLastAtomicElements();
}
@Override
......@@ -142,15 +172,18 @@ public class UnrollSymbol extends ArchitectureElementSymbol {
int maxSerialLength = getMaxSerialLength().get();
if (!isActive() || maxSerialLength == 0) {
System.err.println("UnrollSymbol resolveSequences called!1");
//set resolvedThis to empty composite to remove the unroll.
setResolvedThis(new SerialCompositeElementSymbol());
}
else if (parallelLength == 1 && maxSerialLength == 1) {
System.err.println("UnrollSymbol resolveSequences called!2");
//resolve the unroll call
ArchitectureElementSymbol resolvedUnroll = getDeclaration().call(this);
setResolvedThis(resolvedUnroll);
}
else {
System.err.println("UnrollSymbol resolveSequences called!3");
//split the unroll if it contains an argument sequence
ArchitectureElementSymbol splitComposite = resolveSequences(parallelLength, getSerialLengths().get());
setResolvedThis(splitComposite);
......
......@@ -20,6 +20,7 @@
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._ast.*;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.se_rwth.commons.Joiners;
......@@ -34,19 +35,60 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
super(AllPredefinedLayers.BEAMSEARCH_NAME);
}
List<LayerSymbol> layers = new ArrayList<>(Arrays.asList());
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer) {
for(ArchitectureElementSymbol item:layer.getDeclaration().getBody().getElements()){
try {
item.resolve();
} catch (ArchResolveException e) {
System.err.println("The following names could not be resolved: " + Joiners.COMMA.join(item.getUnresolvableVariables()));
List<ArchTypeSymbol> output = new ArrayList<ArchTypeSymbol>();
for(ASTArchitectureElement item: ((ASTUnroll) layer.getAstNode().get()).getBody().getElementsList()){
if(item instanceof ASTLayer) {
try {
//ArchitectureElementSymbol item = (ArchitectureElementSymbol) ASTitem;
LayerSymbol sublayer = (LayerSymbol) item.getSymbol();
sublayer.resolve();
System.err.println("inputElement: " + sublayer.getInputElement().get().toString());
System.err.println("resolved: " + sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes().get(0).getChannels().toString());
output = sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes();
System.err.println("inputTypes: " + sublayer.getInputTypes().toString());
layers.add((LayerSymbol) sublayer);
System.err.println("outputTypes before: " + ((ArchitectureElementSymbol) sublayer).getOutputTypes().get(0).getChannels());
System.err.println("LOL_NAME: " + ((LayerSymbol) sublayer).getArguments().get(0).getName());
System.err.println("LOL: " + ((LayerSymbol) sublayer).getIntValue((((LayerSymbol) sublayer).getArguments().get(0).getName())));
//item.setOutputTypes(item.getOutputTypes());
//System.err.println("outputTypes after: " + ((ArchitectureElementSymbol) astElement).getOutputTypes());
} catch (Exception e) {
LayerSymbol sublayer = (LayerSymbol) item.getSymbol();
//System.err.println("The following names could not be resolved: " + Joiners.COMMA.join(sublayer.getUnresolvableVariables()));
e.printStackTrace();
}
}else if(item instanceof ASTIOElement){
try {
//ArchitectureElementSymbol item = (ArchitectureElementSymbol) ASTitem;
IOSymbol sublayer = (IOSymbol) item.getSymbol();
sublayer.resolve();
System.err.println("resolved2: " + sublayer.getResolvedThis().get().getInputElement().get().toString());
System.err.println("isOutput?: " + sublayer.isOutput());
System.err.println("Definition: " + sublayer.getDefinition().toString());
System.err.println("Type: " + sublayer.getDefinition().getType().getChannels().toString());
output = sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes();
if(sublayer.isOutput()){
output = new ArrayList<ArchTypeSymbol>();
}
//item.setOutputTypes(item.getOutputTypes());
//System.err.println("outputTypes after: " + ((ArchitectureElementSymbol) astElement).getOutputTypes());
} catch (Exception e) {
IOSymbol sublayer = (IOSymbol) item.getSymbol();
//System.err.println("The following names could not be resolved2: " + Joiners.COMMA.join(sublayer.getUnresolvableVariables()));
e.printStackTrace();
}
}
}
return layer.getDeclaration().getBody().computeOutputTypes();
System.err.println("output: " + output);
return output;
}
@Override
......@@ -67,6 +109,12 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
declaration.setLayers(declaration.layers);
for(LayerSymbol layer: declaration.layers){
for(ArgumentSymbol a: layer.getArguments()) {
//layer.setIntValue(a.getName(), 10);
}
}
return declaration;
}
}
......@@ -39,11 +39,6 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
<<<<<<< HEAD
channels = inputTypes.get(0).getChannels();
=======
>>>>>>> develop
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get())
.height(1)
......@@ -59,9 +54,6 @@ public class OneHot extends PredefinedLayerDeclaration {
size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
errorIfInputSizeIsNotOne(inputTypes, layer);
<<<<<<< HEAD
//errorIfInputSizeUnequalToOnehotSize(inputTypes, layer);
=======
errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1);
errorIfInputHeightIsInvalid(inputTypes, layer, 1);
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
......@@ -113,7 +105,6 @@ public class OneHot extends PredefinedLayerDeclaration {
}
}
>>>>>>> develop
}
public static OneHot create(){
......
......@@ -66,11 +66,7 @@ public class AllCoCoTest extends AbstractCoCoTest {
checkValid("valid_tests", "Alexnet_alt2");
checkValid("valid_tests", "MultipleOutputs");
checkValid("valid_tests", "MultipleStreams");
<<<<<<< HEAD
checkValid("valid_tests", "Alexnet_alt_OneHotOutput");
checkValid("valid_tests", "RNNencdec");
=======
>>>>>>> develop
}
@Test
......
......@@ -2,10 +2,11 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target
unroll<t> BeamSearchStart(max_length=max_length) {
unroll BeamSearchStart(max_length=max_length) {
source ->
FullyConnected(units=vocabulary_size) ->
FullyConnected(units=17) ->
Softmax() ->
FullyConnected(units=vocabulary_size) ->
target
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment