Commit c12de894 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'AdaNet_Luis_Rickert_ba' into 'master'

[update]: AdaNet integration

See merge request !19
parents 50f764d3 205ad921
Pipeline #531540 passed with stage
in 2 minutes and 3 seconds
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId> <artifactId>cnnarch-generator</artifactId>
<version>0.4.8-SNAPSHOT</version> <version>0.4.9-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.4.6-SNAPSHOT</CNNArch.version> <CNNArch.version>0.4.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.4.5-SNAPSHOT</CNNTrain.version> <CNNTrain.version>0.4.5-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-generator>0.4.8</embedded-montiarc-math-generator> <embedded-montiarc-math-generator>0.4.7</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
<guava.version>25.1-jre</guava.version> <guava.version>25.1-jre</guava.version>
<junit.version>4.13.1</junit.version> <junit.version>4.13.1</junit.version>
<logback.version>1.2.0</logback.version> <logback.version>1.1.2</logback.version>
<jscience.version>4.3.1</jscience.version> <jscience.version>4.3.1</jscience.version>
<!-- .. Plugins ....................................................... --> <!-- .. Plugins ....................................................... -->
......
...@@ -103,6 +103,9 @@ public abstract class CNNArchTemplateController { ...@@ -103,6 +103,9 @@ public abstract class CNNArchTemplateController {
this.nameManager = new LayerNameCreator(architecture); this.nameManager = new LayerNameCreator(architecture);
} }
public boolean containsAdaNet(){
return this.architecture.containsAdaNet();
}
public String getName(ArchitectureElementSymbol layer){ public String getName(ArchitectureElementSymbol layer){
return nameManager.getName(layer); return nameManager.getName(layer);
} }
...@@ -127,7 +130,9 @@ public abstract class CNNArchTemplateController { ...@@ -127,7 +130,9 @@ public abstract class CNNArchTemplateController {
List<String> inputNames = new ArrayList<>(); List<String> inputNames = new ArrayList<>();
for (ArchitectureElementSymbol input : layer.getPrevious()) { for (ArchitectureElementSymbol input : layer.getPrevious()) {
if (input.getOutputTypes().size() == 1) { if(input.isArtificial()){
inputNames.add(getName(input));
}else if (input.getOutputTypes().size() == 1) {
inputNames.add(getName(input)); inputNames.add(getName(input));
} else { } else {
for (int i = 0; i < input.getOutputTypes().size(); i++) { for (int i = 0; i < input.getOutputTypes().size(); i++) {
......
...@@ -7,7 +7,8 @@ import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected; ...@@ -7,7 +7,8 @@ import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
import de.monticore.lang.monticar.cnnarch.predefined.Pooling; import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
import de.monticore.lang.monticar.cnnarch.predefined.LargeMemory; import de.monticore.lang.monticar.cnnarch.predefined.LargeMemory;
import de.monticore.lang.monticar.cnnarch.predefined.EpisodicMemory; import de.monticore.lang.monticar.cnnarch.predefined.EpisodicMemory;
import de.monticore.lang.monticar.cnnarch.predefined.AdaNet;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import java.util.*; import java.util.*;
public class LayerNameCreator { public class LayerNameCreator {
...@@ -40,26 +41,60 @@ public class LayerNameCreator { ...@@ -40,26 +41,60 @@ public class LayerNameCreator {
} else if (architectureElement instanceof ParallelCompositeElementSymbol) { } else if (architectureElement instanceof ParallelCompositeElementSymbol) {
return nameParallelComposite((ParallelCompositeElementSymbol) architectureElement, stage, streamIndices); return nameParallelComposite((ParallelCompositeElementSymbol) architectureElement, stage, streamIndices);
} else { } else {
if (architectureElement.isAtomic()) { boolean noAdaNet = !architectureElement.containsAdaNet();
if (architectureElement.getMaxSerialLength().get() > 0){ // flag which is true if there is no AdaNet inside the architecture
if (architectureElement.isAtomic() && (!architectureElement.isArtificial() || noAdaNet)) {
if (architectureElement.getMaxSerialLength().get() > 0) {
return add(architectureElement, stage, streamIndices); return add(architectureElement, stage, streamIndices);
} else { } else {
return stage; return stage;
} }
} else { } else {
ArchitectureElementSymbol resolvedElement = (ArchitectureElementSymbol) architectureElement.getResolvedThis().get(); ArchitectureElementSymbol resolvedElement = (ArchitectureElementSymbol) architectureElement.getResolvedThis().get();
return name(resolvedElement, stage, streamIndices); int final_stage = name(resolvedElement, stage, streamIndices);
if (architectureElement.isArtificial() && !noAdaNet) {
// if the element is artificial the name needs to be added only if an adaNet layer is present
final_stage = add(architectureElement, final_stage, streamIndices);
}
return final_stage;
} }
} }
} }
protected int nameAdaNetBlock(String target,ArchitectureElementSymbol subElement,int endStage, List<Integer> streamIndices){
protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
Optional<ArchitectureElementSymbol> currentBlock = ((AdaNet) ((LayerSymbol) subElement).getDeclaration()).getBlock(target);
if(currentBlock.isPresent()) {
if (currentBlock.get().isArtificial() || target.equals(AllPredefinedLayers.Block)) {
boolean oldState = currentBlock.get().containsAdaNet();
currentBlock.get().setAdaNet(true);
endStage = name(currentBlock.get(), endStage, streamIndices);
currentBlock.get().setAdaNet(oldState);
}
}
return endStage;
}
protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices) {
int endStage = stage; int endStage = stage;
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){ for (ArchitectureElementSymbol subElement : compositeElement.getElements()) {
endStage = name(subElement, endStage, streamIndices); if (subElement.isArtificial() && compositeElement.containsAdaNet()) {
endStage = name(subElement, endStage, streamIndices);
} else if (subElement.getName().equals(AllPredefinedLayers.AdaNet_Name)) {
// name outBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.Out,subElement,endStage,streamIndices);
// name inBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.In,subElement,endStage,streamIndices);
// name buildBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.Block,subElement,endStage,streamIndices);
endStage = name(subElement, endStage, streamIndices);
} else {
endStage = name(subElement, endStage, streamIndices);
}
} }
for (List<ArchitectureElementSymbol> subNetwork : compositeElement.getEpisodicSubNetworks()){
for (ArchitectureElementSymbol subElement : subNetwork){ for (List<ArchitectureElementSymbol> subNetwork : compositeElement.getEpisodicSubNetworks()) {
for (ArchitectureElementSymbol subElement : subNetwork) {
endStage = name(subElement, endStage, streamIndices); endStage = name(subElement, endStage, streamIndices);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment