@@ -337,7 +343,10 @@ class ${tc.fileNameWithoutEnding}:
if not os.path.isdir(self._net_creator._model_dir_):
raise
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
if optimizer == "adamw":
trainers = [mx.gluon.Trainer(network.collect_params(), AdamW.AdamW(**optimizer_params)) for network in self._networks.values() if len(network.collect_params().values()) != 0]
else:
trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
...
...
@@ -408,7 +417,7 @@ class ${tc.fileNameWithoutEnding}:
layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]