diff --git a/src/test/resources/target_code/CNNCreator_LeNet.py b/src/test/resources/target_code/CNNCreator_LeNet.py index 19b097b19728a182fe646fce43070037e263dd8e..453ed8c8dd622fe36898b87c7b4f85ae05cbe315 100644 --- a/src/test/resources/target_code/CNNCreator_LeNet.py +++ b/src/test/resources/target_code/CNNCreator_LeNet.py @@ -141,7 +141,7 @@ class CNNCreator_LeNet: train_model= model_helper.ModelHelper(name="train_net", arg_scope=arg_scope) data, label, train_dataset_size = self.add_input(train_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'train_lmdb'), db_type='lmdb', device_opts=device_opts) predictions = self.create_model(train_model, data, device_opts=device_opts, is_test=False) - self.add_training_operators(train_model, predictions, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) + self.add_training_operators(train_model, predictions, label, device_opts, loss, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) self.add_accuracy(train_model, predictions, label, device_opts, eval_metric) with core.DeviceScope(device_opts): brew.add_weight_decay(train_model, weight_decay) diff --git a/src/test/resources/target_code/CNNCreator_VGG16.py b/src/test/resources/target_code/CNNCreator_VGG16.py index fb69eec6416d1891e91137fb37c3ca0de92c277c..f7a72517a2aece0625d100aeefa2dd7dada9969c 100644 --- a/src/test/resources/target_code/CNNCreator_VGG16.py +++ b/src/test/resources/target_code/CNNCreator_VGG16.py @@ -187,7 +187,7 @@ class CNNCreator_VGG16: train_model= model_helper.ModelHelper(name="train_net", arg_scope=arg_scope) data, label, train_dataset_size = self.add_input(train_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'train_lmdb'), db_type='lmdb', device_opts=device_opts) predictions = self.create_model(train_model, data, device_opts=device_opts, is_test=False) - self.add_training_operators(train_model, predictions, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) + self.add_training_operators(train_model, predictions, label, device_opts, loss, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) self.add_accuracy(train_model, predictions, label, device_opts, eval_metric) with core.DeviceScope(device_opts): brew.add_weight_decay(train_model, weight_decay)