fixes

parent 23865315
Pipeline #344500 passed with stage
in 7 minutes and 57 seconds
......@@ -33,7 +33,7 @@
<mc.grammars.assembly.version>0.0.6</mc.grammars.assembly.version>
<SIUnit.version>0.0.11</SIUnit.version>
<Common-MontiCar.version>0.0.19-SNAPSHOT</Common-MontiCar.version>
<Math.version>0.0.20-SNAPSHOT</Math.version>
<Math.version>0.2.12-SNAPSHOT</Math.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......
......@@ -32,9 +32,9 @@ public class CheckLargeMemoryLayer extends CNNArchSymbolCoCo {
Integer k = new Integer(0);
for (ArgumentSymbol arg : arguments) {
if (arg.getName().equals("subKeySize")) {
if (arg.getName().equals("subKeySize") && arg.getRhs().getIntValue().isPresent()) {
subKeySize = arg.getRhs().getIntValue().get();
} else if (arg.getName().equals("k")) {
} else if (arg.getName().equals("k") && arg.getRhs().getIntValue().isPresent()) {
k = arg.getRhs().getIntValue().get();
}
}
......
......@@ -229,6 +229,30 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty();
}
public Optional<List<Integer>> getIntOrIntTupleValues(){
Optional<List<Object>> optValue = getTupleValues();
if (optValue.isPresent()){
List<Integer> list = new ArrayList<>();
for (Object value : optValue.get()) {
if (value instanceof Integer){
list.add((Integer) value);
}
else {
return Optional.empty();
}
}
return Optional.of(list);
}else{
List<Integer> list = new ArrayList<>();
Optional<Integer> optValueInt = getIntValue();
if (optValueInt.isPresent()){
list.add(optValueInt.get());
return Optional.of(list);
}
}
return Optional.empty();
}
public Optional<List<Double>> getDoubleTupleValues() {
Optional<List<Object>> optValue = getTupleValues();
if (optValue.isPresent()){
......
......@@ -260,31 +260,47 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
List<ArchitectureElementSymbol> elementsNew = new ArrayList<>();
List<List<ArchitectureElementSymbol>> episodicSubNetworks = new ArrayList<>(new ArrayList<>());
List<ArchitectureElementSymbol> currentEpisodicSubNetworkElements = new ArrayList<>();
boolean anyEpisodicLocalAdaption = false;
for (ArchitectureElementSymbol element : elements){
if (AllPredefinedLayers.EPISODIC_REPLAY_LAYER_NAMES.contains(element.getName())) {
boolean use_replay = false;
boolean use_local_adaption = false;
boolean use_replay_specified = false;
boolean use_local_adaption_specified = false;
for (ArgumentSymbol arg : ((LayerSymbol)element).getArguments()){
if (arg.getName().equals(AllPredefinedLayers.USE_REPLAY_NAME) && (boolean)arg.getRhs().getValue().get()){
use_replay = true;
break;
}else if (arg.getName().equals(AllPredefinedLayers.USE_LOCAL_ADAPTION_NAME) && (boolean)arg.getRhs().getValue().get()){
use_local_adaption = true;
break;
if (arg.getName().equals(AllPredefinedLayers.USE_REPLAY_NAME)){
use_replay_specified = true;
if ((boolean)arg.getRhs().getValue().get()) {
use_replay = true;
break;
}
}else if (arg.getName().equals(AllPredefinedLayers.USE_LOCAL_ADAPTION_NAME)){
use_local_adaption_specified = true;
if ((boolean)arg.getRhs().getValue().get()) {
use_local_adaption = true;
anyEpisodicLocalAdaption = true;
break;
}
}
}
if (!use_replay && !use_local_adaption) {
if (!use_replay_specified) {
for (ParameterSymbol param : ((LayerSymbol) element).getDeclaration().getParameters()) {
if (param.getName().equals(AllPredefinedLayers.USE_REPLAY_NAME) &&
(boolean) param.getDefaultExpression().get().getValue().get()) {
use_replay = true;
break;
} else if (param.getName().equals(AllPredefinedLayers.USE_LOCAL_ADAPTION_NAME) &&
}
}
}
if (!use_local_adaption_specified) {
for (ParameterSymbol param : ((LayerSymbol) element).getDeclaration().getParameters()) {
if (param.getName().equals(AllPredefinedLayers.USE_LOCAL_ADAPTION_NAME) &&
(boolean) param.getDefaultExpression().get().getValue().get()) {
use_local_adaption = true;
anyEpisodicLocalAdaption = true;
break;
}
}
......@@ -303,6 +319,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
episodicSubNetworks.add(currentEpisodicSubNetworkElements);
}
networkInstruction.getBody().setEpisodicSubNetworks(episodicSubNetworks);
networkInstruction.getBody().setAnyEpisodicLocalAdaption(anyEpisodicLocalAdaption);
}
}
}
......@@ -269,26 +269,6 @@ public enum Constraints {
+ AllPredefinedLayers.MEMORY_ACTIVATION_SOFTSIGN;
}
},
DIST_MEASURE_TYPE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
Optional<String> optString= exp.getStringValue();
if (optString.isPresent()){
if (optString.get().equals(AllPredefinedLayers.L2)
|| optString.get().equals(AllPredefinedLayers.INNER_PROD)){
return true;
}
}
return false;
}
@Override
protected String msgString() {
return AllPredefinedLayers.L2 + " or "
+ AllPredefinedLayers.INNER_PROD + "or"
+ AllPredefinedLayers.RANDOM;
}
},
MEMORY_REPLACEMENT_STRATEGY_TYPE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
......
......@@ -283,6 +283,10 @@ public class LayerSymbol extends ArchitectureElementSymbol {
return getTValue(parameterName, ArchExpressionSymbol::getIntTupleValues);
}
public Optional<List<Integer>> getIntOrIntTupleValues(String parameterName){
return getTValue(parameterName, ArchExpressionSymbol::getIntOrIntTupleValues);
}
public Optional<Boolean> getBooleanValue(String parameterName){
return getTValue(parameterName, ArchExpressionSymbol::getBooleanValue);
}
......
......@@ -13,6 +13,7 @@ import java.util.*;
public class SerialCompositeElementSymbol extends CompositeElementSymbol {
protected List<List<ArchitectureElementSymbol>> episodicSubNetworks = new ArrayList<>(new ArrayList<>());
protected boolean anyEpisodicLocalAdaption = false;
protected void setElements(List<ArchitectureElementSymbol> elements) {
ArchitectureElementSymbol previous = null;
......@@ -60,6 +61,10 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol {
return episodicSubNetworks;
}
protected void setAnyEpisodicLocalAdaption(boolean isUsed){ anyEpisodicLocalAdaption = isUsed; }
public boolean getAnyEpisodicLocalAdaption(){ return anyEpisodicLocalAdaption; }
@Override
public void setInputElement(ArchitectureElementSymbol inputElement) {
super.setInputElement(inputElement);
......
......@@ -31,6 +31,7 @@ public class AllPredefinedLayers {
public static final String GLOBAL_POOLING_NAME = "GlobalPooling";
public static final String LRN_NAME = "Lrn";
public static final String BATCHNORM_NAME = "BatchNorm";
public static final String LAYERNORM_NAME = "LayerNorm";
public static final String SPLIT_NAME = "Split";
public static final String GET_NAME = "Get";
public static final String ADD_NAME = "Add";
......@@ -126,7 +127,6 @@ public class AllPredefinedLayers {
public static final String QUERY_ACT_NAME = "queryAct";
public static final String K_NAME = "k";
public static final String NUM_HEADS_NAME = "numHeads";
public static final String STORE_DIST_MEASURE_NAME = "storeDistMeasure";
public static final String VALUES_DIM_NAME = "valuesDim";
public static final String MEMORY_REPLACEMENT_STRATEGY_NAME = "memoryReplacementStrategy";
......@@ -143,8 +143,6 @@ public class AllPredefinedLayers {
public static final String PADDING_NO_LOSS = "no_loss";
public static final String POOL_MAX = "max";
public static final String POOL_AVG = "avg";
public static final String L2 = "l2";
public static final String INNER_PROD = "inner_prod";
public static final String RANDOM = "random";
public static final String REPLACE_OLDEST = "replace_oldest";
public static final String NO_REPLACEMENT = "no_replacement";
......@@ -174,6 +172,7 @@ public class AllPredefinedLayers {
GlobalPooling.create(),
Lrn.create(),
BatchNorm.create(),
LayerNorm.create(),
Split.create(),
Get.create(),
Add.create(),
......
......@@ -50,7 +50,7 @@ public class DotProductSelfAttention extends PredefinedLayerDeclaration {
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SCALE_FACTOR_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.constraints(Constraints.POSITIVE)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
......
......@@ -80,7 +80,7 @@ public class EpisodicMemory extends PredefinedLayerDeclaration {
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.USE_LOCAL_ADAPTION_NAME)
.constraints(Constraints.BOOLEAN, Constraints.POSITIVE)
.constraints(Constraints.BOOLEAN)
.defaultValue(true)
.build(),
new ParameterSymbol.Builder()
......
......@@ -60,11 +60,6 @@ public class LargeMemory extends PredefinedLayerDeclaration {
public static LargeMemory create(){
LargeMemory declaration = new LargeMemory();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.STORE_DIST_MEASURE_NAME)
.constraints(Constraints.DIST_MEASURE_TYPE)
.defaultValue(AllPredefinedLayers.INNER_PROD)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SUB_KEY_SIZE_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class LayerNorm extends PredefinedLayerDeclaration {
private LayerNorm() {
super(AllPredefinedLayers.LAYERNORM_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
return layer.getInputTypes();
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static LayerNorm create(){
LayerNorm declaration = new LayerNorm();
declaration.setParameters(new ArrayList<>());
return declaration;
}
}
......@@ -24,7 +24,7 @@ public class LoadNetwork extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
Optional<List<Integer>> optValue = layer.getIntTupleValue(AllPredefinedLayers.OUTPUT_SHAPE_NAME);
Optional<List<Integer>> optValue = layer.getIntOrIntTupleValues(AllPredefinedLayers.OUTPUT_SHAPE_NAME);
List<Integer> shapeList = Arrays.asList(1, 1, 1);
......@@ -65,7 +65,7 @@ public class LoadNetwork extends PredefinedLayerDeclaration {
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.OUTPUT_SHAPE_NAME)
.constraints(Constraints.INTEGER_TUPLE)
.constraints(Constraints.INTEGER_OR_INTEGER_TUPLE)
.build()));
declaration.setParameters(parameters);
return declaration;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment