Commit e333f3fd authored by Christian Fuß's avatar Christian Fuß
Browse files

added method to generate the additional layers generated by the unrolling procedure

parent 61dc6ff8
Pipeline #175999 passed with stages
in 20 minutes and 20 seconds
......@@ -127,6 +127,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
{
try {
unroll.resolve();
unroll = unroll.createUnrollForBackend();
}
catch (ArchResolveException e) {
// Do nothing; error is already logged
......
......@@ -43,6 +43,7 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
private ParameterSymbol timeParameter;
private UnrollSymbol resolvedThis = null;
private SerialCompositeElementSymbol body;
private boolean isExtendedForBackend = false;
public SerialCompositeElementSymbol getBody() {
return body;
......@@ -52,6 +53,63 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
this.body = body;
}
public boolean isExtended(){
return this.isExtendedForBackend;
}
private void setExtended(boolean extended){
this.isExtendedForBackend = extended;
}
public UnrollSymbol createUnrollForBackend(){
if(this.isExtendedForBackend){
return this;
}else {
int i;
boolean skipFirst;
List<ArchitectureElementSymbol> newBodyList = new ArrayList<>();
for (i = this.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get(); i < this.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get(); i++) {
skipFirst = true;
SerialCompositeElementSymbol body = this.getBody().preResolveDeepCopy();
body.putInScope(this.getBody().getSpannedScope());
for (ArchitectureElementSymbol element : body.getElements()) {
if (!(i > 0 && skipFirst)) {
if(element.getEnclosingScope() == null) {
element.setEnclosingScope(getEnclosingScope().getAsMutableScope());
}
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.BEAMSEARCH_T_NAME)).get(0)).getExpression().setValue(i);
try {
this.resolveExpressions();
} catch (ArchResolveException e) {
e.printStackTrace();
}
for(ParameterSymbol p:declaration.getParameters()){
if(p.getEnclosingScope() == null) {
p.putInScope(getSpannedScope());
}
}
try {
element.resolve();
} catch (ArchResolveException e) {
e.printStackTrace();
}
newBodyList.add(element);
}
skipFirst=false;
}
}
SerialCompositeElementSymbol newBody = new SerialCompositeElementSymbol();
newBody.putInScope(this.getBody().getSpannedScope());
newBody.setElements(newBodyList);
this.setBody(newBody);
this.setExtended(true);
List<UnrollSymbol> newUnrolls = new ArrayList<>();
newUnrolls.add(this);
getBody().getArchitecture().setUnrolls(newUnrolls);
return this;
}
}
public boolean isTrainable() {
return body.isTrainable();
}
......@@ -306,6 +364,7 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
System.err.println("t_value: " + ((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get("t")).get(0)).getValue().get().toString());
copy.getTimeParameter().putInScope(getSpannedScope());
copy.setBody(getBody().preResolveDeepCopy());
copy.getBody().putInScope(copy.getSpannedScope());
//TODO: Find a nicer way to put the timeParameter into the unroll elements' scope
......
architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target[3]
def output Q(0:1)^{vocabulary_size} target[5]
source -> Softmax() -> target[0];
timed <t=2> BeamSearchStart(max_length=3){
target[t-1] ->
timed <t=0> BeamSearchStart(max_length=4){
target[t] ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
target[t+1]
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment