Skip to content
Snippets Groups Projects
Select Git revision
  • c4410482ec88746cdc926fdf72769cbe1cf0d00b
  • main default protected
  • celebAHQ
  • ddpm-diffusers
4 results

main.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    main.py 5.10 KiB
    
    import json
    import sys
    from dataloader.load import  *
    from models.Framework import *
    from trainer.train import ddpm_trainer
    from evaluation.sample import ddpm_sampler
    from evaluation.evaluate import ddpm_evaluator
    from models.all_unets import *
    import torch 
    
    
    def train_func(f):
      #load all settings 
      device = 'cuda' if torch.cuda.is_available() else 'cpu'
      print(f"device: {device}\n\n")
      print(f"folderpath: {f}\n\n")
      with open(f+"/meta_setting.json","r") as fp:
          meta_setting = json.load(fp)
    
      with open(f+"/dataset_setting.json","r") as fp:
          dataset_setting = json.load(fp)
      
      with open(f+"/model_setting.json","r") as fp:
          model_setting = json.load(fp)
      
      with open(f+"/framework_setting.json","r") as fp:
          framework_setting = json.load(fp)
    
      with open(f+"/training_setting.json","r") as fp:
          training_setting = json.load(fp)
          training_setting["optimizer_class"] = eval(training_setting["optimizer_class"])
    
    
      batchsize = meta_setting["batchsize"]
      training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting)
      test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
    
      training_dataloader = torch.utils.data.DataLoader( training_dataset,batch_size=batchsize)
      test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize)
      
    
      net = globals()[meta_setting["modelname"]](**model_setting).to(device)  
      #net = torch.compile(net)
      net = net.to(device)
      
      framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
      
      print(f"META SETTINGS:\n\n {meta_setting}\n\n")
      print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n")
      print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
      print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
      print(f"TRAINING SETTINGS:\n\n {training_setting}\n\n")
    
      print("\n\nSTART TRAINING\n\n")
      globals()[meta_setting["trainloop_function"]](model=framework,device=device, trainloader = training_dataloader, testloader = test_dataloader,safepath = f,**training_setting,)
      print("\n\nFINISHED TRAINING\n\n")
    
    
    
    def sample_func(f):
      device = 'cuda' if torch.cuda.is_available() else 'cpu'
      print(f"device: {device}\n\n")
      print(f"folderpath: {f}\n\n")
    
      with open(f+"/meta_setting.json","r") as fp:
          meta_setting = json.load(fp)
    
      with open(f+"/model_setting.json","r") as fp:
          model_setting = json.load(fp)
    
      with open(f+"/framework_setting.json","r") as fp:
          framework_setting = json.load(fp)
    
      with open(f+"/sampling_setting.json","r") as fp:
          sampling_setting = json.load(fp)
    
      # init Unet
      batchsize = meta_setting["batchsize"]
      net = globals()[meta_setting["modelname"]](**model_setting).to(device)
      #net = torch.compile(net)
      net = net.to(device)
      # init unconditional diffusion model
      framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
    
      print(f"META SETTINGS:\n\n {meta_setting}\n\n")
      print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
      print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
      print(f"SAMPLING SETTINGS:\n\n {sampling_setting}\n\n")
    
      print("\n\nSTART SAMPLING\n\n")
      globals()[meta_setting["sampling_function"]](model=framework,device=device, **sampling_setting,)
      print("\n\nFINISHED SAMPLING\n\n")
    
    
    
    def evaluate_func(f):
      device = 'cuda' if torch.cuda.is_available() else 'cpu'
      print(f"device: {device}\n\n")
      print(f"folderpath: {f}\n\n")
    
      with open(f+"/meta_setting.json","r") as fp:
          meta_setting = json.load(fp)
    
      with open(f+"/model_setting.json","r") as fp:
          model_setting = json.load(fp)
    
      with open(f+"/framework_setting.json","r") as fp:
          framework_setting = json.load(fp)
    
      with open(f+"/evaluation_setting.json","r") as fp:
          evaluation_setting = json.load(fp)
    
      with open(f+"/dataset_setting.json","r") as fp:
          dataset_setting = json.load(fp)
    
      # load dataset
      batchsize = meta_setting["batchsize"]
      test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
      #test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset),  shuffle=False)
      test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,  shuffle=False)
    
      # init Unet
      net = globals()[meta_setting["modelname"]](**model_setting).to(device)
      #net = torch.compile(net)
      net = net.to(device)
      
      # init unconditional diffusion model
      framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
    
      print(f"META SETTINGS:\n\n {meta_setting}\n\n")
      print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n")
      print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
      print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
      print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n")
    
      print("\n\nSTART EVALUATION\n\n")
      globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,)
      print("\n\nFINISHED EVALUATION\n\n")
    
    
    
    
      
    if __name__ == '__main__':
        
      
      print(sys.argv)
      functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func}
      functions[sys.argv[1]](sys.argv[2])