Commit 66fd3199 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Develop

See merge request !11
parents 000ef532 d7018253
Pipeline #325050 passed with stage
in 4 minutes and 4 seconds
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>0.0.6-SNAPSHOT</version>
<version>0.0.7-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,8 +17,8 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.5-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.10-SNAPSHOT</CNNTrain.version>
<CNNArch.version>0.3.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.12-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -11,8 +11,7 @@ import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.*;
public class ArchitectureElementData {
......@@ -184,6 +183,132 @@ public class ArchitectureElementData {
return getLayerSymbol().getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
}
public String getNetworkDir(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.NETWORK_DIR_NAME).get();
}
public String getNetworkPrefix(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.NETWORK_PREFIX_NAME).get();
}
public int getNumInputs(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.NUM_INPUTS_NAME).get();
}
public List<Integer> getOutputShape(){
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.OUTPUT_SHAPE_NAME).get();
}
public int getScaleFactor(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.SCALE_FACTOR_NAME).get();
}
public int getDimKeys(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.DIM_KEYS_NAME).get();
}
public int getDimValues(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.DIM_VALUES_NAME).get();
}
public boolean getUseProjBias() {
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.USE_PROJ_BIAS_NAME).get();
}
public boolean getUseMask() {
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.USE_MASK_NAME).get();
}
public int getSubKeySize(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.SUB_KEY_SIZE_NAME).get();
}
public List<Integer> getQuerySize(){
if (getLayerSymbol().getIntValue(AllPredefinedLayers.QUERY_SIZE_NAME).isPresent()){
List<Integer> list = new ArrayList<>();
list.add((Integer) getLayerSymbol().getIntValue(AllPredefinedLayers.QUERY_SIZE_NAME).get());
return list;
}else{
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.QUERY_SIZE_NAME).get();
}
}
public String getQueryAct(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.QUERY_ACT_NAME).get();
}
public int getK(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.K_NAME).get();
}
public int getNumHeads(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.NUM_HEADS_NAME).get();
}
public String getStoreDistMeasure(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.STORE_DIST_MEASURE_NAME).get();
}
public int getReplayInterval(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPLAY_INTERVAL_NAME).get();
}
public int getReplayBatchSize(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPLAY_BATCH_SIZE_NAME).get();
}
public int getReplaySteps(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPLAY_STEPS_NAME).get();
}
public int getReplayGradientSteps(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_NAME).get();
}
public double getReplayMemoryStoreProb(){
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.REPLAY_MEMORY_STORE_PROB_NAME).get();
}
public int getMaxStoredSamples(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.MAX_STORED_SAMPLES_NAME).get();
}
public String getMemoryReplacementStrategy(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.MEMORY_REPLACEMENT_STRATEGY_NAME).get();
}
public boolean getUseReplay(){
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.USE_REPLAY_NAME).get();
}
public boolean getUseLocalAdaption(){
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.USE_LOCAL_ADAPTION_NAME).get();
}
public int getLocalAdaptionK(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.LOCAL_ADAPTION_K_NAME).get();
}
public int getlocalAdaptionGradientSteps(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.LOCAL_ADAPTION_GRADIENT_STEPS_NAME).get();
}
public String getQueryNetDir(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.QUERY_NET_DIR_NAME).get();
}
public String getQueryNetPrefix(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.QUERY_NET_PREFIX_NAME).get();
}
public int getQueryNetNumInputs(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.QUERY_NET_NUM_INPUTS_NAME).get();
}
public int getValuesDim(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.VALUES_DIM_NAME).get();
}
@Nullable
public List<Integer> getPadding(){
......
......@@ -5,6 +5,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
import de.monticore.lang.monticar.cnnarch.predefined.LargeMemory;
import de.monticore.lang.monticar.cnnarch.predefined.EpisodicMemory;
import java.util.*;
......@@ -56,6 +58,11 @@ public class LayerNameCreator {
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStage = name(subElement, endStage, streamIndices);
}
for (List<ArchitectureElementSymbol> subNetwork : compositeElement.getEpisodicSubNetworks()){
for (ArchitectureElementSymbol subElement : subNetwork){
endStage = name(subElement, endStage, streamIndices);
}
}
return endStage;
}
......@@ -127,6 +134,8 @@ public class LayerNameCreator {
return "fc";
} else if (layerDeclaration instanceof Pooling) {
return "pool";
} else if (layerDeclaration instanceof LargeMemory || layerDeclaration instanceof EpisodicMemory) {
return "memory";
} else {
return layerDeclaration.getName().toLowerCase();
}
......
......@@ -52,6 +52,8 @@ public abstract class TrainParamSupportChecker implements CNNTrainVisitor {
public void visit(ASTAdamOptimizer node){}
public void visit(ASTAdamWOptimizer node){}
public void visit(ASTRmsPropOptimizer node){}
public void visit(ASTAdaGradOptimizer node){}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment