Skip to content
Snippets Groups Projects
Select Git revision
  • main
1 result

sample.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    sample.py 4.53 KiB
    import os
    import torch
    from torchvision import transforms
    import re
    
    def ldm_sampler(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
        '''
        Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated 
        images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch 
        w.r.t. the model which we are sampling form and j is an index separating images from each call of the 
        sampling function for the given mode.
      
        model:             Diffusion model
        checkpoint:        Name of the saved pth. file containing the trained weights and biases
        conditioning_path: Path to the conditioning input tensors from whihc the samples will be generated from 
        experiment_path:   Path to the experiment directory where the samples will saved under the diectory samples
        batch_size:        The number of images to sample
        intermediate:      Bool value. If False the sampling function will draw a batch of images, else it will just 
                           sample a single image, but store all the intermediate noised latents along the reverse chain
        sample_all:        If True, samples a batch of images for the given model at every stored checkpoint at once
        n_times:           Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
                           the GPU will draw batches of 512 images 20 times to reach this goal.  
        '''
    
        # If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints
        if sample_all:
            f = f'{experiment_path}trained_ldm/'
            checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
            for checkpoint_i in os.listdir(f):
                if checkpoint_i.endswith(".pth"):
                    ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False)
            return 0
    
        # load model
        try:
            checkpoint_path = f'{experiment_path}trained_ldm/{checkpoint}'
            checkpoint = torch.load(checkpoint_path)
            # load weights and biases of the U-Net
            net_state_dict = checkpoint['model']
            model.net.load_state_dict(net_state_dict)
            model = model.to(device)
        except Exception as e:
            print("Error loading checkpoint. Exception:", e)
        
        model.vq_model.to("cuda")
    
        # create samples directory for the complete experiment (if first time sampling images)
        output_dir = f'{experiment_path}samples/'
        #output_dir = os.path.join(experiment_path,'/samples/')
        os.makedirs(output_dir, exist_ok=True)
    
        # create sample directory for the current version of the trained model
        model_name = os.path.basename(checkpoint_path)
        epoch = re.findall(r'\d+', model_name)
        if epoch:
            e = int(epoch[0])
        else:
            #raise ValueError(f"No digit found in the filename: {filename}")
            e = 0
        model_dir = os.path.join(output_dir,f'epoch_{e}')
        os.makedirs(model_dir, exist_ok=True)
    
        # create the sample directory for this sampling run for the current version of the model    
        sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
        indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')]
        j = max(indx_list, default=-1) + 1
        sample_dir = os.path.join(model_dir, f'sample_{j}')
        os.makedirs(sample_dir, exist_ok=True)
    
        # transform
        back2pil = transforms.Compose([transforms.Lambda(lambda x: (x-x.min())/(x.max()-x.min())),transforms.ToPILImage()])
        
        # sample upscaling latent encoding, decode and transform to PIL
        run_indx=0
        for idx,(_, y) in enumerate(dataloader):
            y = y.to(device)
            generated = model.sample(y=y, batch_size=y.size(0))
            #torch.save(generated,os.path.join(sample_dir, f"image_tensor{idx}"))
            # save images
            for i in range(generated.size(0)):
                image = back2pil(generated[i])
                image_path = os.path.join(sample_dir, f'sample_{j}_{run_indx}.png')
                image_raw = back2pil(y[i])
                image_path_raw = os.path.join(sample_dir, f'raw_{j}_{run_indx}.png')
                try:
                    image.save(image_path)
                    image_raw.save(image_path_raw)
                except Exception as e:
                    print("Error saving image. Exception:", e)
                run_indx = run_indx+1
            if run_indx >9:
                break