From 488eb68c23898f46308f58b9299088c287b9380d Mon Sep 17 00:00:00 2001
From: Weihan Li <tjliweihan@gmail.com>
Date: Thu, 23 Dec 2021 10:38:49 +0100
Subject: [PATCH] Update MTL.py

---
 Modeling Codes/MTL.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/Modeling Codes/MTL.py b/Modeling Codes/MTL.py
index 0fbb076..b78fdbf 100644
--- a/Modeling Codes/MTL.py	
+++ b/Modeling Codes/MTL.py	
@@ -174,6 +174,8 @@ model.layers[19].trainable=True
 model.layers[21].trainable=True
 model.layers[23].trainable=True
 #train model for stage 1
+#sometimes need load checkpoint because the local best solution may not be the best solution for remaining stages.
+#model.load_weights('capir/weight0_best.hdf5')
 model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4),loss=CustomLoss.MaskedMAE,metrics=[],loss_weights=([0.0,1.0]))
 model.fit(x0,y0,batch_size=384,epochs=450,verbose=2,validation_data=(x1,y1),shuffle=True,callbacks=[cp_callback1,callback2])
 model.save("2CapIR2_stg1.h5")
@@ -181,9 +183,11 @@ model.save("2CapIR2_stg1.h5")
 for idx in range(len(model.layers)):
     model.layers[idx ].trainable=True
 #train model for stage 2
+#model.load_weights('capir/weight1_best.hdf5')
 model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-5),loss=CustomLoss.MaskedMAE,metrics=[],loss_weights=([1.0,1.0]))
+#model.load_weights('capir/weight2_best.hdf5')
 model.fit(x0,y0,batch_size=512,epochs=300,verbose=1,validation_data=(x1,y1),shuffle=True,callbacks=[cp_callback2,callback2])
 
 #save model
 model.save("2CapIR2.h5")
-tf.keras.utils.plot_model(model,"2CapIR2.png",show_shapes=True,expand_nested=True)
\ No newline at end of file
+tf.keras.utils.plot_model(model,"2CapIR2.png",show_shapes=True,expand_nested=True)
-- 
GitLab