Commit c12de894 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'AdaNet_Luis_Rickert_ba' into 'master'

[update]: AdaNet integration

See merge request !19
parents 50f764d3 205ad921
Pipeline #531540 passed with stage
in 2 minutes and 3 seconds
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>0.4.8-SNAPSHOT</version>
<version>0.4.9-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,14 +17,14 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.4.6-SNAPSHOT</CNNArch.version>
<CNNArch.version>0.4.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.4.5-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-generator>0.4.8</embedded-montiarc-math-generator>
<embedded-montiarc-math-generator>0.4.7</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. -->
<guava.version>25.1-jre</guava.version>
<junit.version>4.13.1</junit.version>
<logback.version>1.2.0</logback.version>
<logback.version>1.1.2</logback.version>
<jscience.version>4.3.1</jscience.version>
<!-- .. Plugins ....................................................... -->
......
......@@ -103,6 +103,9 @@ public abstract class CNNArchTemplateController {
this.nameManager = new LayerNameCreator(architecture);
}
public boolean containsAdaNet(){
return this.architecture.containsAdaNet();
}
public String getName(ArchitectureElementSymbol layer){
return nameManager.getName(layer);
}
......@@ -127,7 +130,9 @@ public abstract class CNNArchTemplateController {
List<String> inputNames = new ArrayList<>();
for (ArchitectureElementSymbol input : layer.getPrevious()) {
if (input.getOutputTypes().size() == 1) {
if(input.isArtificial()){
inputNames.add(getName(input));
}else if (input.getOutputTypes().size() == 1) {
inputNames.add(getName(input));
} else {
for (int i = 0; i < input.getOutputTypes().size(); i++) {
......
......@@ -7,7 +7,8 @@ import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
import de.monticore.lang.monticar.cnnarch.predefined.LargeMemory;
import de.monticore.lang.monticar.cnnarch.predefined.EpisodicMemory;
import de.monticore.lang.monticar.cnnarch.predefined.AdaNet;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import java.util.*;
public class LayerNameCreator {
......@@ -40,26 +41,60 @@ public class LayerNameCreator {
} else if (architectureElement instanceof ParallelCompositeElementSymbol) {
return nameParallelComposite((ParallelCompositeElementSymbol) architectureElement, stage, streamIndices);
} else {
if (architectureElement.isAtomic()) {
if (architectureElement.getMaxSerialLength().get() > 0){
boolean noAdaNet = !architectureElement.containsAdaNet();
// flag which is true if there is no AdaNet inside the architecture
if (architectureElement.isAtomic() && (!architectureElement.isArtificial() || noAdaNet)) {
if (architectureElement.getMaxSerialLength().get() > 0) {
return add(architectureElement, stage, streamIndices);
} else {
return stage;
}
} else {
ArchitectureElementSymbol resolvedElement = (ArchitectureElementSymbol) architectureElement.getResolvedThis().get();
return name(resolvedElement, stage, streamIndices);
int final_stage = name(resolvedElement, stage, streamIndices);
if (architectureElement.isArtificial() && !noAdaNet) {
// if the element is artificial the name needs to be added only if an adaNet layer is present
final_stage = add(architectureElement, final_stage, streamIndices);
}
return final_stage;
}
}
}
protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
protected int nameAdaNetBlock(String target,ArchitectureElementSymbol subElement,int endStage, List<Integer> streamIndices){
Optional<ArchitectureElementSymbol> currentBlock = ((AdaNet) ((LayerSymbol) subElement).getDeclaration()).getBlock(target);
if(currentBlock.isPresent()) {
if (currentBlock.get().isArtificial() || target.equals(AllPredefinedLayers.Block)) {
boolean oldState = currentBlock.get().containsAdaNet();
currentBlock.get().setAdaNet(true);
endStage = name(currentBlock.get(), endStage, streamIndices);
currentBlock.get().setAdaNet(oldState);
}
}
return endStage;
}
protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices) {
int endStage = stage;
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStage = name(subElement, endStage, streamIndices);
for (ArchitectureElementSymbol subElement : compositeElement.getElements()) {
if (subElement.isArtificial() && compositeElement.containsAdaNet()) {
endStage = name(subElement, endStage, streamIndices);
} else if (subElement.getName().equals(AllPredefinedLayers.AdaNet_Name)) {
// name outBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.Out,subElement,endStage,streamIndices);
// name inBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.In,subElement,endStage,streamIndices);
// name buildBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.Block,subElement,endStage,streamIndices);
endStage = name(subElement, endStage, streamIndices);
} else {
endStage = name(subElement, endStage, streamIndices);
}
}
for (List<ArchitectureElementSymbol> subNetwork : compositeElement.getEpisodicSubNetworks()){
for (ArchitectureElementSymbol subElement : subNetwork){
for (List<ArchitectureElementSymbol> subNetwork : compositeElement.getEpisodicSubNetworks()) {
for (ArchitectureElementSymbol subElement : subNetwork) {
endStage = name(subElement, endStage, streamIndices);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment