Skip to content
Snippets Groups Projects
Commit 488eb68c authored by Weihan Li's avatar Weihan Li
Browse files

Update MTL.py

parent 02990287
No related branches found
No related tags found
No related merge requests found
......@@ -174,6 +174,8 @@ model.layers[19].trainable=True
model.layers[21].trainable=True
model.layers[23].trainable=True
#train model for stage 1
#sometimes need load checkpoint because the local best solution may not be the best solution for remaining stages.
#model.load_weights('capir/weight0_best.hdf5')
model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4),loss=CustomLoss.MaskedMAE,metrics=[],loss_weights=([0.0,1.0]))
model.fit(x0,y0,batch_size=384,epochs=450,verbose=2,validation_data=(x1,y1),shuffle=True,callbacks=[cp_callback1,callback2])
model.save("2CapIR2_stg1.h5")
......@@ -181,9 +183,11 @@ model.save("2CapIR2_stg1.h5")
for idx in range(len(model.layers)):
model.layers[idx ].trainable=True
#train model for stage 2
#model.load_weights('capir/weight1_best.hdf5')
model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-5),loss=CustomLoss.MaskedMAE,metrics=[],loss_weights=([1.0,1.0]))
#model.load_weights('capir/weight2_best.hdf5')
model.fit(x0,y0,batch_size=512,epochs=300,verbose=1,validation_data=(x1,y1),shuffle=True,callbacks=[cp_callback2,callback2])
#save model
model.save("2CapIR2.h5")
tf.keras.utils.plot_model(model,"2CapIR2.png",show_shapes=True,expand_nested=True)
\ No newline at end of file
tf.keras.utils.plot_model(model,"2CapIR2.png",show_shapes=True,expand_nested=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment