Skip to content
Snippets Groups Projects
Commit d310c33f authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] callbacks: adds CheckpointModel callback

parent 3d1dff56
No related branches found
No related tags found
No related merge requests found
......@@ -562,6 +562,21 @@ class TensorBoard(tf.keras.callbacks.TensorBoard):
pass
class CheckpointModel(tf.keras.callbacks.Callback):
def __init__(self, savedir="tmp", frequency=1, identifier="cp"):
pin(locals())
def get_index(self, epoch):
return epoch
def checkpoint_dir(self, epoch):
return f"{self.savedir}/{self.identifier}-{self.get_index(epoch)}"
def on_epoch_end(self, epoch, logs=None):
if epoch != 0 and epoch % self.frequency == 0:
self.model.save(self.checkpoint_dir(epoch))
class BestTracker(tf.keras.callbacks.Callback):
def __init__(
self, monitor="val_loss", mode="auto", min_delta=0, min_delta_rel=0, baseline=None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment