Commit bc061801 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Merge branch 'rnn' into develop

parents 195c5712 cd828ae3
Pipeline #204094 passed with stages
in 18 minutes and 36 seconds
......@@ -68,11 +68,11 @@ public class CNNArchCocos {
.addCoCo(new CheckElementInputs())
.addCoCo(new CheckIOAccessAndIOMissing())
.addCoCo(new CheckArchitectureFinished())
.addCoCo(new CheckNetworkStreamMissing())
.addCoCo(new CheckVariableMember())
.addCoCo(new CheckLayerVariableDeclarationLayerType())
.addCoCo(new CheckLayerVariableDeclarationIsUsed())
.addCoCo(new CheckConstants());
.addCoCo(new CheckConstants())
.addCoCo(new CheckUnrollInputsOutputsTooMany());
}
//checks cocos based on symbols before the resolve method of the ArchitectureSymbol is called
......
......@@ -42,6 +42,10 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
else {
checkIOArray(ioDeclaration);
}
if (ioDeclaration.isOutput()) {
checkOutputWrittenToOnce(ioDeclaration);
}
}
}
......@@ -89,4 +93,63 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
}
}
private void checkOutputWrittenToOnce(IODeclarationSymbol ioDeclaration) {
List<Integer> written = new ArrayList<>();
for (NetworkInstructionSymbol networkInstruction : ioDeclaration.getArchitecture().getNetworkInstructions()) {
if (networkInstruction.isStream()) {
SerialCompositeElementSymbol body = networkInstruction.getBody();
List<ArchitectureElementSymbol> outputs = body.getLastAtomicElements();
for (ArchitectureElementSymbol output : outputs) {
if (output instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) output;
if (variable.getName().equals(ioDeclaration.getName())) {
int arrayAccess = 0;
if (variable.getArrayAccess().isPresent()) {
arrayAccess = variable.getArrayAccess().get().getIntValue().orElse(0);
}
if (!written.contains(arrayAccess)) {
written.add(arrayAccess);
} else {
Log.error("0" + ErrorCodes.OUTPUT_WRITTEN_TO_MULTIPLE_TIMES + " " + variable.getName() + "["
+ arrayAccess + "] is written to multiple times, this is currently not allowed."
, networkInstruction.getSourcePosition());
}
}
}
}
} else if (networkInstruction.isUnroll()) {
for (SerialCompositeElementSymbol body : networkInstruction.toUnrollInstruction().getResolvedBodies()) {
List<ArchitectureElementSymbol> outputs = body.getLastAtomicElements();
for (ArchitectureElementSymbol output : outputs) {
if (output instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) output;
if (variable.getName().equals(ioDeclaration.getName())) {
int arrayAccess = 0;
if (variable.getArrayAccess().isPresent()) {
arrayAccess = variable.getArrayAccess().get().getIntValue().orElse(0);
}
if (!written.contains(arrayAccess)) {
written.add(arrayAccess);
} else {
Log.error("0" + ErrorCodes.OUTPUT_WRITTEN_TO_MULTIPLE_TIMES + " " + variable.getName() + "["
+ arrayAccess + "] is written to multiple times, this is currently not allowed."
, networkInstruction.getSourcePosition());
}
}
}
}
}
}
}
}
}
......@@ -20,25 +20,50 @@
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.math3.geometry.spherical.oned.Arc;
public class CheckNetworkStreamMissing extends CNNArchSymbolCoCo {
import java.util.List;
import java.util.Optional;
public class CheckUnrollInputsOutputsTooMany extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
boolean hasTrainableStream = false;
public void check(UnrollInstructionSymbol sym) {
int countInputs = 0;
for (ArchitectureElementSymbol input : sym.getBody().getFirstAtomicElements()) {
if (input instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) input;
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
hasTrainableStream |= networkInstruction.getBody().isTrainable();
if (variable.isOutput() && variable.getType() == VariableSymbol.Type.IO) {
++countInputs;
}
}
}
if (!hasTrainableStream) {
Log.error("0" + ErrorCodes.MISSING_TRAINABLE_STREAM + " The architecture has no trainable stream. "
, architecture.getSourcePosition());
if (countInputs > 1) {
Log.error("0" + ErrorCodes.UNROLL_INPUTS_TOO_MANY + " Only one input is allowed for timed constructs."
, sym.getSourcePosition());
}
}
int countOutputs = 0;
for (ArchitectureElementSymbol output : sym.getBody().getLastAtomicElements()) {
if (output instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) output;
if (variable.isOutput() && variable.getType() == VariableSymbol.Type.IO) {
++countOutputs;
}
}
}
if (countOutputs > 1) {
Log.error("0" + ErrorCodes.UNROLL_OUTPUTS_TOO_MANY + " Only one output is allowed for timed constructs."
, sym.getSourcePosition());
}
}
}
......@@ -40,35 +40,6 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
abstract protected void setElements(List<ArchitectureElementSymbol> elements);
public boolean isTrainable() {
boolean isTrainable = false;
for (ArchitectureElementSymbol element : elements) {
if (element instanceof CompositeElementSymbol) {
isTrainable |= ((CompositeElementSymbol) element).isTrainable();
}
else if (element instanceof LayerSymbol) {
isTrainable |= ((LayerSymbol) element).getDeclaration().isTrainable();
}
else if (element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.LAYER) {
LayerDeclarationSymbol layerDeclaration = ((LayerVariableDeclarationSymbol) variable.getDeclaration()).getLayer().getDeclaration();
if (layerDeclaration.isPredefined()) {
isTrainable |= ((PredefinedLayerDeclaration) layerDeclaration).isTrainable(variable.getMember());
}
else {
isTrainable |= layerDeclaration.isTrainable();
}
}
}
}
return isTrainable;
}
@Override
public boolean isAtomic() {
return getElements().isEmpty();
......
......@@ -79,10 +79,6 @@ public class LayerDeclarationSymbol extends CommonScopeSpanningSymbol {
return body;
}
public boolean isTrainable() {
return body.isTrainable();
}
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
......
......@@ -52,16 +52,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
return true;
}
/**
* This method is used to distinguish between neural networks like "source -> FullyConnected() -> target" and
* basic assignments like "1 -> OneHot() -> target". The generators use this to avoid creating an own
* network for each assignment. Override by predefined layers which are trainable.
*/
@Override
public boolean isTrainable() {
return isTrainable(VariableSymbol.Member.NONE);
}
public boolean isTrainable(VariableSymbol.Member member) {
if(member == VariableSymbol.Member.STATE || member == VariableSymbol.Member.OUTPUT){
return false;
......
......@@ -49,5 +49,8 @@ public class ErrorCodes {
public static final String ILLEGAL_LAYER_USE = "x04845";
public static final String UNUSED_LAYER = "x04847";
public static final String INVALID_CONSTANT = "x04856";
public static final String OUTPUT_WRITTEN_TO_MULTIPLE_TIMES = "x04836";
public static final String UNROLL_INPUTS_TOO_MANY = "x02384";
public static final String UNROLL_OUTPUTS_TOO_MANY = "x02385";
}
......@@ -36,11 +36,6 @@ public class Add extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ADD_NAME);
}
@Override
public boolean isTrainable() {
return false;
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
List<String> range = computeStartAndEndValue(layer.getInputTypes(), Rational::plus, Rational::plus);
......
......@@ -32,11 +32,6 @@ public class ArgMax extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ARG_MAX_NAME);
}
@Override
public boolean isTrainable() {
return false;
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
ArchTypeSymbol inputType = layer.getInputTypes().get(0);
......
......@@ -35,11 +35,6 @@ public class Get extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.GET_NAME);
}
@Override
public boolean isTrainable() {
return false;
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int index = layer.getIntValue(AllPredefinedLayers.INDEX_NAME).get();
......
......@@ -68,15 +68,4 @@ public class SymtabTest extends AbstractSymtabTest {
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
@Test
public void testMultipleOutputs(){
Scope symTab = createSymTab("src/test/resources/valid_tests");
CNNArchCompilationUnitSymbol a = symTab.<CNNArchCompilationUnitSymbol>resolve(
"MultipleOutputs",
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
}
......@@ -64,7 +64,6 @@ public class AllCoCoTest extends AbstractCoCoTest {
checkValid("valid_tests", "SimpleNetworkTanh");
checkValid("valid_tests", "ResNeXt50_alt");
checkValid("valid_tests", "Alexnet_alt2");
checkValid("valid_tests", "MultipleOutputs");
checkValid("valid_tests", "MultipleStreams");
checkValid("valid_tests", "RNNencdec");
checkValid("valid_tests", "RNNsearch");
......@@ -307,4 +306,21 @@ public class AllCoCoTest extends AbstractCoCoTest {
new ExpectedErrorInfo(2, ErrorCodes.MISSING_MERGE));
}
@Test
public void testOutputWrittenToMultipleTimes() {
checkInvalid(new CNNArchCoCoChecker(),
new CNNArchSymbolCoCoChecker(),
new CNNArchSymbolCoCoChecker().addCoCo(new CheckIOAccessAndIOMissing()),
"invalid_tests", "OutputWrittenToMultipleTimes",
new ExpectedErrorInfo(1, ErrorCodes.OUTPUT_WRITTEN_TO_MULTIPLE_TIMES));
}
@Test
public void testUnrollInputsTooMany() {
checkInvalid(new CNNArchCoCoChecker(),
new CNNArchSymbolCoCoChecker(),
new CNNArchSymbolCoCoChecker().addCoCo(new CheckUnrollInputsOutputsTooMany()),
"invalid_tests", "UnrollInputsTooMany",
new ExpectedErrorInfo(1, ErrorCodes.UNROLL_INPUTS_TOO_MANY));
}
}
\ No newline at end of file
architecture OutputWrittenToMultipleTimes{
def input Q(-oo:+oo)^{10} data[2]
def output Q(0:1)^{4} pred
data[0] ->
FullyConnected(units=4, no_bias=true) ->
Softmax() ->
pred;
data[1] ->
FullyConnected(units=4, no_bias=true) ->
Softmax() ->
pred;
}
\ No newline at end of file
architecture UnrollInputsTooMany{
def input Q(0:1)^{4} in
def output Q(0:1)^{4} out[2]
in -> Softmax() -> out[0];
timed<t> GreedySearch(max_length=2) {
(out[0] | out[t-1]) ->
Concatenate() ->
FullyConnected(units=4) ->
Softmax() ->
out[t]
};
}
\ No newline at end of file
architecture MultipleOutputs{
def input Q(-oo:+oo)^{10} data
def output Q(0:1)^{4} pred[2]
data ->
FullyConnected(units=128, no_bias=true) ->
Tanh() ->
(
FullyConnected(units=16, no_bias=true) ->
Tanh() ->
FullyConnected(units=4, no_bias=true) ->
Softmax()
|
FullyConnected(units=16, no_bias=true) ->
Tanh() ->
FullyConnected(units=4, no_bias=true) ->
Softmax()
) ->
pred;
}
\ No newline at end of file
......@@ -14,7 +14,7 @@ architecture RNNencdec{
encoder.state -> decoder.state;
timed<t> GreedySearch(max_length=30) {
timed<t> BeamSearch(max_length=30, width=3) {
target[t-1] ->
Embedding(output_dim=620) ->
decoder ->
......
......@@ -12,7 +12,7 @@ architecture RNNsearch{
layer GRU(units=1000) decoder;
encoder.state -> Split(n=2) -> [1] -> decoder.state;
timed<t> GreedySearch(max_length=30) {
timed<t> BeamSearch(max_length=30, width=3) {
(
(
(
......
architecture RNNtest(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target[5]
architecture RNNtest{
def input Q(0:1)^{30000} source
def output Q(0:1)^{30000} target[5]
source -> Softmax() -> target[0];
source -> Softmax() -> target[0];
timed <t> BeamSearch(max_length=5){
(target[0] | target[t-1]) ->
Concatenate() ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
};
timed <t> BeamSearch(max_length=5){
target[t-1] ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment