Skip to content
Snippets Groups Projects
Select Git revision
  • c27fdccf6e885802c13683c5829d395b4c224e80
  • main default protected
  • gitkeep
  • dev
  • ipynb
  • 81-add-id-to-figure-file-metadata
  • v0.3.2
  • v0.3.1
  • v0.3.0
  • v0.2.3
  • test_tag
  • v0.2.2
  • v.0.2.1
  • v0.2.1
  • v0.1.2
  • v0.1.1
  • v0.1.0
17 results

example.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    evaluate.py 4.32 KiB
    from torchmetrics.image.fid import FrechetInceptionDistance
    from torchmetrics.image.inception import InceptionScore
    from torchmetrics.image.kid import KernelInceptionDistance
    import re
    import os
    from PIL import Image
    from torchvision import transforms
    import torch
    
    def cdm_evaluator(model, 
                      device,
                      dataloader,
                      checkpoint,
                      experiment_path,
                      sample_idx=0,
                      **args,
                      ):
      '''
      Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test 
      dataset 'dataloader' w.r.t. the three most important perfromance metrics; FID, IS, KID. We continue
      the progress of our evaluation function for the LDM upscalaer and may update this function accordingly.
        
      checkpoint:      Name of the saved pth. file containing the trained weights and biases  
      experiment_path: Path to the experiment folder where the evaluation results will be stored  
      dataloader:      Loads the test dataset for evaluation
      sample_idx:      Integer that denotes which sample directory sample_{sample_idx} from the checkpoint model shall be used for evaluation
      '''
    
      checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}'
      # create evaluation directory for the complete experiment (if first time sampling images)
      output_dir = f'{experiment_path}evaluations/'
      os.makedirs(output_dir, exist_ok=True)
    
      # create evaluation directory for the current version of the trained model
      model_name = os.path.basename(checkpoint_path)
      epoch = re.findall(r'\d+', model_name)
      if epoch:
          e = int(epoch[0])
      else:
          raise ValueError(f"No digit found in the filename: {filename}")
      model_dir = os.path.join(output_dir,f'epoch_{e}')
      os.makedirs(model_dir, exist_ok=True)
    
      # create the evaluation directory for this evaluation run for the current version of the model    
      eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
      indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')]
      j = max(indx_list, default=-1) + 1
      eval_dir = os.path.join(model_dir, f'evaluation_{j}')
      os.makedirs(eval_dir, exist_ok=True)
      
      # Compute metrics  
      eval_path = os.path.join(eval_dir, 'eval.txt')
     
      # get sampled images
      transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
      sample_path =  os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}')
      ignore_tensor = f'image_tensor{j}'
      images = []
      for samplename in os.listdir(sample_path):
            if samplename == ignore_tensor:
                continue
            img = Image.open(os.path.join(sample_path, samplename))
            img = transform(img)
            images.append(img)
      # split them into batches for GPU memory
      generated = torch.stack(images).to(device)
      generated_batches = torch.split(generated, dataloader.batch_size)
      nr_generated_batches = len(generated_batches)
      nr_real_batches = len(dataloader)
    
      # Init FID, IS and KID scores
      fid = FrechetInceptionDistance(normalize = False).to(device)
      iscore = InceptionScore(normalize=False).to(device)
      kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device)
    
      # Update scores for the full testing dataset w.r.t. the sampled batches
      for idx,(data, _) in enumerate(dataloader):
        data = data.to(device)
        fid.update(data, real=True)
        kid.update(data, real=True)
        if idx < nr_generated_batches:
            gen = generated_batches[idx].to(device)
            fid.update(gen, real=False)
            kid.update(gen, real=False)
            iscore.update(gen)
    
      # If there are sampled images left, add them too
      for idx in range(nr_real_batches, nr_generated_batches):
        gen = generated_batches[idx].to(device)
        fid.update(gen, real=False)
        kid.update(gen, real=False)
        iscore.update(gen)
     
      # compute total FID, IS and KID 
      fid_score = fid.compute()
      i_score = iscore.compute()
      kid_score = kid.compute()
    
      # store results in txt file
      with open(str(eval_path), 'a') as txt:
        result = f'FID_epoch_{e}_sample_{sample_idx}:'
        txt.write(result + str(fid_score.item()) + '\n')
        result = f'KID_epoch_{e}_sample_{sample_idx}:'
        txt.write(result + str(kid_score) + '\n')
        result =  f'IS_epoch_{e}_sample_{sample_idx}:'
        txt.write(result + str(i_score) + '\n')