Inherit get-functions for Training parameters

parent 23238012
Pipeline #66682 failed with stages
in 1 minute and 24 seconds
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.2-SNAPSHOT</CNNArch.version> <CNNArch.version>0.2.2-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.1-SNAPSHOT</CNNTrain.version> <CNNTrain.version>0.2.2-SNAPSHOT</CNNTrain.version>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version> <guava.version>18.0</guava.version>
......
...@@ -26,26 +26,51 @@ public class ConfigurationData { ...@@ -26,26 +26,51 @@ public class ConfigurationData {
} }
public String getNumEpoch() { public String getNumEpoch() {
return String.valueOf(getConfiguration().getNumEpoch().getValue()); if (!getConfiguration().getEntryMap().containsKey("num_epoch")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("num_epoch").getValue());
} }
public String getBatchSize() { public String getBatchSize() {
return String.valueOf(getConfiguration().getBatchSize().getValue()); if (!getConfiguration().getEntryMap().containsKey("batch_size")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("batch_size") .getValue());
} }
public LoadCheckpointSymbol getLoadCheckpoint() { public Boolean getLoadCheckpoint() {
return getConfiguration().getLoadCheckpoint(); if (!getConfiguration().getEntryMap().containsKey("load_checkpoint")) {
return null;
}
return (Boolean) getConfiguration().getEntry("load_checkpoint").getValue().getValue();
} }
public NormalizeSymbol getNormalize() { public Boolean getNormalize() {
return getConfiguration().getNormalize(); if (!getConfiguration().getEntryMap().containsKey("normalize")) {
return null;
}
return (Boolean) getConfiguration().getEntry("normalize").getValue().getValue();
} }
public TrainContextSymbol getContext() { public String getContext() {
return getConfiguration().getTrainContext(); if (!getConfiguration().getEntryMap().containsKey("context")) {
return null;
}
return getConfiguration().getEntry("context").getValue().toString();
}
public String getEvalMetric() {
if (!getConfiguration().getEntryMap().containsKey("eval_metric")) {
return null;
}
return getConfiguration().getEntry("eval_metric").getValue().toString();
} }
public String getOptimizerName() { public String getOptimizerName() {
if (getConfiguration().getOptimizer() == null) {
return null;
}
return getConfiguration().getOptimizer().getName(); return getConfiguration().getOptimizer().getName();
} }
......
...@@ -17,13 +17,13 @@ if __name__ == "__main__": ...@@ -17,13 +17,13 @@ if __name__ == "__main__":
batch_size = ${config.batchSize}, batch_size = ${config.batchSize},
</#if> </#if>
<#if (config.loadCheckpoint)??> <#if (config.loadCheckpoint)??>
load_checkpoint = ${config.loadCheckpoint.value?string("True","False")}, load_checkpoint = ${config.loadCheckpoint?string("True","False")},
</#if> </#if>
<#if (config.context)??> <#if (config.context)??>
context = '${config.context.value}', context = '${config.context}',
</#if> </#if>
<#if (config.normalize)??> <#if (config.normalize)??>
normalize = ${config.normalize.value?string("True","False")}, normalize = ${config.normalize?string("True","False")},
</#if> </#if>
<#if (config.configuration.optimizer)??> <#if (config.configuration.optimizer)??>
optimizer = '${config.optimizerName}', optimizer = '${config.optimizerName}',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment