Commit 931b3386 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Merge develop into master

See merge request !8
parents b14b8540 7f595808
Pipeline #246263 canceled with stages
......@@ -19,7 +19,7 @@
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......
......@@ -167,6 +167,10 @@ public class ArchitectureElementData {
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.BIDIRECTIONAL_NAME).get();
}
public double getDropout() {
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.RNN_DROPOUT_NAME).get();
}
public boolean getFlatten() {
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.FLATTEN_PARAMETER_NAME).get();
}
......@@ -178,7 +182,7 @@ public class ArchitectureElementData {
@Nullable
public List<Integer> getPadding(){
String pad = ((LayerSymbol) getElement()).getStringValue(AllPredefinedLayers.PADDING_NAME).get();
if(pad.equals("same")){
......@@ -213,4 +217,48 @@ public class ArchitectureElementData {
return Arrays.asList(0,0,0,0,topPad,bottomPad,leftPad,rightPad);
}
@Nullable
public List<Integer> getTransPadding(){
String pad = ((LayerSymbol) getElement()).getStringValue(AllPredefinedLayers.TRANSPADDING_NAME).get();
if(pad.equals("same")){
return getTransPadding(getLayerSymbol()); //The padding calculated here is only used in the gluon/ mxnet backend, in the tensorlflow one it is interpreted as "same"
}else if(pad.equals("valid")){
return Arrays.asList(0,0);
}else{ //"no loss"
return Arrays.asList(0,0,-1,0,0,0,0,0);
}
}
@Nullable
public List<Integer> getTransPadding(LayerSymbol layer) {
List<Integer> kernel = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
List<Integer> stride = layer.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
ArchTypeSymbol inputType = layer.getInputTypes().get(0);
ArchTypeSymbol outputType = layer.getOutputTypes().get(0);
int heightPad = kernel.get(0) - stride.get(0);
int widthPad = kernel.get(1) - stride.get(1);
int topPad = (int) Math.ceil(heightPad / 2.0);
int bottomPad = (int) Math.floor(heightPad / 2.0);
int leftPad = (int) Math.ceil(widthPad / 2.0);
int rightPad = (int) Math.floor(widthPad / 2.0);
/*if (topPad == 0 && bottomPad == 0 && leftPad == 0 && rightPad == 0){
return null;
}*/
return Arrays.asList(bottomPad, rightPad);
}
/*public boolean getStart() {
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.START_NAME).get();
}
public boolean getEnd() {
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.END_NAME).get();
}*/
}
\ No newline at end of file
......@@ -2,6 +2,8 @@
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import jline.internal.Log;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
import java.util.ArrayList;
......@@ -48,6 +50,20 @@ public class ConfigurationData {
return (Boolean) getConfiguration().getEntry("load_checkpoint").getValue().getValue();
}
public String getCheckpointPeriod() {
if (!getConfiguration().getEntryMap().containsKey("checkpoint_period")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("checkpoint_period").getValue());
}
public String getLogPeriod() {
if (!getConfiguration().getEntryMap().containsKey("log_period")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("log_period").getValue());
}
public Boolean getNormalize() {
if (!getConfiguration().getEntryMap().containsKey("normalize")) {
return null;
......@@ -55,6 +71,27 @@ public class ConfigurationData {
return (Boolean) getConfiguration().getEntry("normalize").getValue().getValue();
}
public Boolean getShuffleData() {
if (!getConfiguration().getEntryMap().containsKey("shuffle_data")) {
return null;
}
return (Boolean) getConfiguration().getEntry("shuffle_data").getValue().getValue();
}
public String getClipGlobalGradNorm() {
if (!getConfiguration().getEntryMap().containsKey("clip_global_grad_norm")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("clip_global_grad_norm").getValue());
}
public String getPreprocessingName() {
if (!getConfiguration().getEntryMap().containsKey("preprocessing_name")) {
return null;
}
return (String) getConfiguration().getEntry("preprocessing_name").getValue().toString();
}
public String getContext() {
if (!getConfiguration().getEntryMap().containsKey("context")) {
return null;
......@@ -143,6 +180,34 @@ public class ConfigurationData {
return (Boolean) getConfiguration().getEntry("use_teacher_forcing").getValue().getValue();
}
public Boolean getEvalTrain() {
if (!getConfiguration().getEntryMap().containsKey("eval_train")) {
return null;
}
return (Boolean) getConfiguration().getEntry("eval_train").getValue().getValue();
}
protected Map<String, Map<String, Object>> getMultiParamMapEntry(final String key, final String valueName) {
if (!configurationContainsKey(key)) {
return null;
}
Map<String, Map<String,Object>> resultView = new HashMap<>();
ValueSymbol value = this.getConfiguration().getEntryMap().get(key).getValue();
if (value instanceof MultiParamValueMapSymbol) {
MultiParamValueMapSymbol multiParamValueMap = (MultiParamValueMapSymbol) value;
resultView.putAll(multiParamValueMap.getParameters());
Map<String,String> names = multiParamValueMap.getMultiParamValueNames();
for(String distrName : names.keySet()) {
resultView.get(distrName).put(valueName, names.get(distrName));
}
}
return resultView;
}
protected Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
if (!configurationContainsKey(key)) {
return null;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment