Commit 7aeb61ac authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Implemented multiple instructions, for now only multiple streams are supported

parent bcf425cb
Pipeline #138811 passed with stages
in 21 minutes and 54 seconds
......@@ -23,7 +23,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
LayerDeclaration = "def"
Name "("
parameters:(LayerParameter || ",")* ")" "{"
body:ArchBody "}";
body:Stream "}";
IODeclaration = "def"
(in:"input" | out:"output")
......@@ -53,10 +53,12 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
@attribute body
The architecture of the neural network.
*/
Architecture = methodDeclaration:LayerDeclaration*
body:ArchBody ;
Architecture = methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
scope ArchBody = elements:(ArchitectureElement || "->")*;
interface Instruction;
Stream implements Instruction = elements:(ArchitectureElement || "->")+;
interface ArchitectureElement;
......@@ -65,8 +67,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
ParallelBlock implements ArchitectureElement = "("
groups:ArchBody "|"
groups:(ArchBody || "|")+ ")";
groups:Stream "|"
groups:(Stream || "|")+ ")";
ArrayAccessLayer implements ArchitectureElement = "[" index:ArchSimpleExpression "]";
......@@ -160,4 +162,4 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
ast ArchArgument = method String getName(){}
method ASTArchExpression getRhs(){};
}
\ No newline at end of file
}
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -28,10 +29,12 @@ public class CheckArchitectureFinished extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
if (!architecture.getBody().getOutputTypes().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, architecture.getSourcePosition());
for (CompositeElementSymbol stream : architecture.getStreams()) {
if (!stream.getOutputTypes().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, architecture.getSourcePosition());
}
}
if (architecture.getInputs().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture has no inputs. "
......
......@@ -21,11 +21,14 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
public class CheckElementInputs extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
architecture.getBody().checkInput();
for (CompositeElementSymbol stream : architecture.getStreams()) {
stream.checkInput();
}
}
}
......@@ -37,7 +37,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
public static final ArchitectureKind KIND = new ArchitectureKind();
private ArchitectureElementSymbol body;
private List<CompositeElementSymbol> streams;
private List<IOSymbol> inputs = new ArrayList<>();
private List<IOSymbol> outputs = new ArrayList<>();
private Map<String, IODeclarationSymbol> ioDeclarationMap = new HashMap<>();
......@@ -48,12 +48,12 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
super("", KIND);
}
public ArchitectureElementSymbol getBody() {
return body;
public List<CompositeElementSymbol> getStreams() {
return streams;
}
protected void setBody(ArchitectureElementSymbol body) {
this.body = body;
public void setStreams(List<CompositeElementSymbol> streams) {
this.streams = streams;
}
public String getDataPath() {
......@@ -103,30 +103,44 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return getSpannedScope().resolveLocally(LayerDeclarationSymbol.KIND);
}
public void resolve() {
for (CompositeElementSymbol stream : streams) {
stream.checkIfResolvable();
public void resolve(){
getBody().checkIfResolvable();
try{
getBody().resolveOrError();
}
catch (ArchResolveException e){
//do nothing; error is already logged
try {
stream.resolveOrError();
}
catch (ArchResolveException e) {
// Do nothing; error is already logged
}
}
}
public List<ArchitectureElementSymbol> getFirstElements(){
/*public List<ArchitectureElementSymbol> getFirstElements() {
if (!getBody().isResolved()){
resolve();
}
return getBody().getFirstAtomicElements();
}
}*/
public boolean isResolved(){
return getBody().isResolved();
boolean resolved = true;
for (CompositeElementSymbol stream : streams) {
resolved &= stream.isResolved();
}
return resolved;
}
public boolean isResolvable(){
return getBody().isResolvable();
boolean resolvable = true;
for (CompositeElementSymbol stream : streams) {
resolvable &= stream.isResolvable();
}
return resolvable;
}
public void putInScope(Scope scope){
......@@ -145,22 +159,32 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
*/
public ArchitectureSymbol preResolveDeepCopy(Scope enclosingScopeOfCopy){
ArchitectureSymbol copy = new ArchitectureSymbol();
copy.setBody(getBody().preResolveDeepCopy());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
}
copy.getSpannedScope().getAsMutableScope().add(AllPredefinedVariables.createTrueConstant());
copy.getSpannedScope().getAsMutableScope().add(AllPredefinedVariables.createFalseConstant());
for (LayerDeclarationSymbol layerDeclaration : AllPredefinedLayers.createList()){
copy.getSpannedScope().getAsMutableScope().add(layerDeclaration);
}
for (LayerDeclarationSymbol layerDeclaration : getSpannedScope().<LayerDeclarationSymbol>resolveLocally(LayerDeclarationSymbol.KIND)){
if (!layerDeclaration.isPredefined()) {
copy.getSpannedScope().getAsMutableScope().add(layerDeclaration.deepCopy());
}
}
copy.getBody().putInScope(copy.getSpannedScope());
List<CompositeElementSymbol> copyStreams = new ArrayList<>();
for (CompositeElementSymbol stream : streams) {
CompositeElementSymbol copyStream = stream.preResolveDeepCopy();
copyStream.putInScope(copy.getSpannedScope());
copyStreams.add(copyStream);
}
copy.setStreams(copyStreams);
copy.putInScope(enclosingScopeOfCopy);
return copy;
}
......
......@@ -145,8 +145,12 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
public void endVisit(final ASTArchitecture node) {
//ArchitectureSymbol architecture = (ArchitectureSymbol) node.getSymbolOpt().get();
architecture.setBody((ArchitectureElementSymbol) node.getBody().getSymbolOpt().get());
List<CompositeElementSymbol> streams = new ArrayList<>();
for (ASTInstruction astInstruction : node.getInstructionsList()){
ASTStream astStream = (ASTStream)astInstruction; // TODO: For now all instructions are streams
streams.add((CompositeElementSymbol) astStream.getSymbolOpt().get());
}
architecture.setStreams(streams);
removeCurrentScope();
}
......@@ -326,8 +330,8 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
CompositeElementSymbol compositeElement = (CompositeElementSymbol) node.getSymbolOpt().get();
List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTArchBody astBody : node.getGroupsList()){
elements.add((CompositeElementSymbol) astBody.getSymbolOpt().get());
for (ASTStream astStream : node.getGroupsList()){
elements.add((CompositeElementSymbol) astStream.getSymbolOpt().get());
}
compositeElement.setElements(elements);
......@@ -335,16 +339,15 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
@Override
public void visit(ASTArchBody ast) {
public void visit(ASTStream ast) {
CompositeElementSymbol compositeElement = new CompositeElementSymbol();
compositeElement.setParallel(false);
addToScopeAndLinkWithNode(compositeElement, ast);
}
@Override
public void endVisit(ASTArchBody ast) {
public void endVisit(ASTStream ast) {
CompositeElementSymbol compositeElement = (CompositeElementSymbol) ast.getSymbolOpt().get();
List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTArchitectureElement astElement : ast.getElementsList()){
elements.add((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
......
......@@ -93,9 +93,9 @@ public class InstanceTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol compilationUnit2 = compilationUnitSymbol.preResolveDeepCopy();
compilationUnit2.setParameter("cardinality", 2);
ArchitectureSymbol instance2 = compilationUnit2.resolve();
ArchRangeExpressionSymbol range1 = (ArchRangeExpressionSymbol) ((LayerSymbol)(((CompositeElementSymbol)((CompositeElementSymbol)((CompositeElementSymbol)((CompositeElementSymbol) instance1.getBody()).getElements().get(5).getResolvedThis().get()).getElements().get(0)).getElements().get(0)).getElements().get(0)))
ArchRangeExpressionSymbol range1 = (ArchRangeExpressionSymbol) ((LayerSymbol)(((CompositeElementSymbol)((CompositeElementSymbol)((CompositeElementSymbol)((CompositeElementSymbol) instance1.getStreams().get(0)).getElements().get(5).getResolvedThis().get()).getElements().get(0)).getElements().get(0)).getElements().get(0)))
.getArgument(AllPredefinedVariables.PARALLEL_ARG_NAME).get().getRhs();
ArchRangeExpressionSymbol range2 = (ArchRangeExpressionSymbol) ((LayerSymbol)(((CompositeElementSymbol)((CompositeElementSymbol)((CompositeElementSymbol)((CompositeElementSymbol) instance2.getBody()).getElements().get(5).getResolvedThis().get()).getElements().get(0)).getElements().get(0)).getElements().get(0)))
ArchRangeExpressionSymbol range2 = (ArchRangeExpressionSymbol) ((LayerSymbol)(((CompositeElementSymbol)((CompositeElementSymbol)((CompositeElementSymbol)((CompositeElementSymbol) instance2.getStreams().get(0)).getElements().get(5).getResolvedThis().get()).getElements().get(0)).getElements().get(0)).getElements().get(0)))
.getArgument(AllPredefinedVariables.PARALLEL_ARG_NAME).get().getRhs();
assertEquals(32, range1.getElements().get().size());
......
......@@ -42,8 +42,10 @@ public class ParserTest {
public static final boolean ENABLE_FAIL_QUICK = false;
private static List<String> expectedParseErrorModels = Arrays.asList(
// incorrect argument name for a layer
"src/test/resources/architectures/RNNsearch.cnna",
"src/test/resources/invalid_tests/MissingParallelBrackets.cnna",
"src/test/resources/invalid_tests/MissingLayerOperator.cnna")
"src/test/resources/invalid_tests/MissingLayerOperator.cnna",
"src/test/resources/invalid_tests/MissingSemicolon.cnna")
.stream().map(s -> Paths.get(s).toString())
.collect(Collectors.toList());
......
......@@ -54,7 +54,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
@Test
......@@ -65,7 +65,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
@Test
......@@ -76,7 +76,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
}
......@@ -65,6 +65,7 @@ public class AllCoCoTest extends AbstractCoCoTest {
checkValid("valid_tests", "ResNeXt50_alt");
checkValid("valid_tests", "Alexnet_alt2");
checkValid("valid_tests", "MultipleOutputs");
checkValid("valid_tests", "MultipleStreams");
}
@Test
......@@ -230,15 +231,6 @@ public class AllCoCoTest extends AbstractCoCoTest {
new ExpectedErrorInfo(2, ErrorCodes.ILLEGAL_NAME));
}
@Test
public void testUnfinishedArchitecture(){
checkInvalid(new CNNArchCoCoChecker(),
new CNNArchSymbolCoCoChecker(),
new CNNArchSymbolCoCoChecker().addCoCo(new CheckArchitectureFinished()),
"invalid_tests", "UnfinishedArchitecture",
new ExpectedErrorInfo(1, ErrorCodes.UNFINISHED_ARCHITECTURE));
}
@Test
public void testInvalidInputShape(){
checkInvalid(new CNNArchCoCoChecker(),
......
......@@ -39,5 +39,5 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
fc(->=2) ->
FullyConnected(units=10) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
package RNNsearch;
architecture RNNsearch(max_length=50, vocabulary_size=30001, embedding_size=620, hidden_size=1000){
def input Q^{max_length, vocabulary_size} source
def output Q^{max_length, vocabulary_size} target
......@@ -16,33 +14,33 @@ architecture RNNsearch(max_length=50, vocabulary_size=30001, embedding_size=620,
1 -> OneHot(n=vocabulary_size) -> target[0];
encoder.state[1] -> decoder.state;
unroll<t> BeamSearchStart(width=5, max_length=50) ->
(
unroll<t> BeamSearchStart(width=5, max_length=50) {
(
(
decoder.state ->
Repeat(n=max_length, dim=1)
(
decoder.state ->
Repeat(n=max_length, dim=1)
|
encoder.output
) ->
Concatenate(dim=2) ->
FullyConnected(units=hidden_size, flatten=false) ->
Tanh() ->
FullyConnected(units=max_length) ->
Softmax()
|
encoder.output
) ->
Concatenate(dim=2) ->
FullyConnected(units=hidden_size, flatten=false) ->
Tanh() ->
FullyConnected(units=max_length) ->
Softmax()
Dot()
|
encoder.output
target[t-1] ->
Embedding(input=vocabulary_size, output=hidden_size)
) ->
Dot()
|
target[t-1] ->
Embedding(input=vocabulary_size, output=hidden_size)
) ->
Concatenate() ->
decoder ->
FullyConnected(units=vocabulary_size) ->
ArgMax() ->
OneHot(n=vocabulary_size) ->
target[t] ->
BeamSearchEnd();
Concatenate() ->
decoder ->
FullyConnected(units=vocabulary_size) ->
ArgMax() ->
OneHot(n=vocabulary_size) ->
target[t]
};
}
......@@ -40,5 +40,5 @@ architecture ResNeXt50(img_height=224, img_width=224, img_channels=3, classes=10
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -33,5 +33,5 @@ architecture ResNet152(img_height=224, img_width=224, img_channels=3, classes=10
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -31,5 +31,5 @@ architecture ResNet34(img_height=224, img_width=224, img_channels=3, classes=100
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -25,5 +25,5 @@ architecture SequentialAlexnet(img_height=224, img_width=224, img_channels=3, cl
fc() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -28,5 +28,5 @@ architecture ThreeInputCNN_M14(img_height=200, img_width=300, img_channels=3, cl
Relu() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -27,5 +27,5 @@ architecture VGG16(img_height=224, img_width=224, img_channels=3, classes=1000){
fc() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -30,5 +30,5 @@ architecture ArgumentConstraintTest1(img_height=224, img_width=224, img_channels
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes, ->=true) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -30,5 +30,5 @@ architecture ArgumentConstraintTest2(img_height=224, img_width=224, img_channels
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -30,5 +30,5 @@ architecture ArgumentConstraintTest3(img_height=224, img_width=224, img_channels
GlobalPooling(pool_type="avg", ?=1) ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment