Changes for ReplayMemory and minor changes

parent 7272fe25
Pipeline #276486 failed with stage
in 41 seconds
......@@ -11,8 +11,7 @@ import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.*;
public class ArchitectureElementData {
......@@ -188,12 +187,18 @@ public class ArchitectureElementData {
return getLayerSymbol().getIntValue(AllPredefinedLayers.SUB_KEY_SIZE_NAME).get();
}
public int getQuerySize(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.QUERY_SIZE_NAME).get();
public List<Integer> getQuerySize(){
if (getLayerSymbol().getIntValue(AllPredefinedLayers.QUERY_SIZE_NAME).isPresent()){
List<Integer> list = new ArrayList<>();
list.add((Integer) getLayerSymbol().getIntValue(AllPredefinedLayers.QUERY_SIZE_NAME).get());
return list;
}else{
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.QUERY_SIZE_NAME).get();
}
}
public String getActQuery(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.ACT_QUERY_NAME).get();
public String getQueryAct(){
return getLayerSymbol().getStringValue(AllPredefinedLayers.QUERY_ACT_NAME).get();
}
public int getK(){
......@@ -204,6 +209,40 @@ public class ArchitectureElementData {
return getLayerSymbol().getIntValue(AllPredefinedLayers.NUM_HEADS_NAME).get();
}
public int getReplayInterval(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPLAY_INTERVAL_NAME).get();
}
public int getReplayBatchSize(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPLAY_BATCH_SIZE_NAME).get();
}
public int getReplaySteps(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPLAY_STEPS_NAME).get();
}
public int getReplayGradientSteps(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_NAME).get();
}
public double getStoreProb(){
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.STORE_PROB_NAME).get();
}
public int getMaxStoredSamples(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.MAX_STORED_SAMPLES_NAME).get();
}
public List<Integer> getValueShape(){
if (getLayerSymbol().getIntValue(AllPredefinedLayers.VALUE_SHAPE_NAME).isPresent()){
List<Integer> list = new ArrayList<>();
list.add((Integer) getLayerSymbol().getIntValue(AllPredefinedLayers.VALUE_SHAPE_NAME).get());
return list;
}else{
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.VALUE_SHAPE_NAME).get();
}
}
@Nullable
public List<Integer> getPadding(){
......
......@@ -56,6 +56,11 @@ public class LayerNameCreator {
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStage = name(subElement, endStage, streamIndices);
}
for (List<ArchitectureElementSymbol> subNetwork : compositeElement.getReplaySubNetworks()){
for (ArchitectureElementSymbol subElement : subNetwork){
endStage = name(subElement, endStage, streamIndices);
}
}
return endStage;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment