diff --git a/evaluation/sample.py b/evaluation/sample.py index a7a6e3b42ae5681159bff338ac41a2bf25ff56f1..4d98deb349a554754a0b51dfab1b75cb8606427b 100644 --- a/evaluation/sample.py +++ b/evaluation/sample.py @@ -75,6 +75,8 @@ def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, y = torch.randint(0, 3, (batch_size,)).to(device) # generate images generated = model.sample(y=y, batch_size=y.size(0)) + # clip the values to between -1 and 1 + generated = generated.clamp(-1, 1) # save images for i in range(generated.size(0)): image = back2pil(generated[i])