Commit e7992cdb authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Removed isTrainable since it is not needed anymore and added some cocos to...

Removed isTrainable since it is not needed anymore and added some cocos to ensure correct functionality of timed construct
parent 75ffe205
......@@ -68,11 +68,11 @@ public class CNNArchCocos {
.addCoCo(new CheckElementInputs())
.addCoCo(new CheckIOAccessAndIOMissing())
.addCoCo(new CheckArchitectureFinished())
.addCoCo(new CheckNetworkStreamMissing())
.addCoCo(new CheckVariableMember())
.addCoCo(new CheckLayerVariableDeclarationLayerType())
.addCoCo(new CheckLayerVariableDeclarationIsUsed())
.addCoCo(new CheckConstants());
.addCoCo(new CheckConstants())
.addCoCo(new CheckUnrollInputsOutputsTooMany());
}
//checks cocos based on symbols before the resolve method of the ArchitectureSymbol is called
......
......@@ -42,6 +42,10 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
else {
checkIOArray(ioDeclaration);
}
if (ioDeclaration.isOutput()) {
checkOutputWrittenToOnce(ioDeclaration);
}
}
}
......@@ -89,4 +93,63 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
}
}
private void checkOutputWrittenToOnce(IODeclarationSymbol ioDeclaration) {
List<Integer> written = new ArrayList<>();
for (NetworkInstructionSymbol networkInstruction : ioDeclaration.getArchitecture().getNetworkInstructions()) {
if (networkInstruction.isStream()) {
SerialCompositeElementSymbol body = networkInstruction.getBody();
List<ArchitectureElementSymbol> outputs = body.getLastAtomicElements();
for (ArchitectureElementSymbol output : outputs) {
if (output instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) output;
if (variable.getName().equals(ioDeclaration.getName())) {
int arrayAccess = 0;
if (variable.getArrayAccess().isPresent()) {
arrayAccess = variable.getArrayAccess().get().getIntValue().orElse(0);
}
if (!written.contains(arrayAccess)) {
written.add(arrayAccess);
} else {
Log.error("0" + ErrorCodes.OUTPUT_WRITTEN_TO_MULTIPLE_TIMES + " " + variable.getName() + "["
+ arrayAccess + "] is written to multiple times, this is currently not allowed."
, networkInstruction.getSourcePosition());
}
}
}
}
} else if (networkInstruction.isUnroll()) {
for (SerialCompositeElementSymbol body : networkInstruction.toUnrollInstruction().getResolvedBodies()) {
List<ArchitectureElementSymbol> outputs = body.getLastAtomicElements();
for (ArchitectureElementSymbol output : outputs) {
if (output instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) output;
if (variable.getName().equals(ioDeclaration.getName())) {
int arrayAccess = 0;
if (variable.getArrayAccess().isPresent()) {
arrayAccess = variable.getArrayAccess().get().getIntValue().orElse(0);
}
if (!written.contains(arrayAccess)) {
written.add(arrayAccess);
} else {
Log.error("0" + ErrorCodes.OUTPUT_WRITTEN_TO_MULTIPLE_TIMES + " " + variable.getName() + "["
+ arrayAccess + "] is written to multiple times, this is currently not allowed."
, networkInstruction.getSourcePosition());
}
}
}
}
}
}
}
}
}
......@@ -20,25 +20,50 @@
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.math3.geometry.spherical.oned.Arc;
public class CheckNetworkStreamMissing extends CNNArchSymbolCoCo {
import java.util.List;
import java.util.Optional;
public class CheckUnrollInputsOutputsTooMany extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
boolean hasTrainableStream = false;
public void check(UnrollInstructionSymbol sym) {
int countInputs = 0;
for (ArchitectureElementSymbol input : sym.getBody().getFirstAtomicElements()) {
if (input instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) input;
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
hasTrainableStream |= networkInstruction.getBody().isTrainable();
if (variable.isOutput() && variable.getType() == VariableSymbol.Type.IO) {
++countInputs;
}
}
}
if (!hasTrainableStream) {
Log.error("0" + ErrorCodes.MISSING_TRAINABLE_STREAM + " The architecture has no trainable stream. "
, architecture.getSourcePosition());
if (countInputs > 1) {
Log.error("0" + ErrorCodes.UNROLL_INPUTS_TOO_MANY + " Only one input is allowed for timed constructs."
, sym.getSourcePosition());
}
}
int countOutputs = 0;
for (ArchitectureElementSymbol output : sym.getBody().getLastAtomicElements()) {
if (output instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) output;
if (variable.isOutput() && variable.getType() == VariableSymbol.Type.IO) {
++countOutputs;
}
}
}
if (countOutputs > 1) {
Log.error("0" + ErrorCodes.UNROLL_OUTPUTS_TOO_MANY + " Only one output is allowed for timed constructs."
, sym.getSourcePosition());
}
}
}
......@@ -40,35 +40,6 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
abstract protected void setElements(List<ArchitectureElementSymbol> elements);
public boolean isTrainable() {
boolean isTrainable = false;
for (ArchitectureElementSymbol element : elements) {
if (element instanceof CompositeElementSymbol) {
isTrainable |= ((CompositeElementSymbol) element).isTrainable();
}
else if (element instanceof LayerSymbol) {
isTrainable |= ((LayerSymbol) element).getDeclaration().isTrainable();
}
else if (element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.LAYER) {
LayerDeclarationSymbol layerDeclaration = ((LayerVariableDeclarationSymbol) variable.getDeclaration()).getLayer().getDeclaration();
if (layerDeclaration.isPredefined()) {
isTrainable |= ((PredefinedLayerDeclaration) layerDeclaration).isTrainable(variable.getMember());
}
else {
isTrainable |= layerDeclaration.isTrainable();
}
}
}
}
return isTrainable;
}
@Override
public boolean isAtomic() {
return getElements().isEmpty();
......
......@@ -79,10 +79,6 @@ public class LayerDeclarationSymbol extends CommonScopeSpanningSymbol {
return body;
}
public boolean isTrainable() {
return body.isTrainable();
}
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
......
......@@ -52,16 +52,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
return true;
}
/**
* This method is used to distinguish between neural networks like "source -> FullyConnected() -> target" and
* basic assignments like "1 -> OneHot() -> target". The generators use this to avoid creating an own
* network for each assignment. Override by predefined layers which are trainable.
*/
@Override
public boolean isTrainable() {
return isTrainable(VariableSymbol.Member.NONE);
}
public boolean isTrainable(VariableSymbol.Member member) {
if(member == VariableSymbol.Member.STATE || member == VariableSymbol.Member.OUTPUT){
return false;
......
......@@ -49,5 +49,8 @@ public class ErrorCodes {
public static final String ILLEGAL_LAYER_USE = "x04845";
public static final String UNUSED_LAYER = "x04847";
public static final String INVALID_CONSTANT = "x04856";
public static final String OUTPUT_WRITTEN_TO_MULTIPLE_TIMES = "x04836";
public static final String UNROLL_INPUTS_TOO_MANY = "x02384";
public static final String UNROLL_OUTPUTS_TOO_MANY = "x02385";
}
......@@ -36,11 +36,6 @@ public class Add extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ADD_NAME);
}
@Override
public boolean isTrainable() {
return false;
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
List<String> range = computeStartAndEndValue(layer.getInputTypes(), Rational::plus, Rational::plus);
......
......@@ -32,11 +32,6 @@ public class ArgMax extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.ARG_MAX_NAME);
}
@Override
public boolean isTrainable() {
return false;
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
ArchTypeSymbol inputType = layer.getInputTypes().get(0);
......
......@@ -35,11 +35,6 @@ public class Get extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.GET_NAME);
}
@Override
public boolean isTrainable() {
return false;
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int index = layer.getIntValue(AllPredefinedLayers.INDEX_NAME).get();
......
......@@ -68,15 +68,4 @@ public class SymtabTest extends AbstractSymtabTest {
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
@Test
public void testMultipleOutputs(){
Scope symTab = createSymTab("src/test/resources/valid_tests");
CNNArchCompilationUnitSymbol a = symTab.<CNNArchCompilationUnitSymbol>resolve(
"MultipleOutputs",
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
}
......@@ -64,7 +64,6 @@ public class AllCoCoTest extends AbstractCoCoTest {
checkValid("valid_tests", "SimpleNetworkTanh");
checkValid("valid_tests", "ResNeXt50_alt");
checkValid("valid_tests", "Alexnet_alt2");
checkValid("valid_tests", "MultipleOutputs");
checkValid("valid_tests", "MultipleStreams");
checkValid("valid_tests", "RNNencdec");
checkValid("valid_tests", "RNNsearch");
......@@ -307,4 +306,21 @@ public class AllCoCoTest extends AbstractCoCoTest {
new ExpectedErrorInfo(2, ErrorCodes.MISSING_MERGE));
}
@Test
public void testOutputWrittenToMultipleTimes() {
checkInvalid(new CNNArchCoCoChecker(),
new CNNArchSymbolCoCoChecker(),
new CNNArchSymbolCoCoChecker().addCoCo(new CheckIOAccessAndIOMissing()),
"invalid_tests", "OutputWrittenToMultipleTimes",
new ExpectedErrorInfo(1, ErrorCodes.OUTPUT_WRITTEN_TO_MULTIPLE_TIMES));
}
@Test
public void testUnrollInputsTooMany() {
checkInvalid(new CNNArchCoCoChecker(),
new CNNArchSymbolCoCoChecker(),
new CNNArchSymbolCoCoChecker().addCoCo(new CheckUnrollInputsOutputsTooMany()),
"invalid_tests", "UnrollInputsTooMany",
new ExpectedErrorInfo(1, ErrorCodes.UNROLL_INPUTS_TOO_MANY));
}
}
\ No newline at end of file
architecture OutputWrittenToMultipleTimes{
def input Q(-oo:+oo)^{10} data[2]
def output Q(0:1)^{4} pred
data[0] ->
FullyConnected(units=4, no_bias=true) ->
Softmax() ->
pred;
data[1] ->
FullyConnected(units=4, no_bias=true) ->
Softmax() ->
pred;
}
\ No newline at end of file
architecture UnrollInputsTooMany{
def input Q(0:1)^{4} in
def output Q(0:1)^{4} out[2]
in -> Softmax() -> out[0];
timed<t> GreedySearch(max_length=2) {
(out[0] | out[t-1]) ->
Concatenate() ->
FullyConnected(units=4) ->
Softmax() ->
out[t]
};
}
\ No newline at end of file
architecture MultipleOutputs{
def input Q(-oo:+oo)^{10} data
def output Q(0:1)^{4} pred[2]
data ->
FullyConnected(units=128, no_bias=true) ->
Tanh() ->
(
FullyConnected(units=16, no_bias=true) ->
Tanh() ->
FullyConnected(units=4, no_bias=true) ->
Softmax()
|
FullyConnected(units=16, no_bias=true) ->
Tanh() ->
FullyConnected(units=4, no_bias=true) ->
Softmax()
) ->
pred;
}
\ No newline at end of file
......@@ -14,7 +14,7 @@ architecture RNNencdec{
encoder.state -> decoder.state;
timed<t> GreedySearch(max_length=30) {
timed<t> BeamSearch(max_length=30, width=3) {
target[t-1] ->
Embedding(output_dim=620) ->
decoder ->
......
......@@ -12,7 +12,7 @@ architecture RNNsearch{
layer GRU(units=1000) decoder;
encoder.state -> Split(n=2) -> [1] -> decoder.state;
timed<t> GreedySearch(max_length=30) {
timed<t> BeamSearch(max_length=30, width=3) {
(
(
(
......
architecture RNNtest(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target[5]
architecture RNNtest{
def input Q(0:1)^{30000} source
def output Q(0:1)^{30000} target[5]
source -> Softmax() -> target[0];
source -> Softmax() -> target[0];
timed <t> BeamSearch(max_length=5){
(target[0] | target[t-1]) ->
Concatenate() ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
};
timed <t> BeamSearch(max_length=5){
target[t-1] ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment