Commit f2e2fb86 authored by Julian Treiber's avatar Julian Treiber

added load_pretrained kwarg

parent 4e8f206e
...@@ -15,7 +15,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -15,7 +15,7 @@ class ${tc.fileNameWithoutEnding}:
self.weight_initializer = mx.init.Normal() self.weight_initializer = mx.init.Normal()
self.networks = {} self.networks = {}
def load(self, context): def load(self, context, load_pretrained=False, pretrained_files=None):
earliestLastEpoch = None earliestLastEpoch = None
for i, network in self.networks.items(): for i, network in self.networks.items():
......
...@@ -222,6 +222,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -222,6 +222,7 @@ class ${tc.fileNameWithoutEnding}:
optimizer_params=(('learning_rate', 0.001),), optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True, load_checkpoint=True,
checkpoint_period=5, checkpoint_period=5,
load_pretrained=False,
log_period=50, log_period=50,
context='gpu', context='gpu',
save_attention_image=False, save_attention_image=False,
...@@ -265,7 +266,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -265,7 +266,7 @@ class ${tc.fileNameWithoutEnding}:
begin_epoch = 0 begin_epoch = 0
if load_checkpoint: if load_checkpoint:
begin_epoch = self._net_creator.load(mx_context) begin_epoch = self._net_creator.load(mx_context, load_pretrained=load_pretrained)
else: else:
if os.path.isdir(self._net_creator._model_dir_): if os.path.isdir(self._net_creator._model_dir_):
shutil.rmtree(self._net_creator._model_dir_) shutil.rmtree(self._net_creator._model_dir_)
...@@ -296,8 +297,8 @@ class ${tc.fileNameWithoutEnding}: ...@@ -296,8 +297,8 @@ class ${tc.fileNameWithoutEnding}:
elif loss == 'cross_entropy': elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss': elif loss == 'dice_loss':
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -37,6 +37,9 @@ if __name__ == "__main__": ...@@ -37,6 +37,9 @@ if __name__ == "__main__":
<#if (config.logPeriod)??> <#if (config.logPeriod)??>
log_period=${config.logPeriod}, log_period=${config.logPeriod},
</#if> </#if>
<#if (config.loadPretrained)??>
load_pretrained="${config.loadPretrained?string("True","False")}",
</#if>
<#if (config.context)??> <#if (config.context)??>
context='${config.context}', context='${config.context}',
</#if> </#if>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment