diff --git a/src/main/resources/templates/caffe2/CNNCreator.ftl b/src/main/resources/templates/caffe2/CNNCreator.ftl index 834fbf8a08afbfc4c656b9b34c54470b395ace4e..941a27a3ad0442c20c90b810a746d313a62253a8 100644 --- a/src/main/resources/templates/caffe2/CNNCreator.ftl +++ b/src/main/resources/templates/caffe2/CNNCreator.ftl @@ -105,8 +105,9 @@ class ${tc.fileNameWithoutEnding}: sys.exit(1) - def train(self, batch_size, + def train(self, batch_size=64, num_epoch=10, + eval_metric='acc', optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, @@ -158,6 +159,7 @@ class ${tc.fileNameWithoutEnding}: self.module.fit( train_data=train_iter, + eval_metric=eval_metric, eval_data=test_iter, optimizer=optimizer, optimizer_params=optimizer_params, diff --git a/src/test/resources/target_code/CNNCreator_Alexnet.py b/src/test/resources/target_code/CNNCreator_Alexnet.py index 678b44c2bc1194dbc7c265e0f45a2a19419af7c5..c987bef75be2184613129fafe418dc7bc842fe17 100644 --- a/src/test/resources/target_code/CNNCreator_Alexnet.py +++ b/src/test/resources/target_code/CNNCreator_Alexnet.py @@ -105,8 +105,9 @@ class CNNCreator_Alexnet: sys.exit(1) - def train(self, batch_size, + def train(self, batch_size=64, num_epoch=10, + eval_metric='acc', optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, @@ -158,6 +159,7 @@ class CNNCreator_Alexnet: self.module.fit( train_data=train_iter, + eval_metric=eval_metric, eval_data=test_iter, optimizer=optimizer, optimizer_params=optimizer_params, diff --git a/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py index 6dd81feb3ebfd7f5b8bf13cc56cbefd91d54c338..91a57704254e22d002193b26dce8fa0f0cc9f105 100644 --- a/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py +++ b/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py @@ -105,8 +105,9 @@ class CNNCreator_CifarClassifierNetwork: sys.exit(1) - def train(self, batch_size, + def train(self, batch_size=64, num_epoch=10, + eval_metric='acc', optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, @@ -158,6 +159,7 @@ class CNNCreator_CifarClassifierNetwork: self.module.fit( train_data=train_iter, + eval_metric=eval_metric, eval_data=test_iter, optimizer=optimizer, optimizer_params=optimizer_params, diff --git a/src/test/resources/target_code/CNNCreator_VGG16.py b/src/test/resources/target_code/CNNCreator_VGG16.py index ba39f3a2365c90ae5c33c9ea4a619263311cb230..f27d5dbc57148025db17a072f7734d439b682e34 100644 --- a/src/test/resources/target_code/CNNCreator_VGG16.py +++ b/src/test/resources/target_code/CNNCreator_VGG16.py @@ -105,8 +105,9 @@ class CNNCreator_VGG16: sys.exit(1) - def train(self, batch_size, + def train(self, batch_size=64, num_epoch=10, + eval_metric='acc', optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, @@ -158,6 +159,7 @@ class CNNCreator_VGG16: self.module.fit( train_data=train_iter, + eval_metric=eval_metric, eval_data=test_iter, optimizer=optimizer, optimizer_params=optimizer_params,