diff --git a/keras.py b/keras.py index dd72a76e1fe8317339c44d4c7a8a40b6aeae513f..2c0064b9016d895fcbc4d364eca6ad1dee3a2f5a 100644 --- a/keras.py +++ b/keras.py @@ -588,18 +588,15 @@ class TensorBoard(tf.keras.callbacks.TensorBoard): class CheckpointModel(tf.keras.callbacks.Callback): - def __init__(self, savedir="tmp", frequency=1, identifier="cp"): - pin(locals()) - - def get_index(self, epoch): - return epoch + """Gets a dict of targets (checkpoints), if current epoch is in dict, save model to target.""" - def checkpoint_dir(self, epoch): - return f"{self.savedir}/{self.identifier}-{self.get_index(epoch)}" + def __init__(self, checkpoints=None): + self.targets = checkpoints.targets def on_epoch_end(self, epoch, logs=None): - if epoch % self.frequency == 0: - self.model.save(self.checkpoint_dir(epoch)) + if epoch in self.targets: + target = self.targets[epoch] + self.model.save(target.path) class BestTracker(tf.keras.callbacks.Callback):