Commit 6ff123d1 authored by Julian Dierkes's avatar Julian Dierkes

introduced mechanics to specify constraint distributions in CNNTrain

parent 68c9be95
Pipeline #225687 failed with stages
in 28 seconds
......@@ -20,7 +20,7 @@
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.5-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.5</embedded-montiarc-math-opt-generator>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
......
......@@ -169,19 +169,10 @@ public class GluonConfigurationData extends ConfigurationData {
return getMultiParamEntry(NOISE_DISTRIBUTION, "name");
}
public String getImgResizeWidth() {
if (!this.getConfiguration().getEntryMap().containsKey("img_resize_width")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("img_resize_width").getValue());
public Map<String, Map<String, Object>> getConstraintDistributions() {
return getMultiParamMapEntry(CONSTRAINT_DISTRIBUTION, "name");
}
public String getImgResizeHeight() {
if (!this.getConfiguration().getEntryMap().containsKey("img_resize_height")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("img_resize_height").getValue());
}
public Map<String, Object> getStrategy() {
assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " +
" non reinforcement learning context";
......
......@@ -78,6 +78,8 @@ class ${tc.fileNameWithoutEnding}:
img_resize=(64,64),
noise_distribution='gaussian',
noise_distribution_params=(('mean_value', 0),('spread_value', 1),),
constraint_distributions={},
constraint_losses={},
preprocessing = False):
if context == 'gpu':
......@@ -164,7 +166,14 @@ class ${tc.fileNameWithoutEnding}:
if loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
activation_name = 'sigmoid'
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'log_cosh':
loss_function = LogCoshLoss()
else:
logging.error("Invalid loss parameter.")
metric_dis = mx.metric.create(eval_metric)
metric_gen = mx.metric.create(eval_metric)
......@@ -219,17 +228,15 @@ class ${tc.fileNameWithoutEnding}:
else:
if batch_i % speed_period == 0:
metric_dis = mx.metric.create(eval_metric)
discriminated = mx.nd.Concat(discriminated_real_dis.reshape((-1,1)), discriminated_fake_dis.reshape((-1,1)), dim=0)
labels = mx.nd.Concat(real_labels.reshape((-1,1)), fake_labels.reshape((-1,1)), dim=0)
discriminated = mx.ndarray.Activation(discriminated, activation_name)
discriminated = mx.nd.Concat(loss_resultD, loss_resultF, dim=0)
labels = mx.nd.Concat(real_labels, fake_labels, dim=0)
discriminated = mx.ndarray.floor(discriminated + 0.5)
metric_dis.update(preds=discriminated, labels=labels)
print("DisAcc: ", metric_dis.get()[1])
metric_gen = mx.metric.create(eval_metric)
discriminated = mx.ndarray.Activation(discriminated_fake_gen.reshape((-1,1)), activation_name)
discriminated = mx.ndarray.floor(discriminated + 0.5)
metric_gen.update(preds=discriminated, labels=real_labels.reshape((-1,1)))
discriminated = mx.ndarray.floor(loss_resultG + 0.5)
metric_gen.update(preds=discriminated, labels=real_labels)
print("GenAcc: ", metric_gen.get()[1])
try:
......
......@@ -9,16 +9,33 @@
domain = gen_inputs[name]
min = domain[1]
max = domain[2]
if domain[0] == float:
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.uniform(min,max,
if name[:-1] in constraint_distributions:
dist_dict = constraint_distributions[name[:-1]]
dist_name = dist_dict['name']
if dist_name is "gaussian":
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.normal(dist_dict["mean_value"],
dist_dict["spread_value"],
shape=(batch_size,)+domain[3], dtype=domain[0],
ctx=mx_context), dtype="float32")
else:
if domain[0] == float:
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.uniform(min,max,
shape=(batch_size,)+domain[3],
dtype=domain[0], ctx=mx_context,), dtype="float32")
qnet_losses += [mx.gluon.loss.L2Loss()]
elif domain[0] == int:
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.randint(low=int(min),
elif domain[0] == int:
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.randint(low=int(min),
high=int(max)+1, shape=(batch_size,)+domain[3],
ctx=mx_context), dtype="float32")
qnet_losses += [lambda pred, labels: mx.gluon.loss.SoftmaxCrossEntropyLoss()(pred, labels.reshape(batch_size))]
if namen[-1] in constraint_losses:
wierd = 'doNothingYet'
else:
if domain[0] == float:
qnet_losses += [mx.gluon.loss.L2Loss()]
elif domain[0] == int:
qnet_losses += [lambda pred, labels: mx.gluon.loss.SoftmaxCrossEntropyLoss()(pred, labels.reshape(batch_size))]
for name in gen_inputs:
if not name in qnet_outputs:
......
......@@ -61,11 +61,6 @@ if __name__ == "__main__":
<#if (config.preprocessingName)??>
preprocessing=${config.preprocessingName???string("True","False")},
</#if>
<#if (config.imgResizeWidth)??>
<#if (config.imgResizeHeight)??>
img_resize=(${config.imgResizeWidth}, ${config.imgResizeHeight}),
</#if>
</#if>
<#if (config.evalMetric)??>
eval_metric='${config.evalMetric}',
</#if>
......@@ -87,6 +82,21 @@ if __name__ == "__main__":
</#list>
},
</#if>
<#if (config.constraintDistributions)??>
<#assign map = (config.constraintDistributions)>
constraint_distributions = {
<#list map?keys as nameKey>
'${nameKey}' : { 'name': '${map[nameKey].name}',
<#if (map[nameKey].mean_value)??>
'mean_value': ${map[nameKey].mean_value},
</#if>
<#if (map[nameKey].spread_value)??>
'spread_value': ${map[nameKey].spread_value}
</#if>
},
</#list>
},
</#if>
<#if (config.noiseDistribution)??>
noise_distribution = '${config.noiseDistribution.name}',
noise_distribution_params = {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment