Commit f2e2fb86 authored by Julian Treiber's avatar Julian Treiber

added load_pretrained kwarg

parent 4e8f206e
......@@ -15,7 +15,7 @@ class ${tc.fileNameWithoutEnding}:
self.weight_initializer = mx.init.Normal()
self.networks = {}
def load(self, context):
def load(self, context, load_pretrained=False, pretrained_files=None):
earliestLastEpoch = None
for i, network in self.networks.items():
......
......@@ -222,6 +222,7 @@ class ${tc.fileNameWithoutEnding}:
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
checkpoint_period=5,
load_pretrained=False,
log_period=50,
context='gpu',
save_attention_image=False,
......@@ -265,7 +266,7 @@ class ${tc.fileNameWithoutEnding}:
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator.load(mx_context)
begin_epoch = self._net_creator.load(mx_context, load_pretrained=load_pretrained)
else:
if os.path.isdir(self._net_creator._model_dir_):
shutil.rmtree(self._net_creator._model_dir_)
......@@ -296,8 +297,8 @@ class ${tc.fileNameWithoutEnding}:
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......
......@@ -37,6 +37,9 @@ if __name__ == "__main__":
<#if (config.logPeriod)??>
log_period=${config.logPeriod},
</#if>
<#if (config.loadPretrained)??>
load_pretrained="${config.loadPretrained?string("True","False")}",
</#if>
<#if (config.context)??>
context='${config.context}',
</#if>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment