Select Git revision
vite.config.js
-
Benedikt Heinrichs authoredBenedikt Heinrichs authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 6.77 KiB
from diffusers import UNet2DModel
from trainer.train import ddpm_trainer
from evaluation.sample import ddpm_sampler
from evaluation.evaluate import ddpm_evaluator
def train_func(f):
#load all settings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
with open(f+"/meta_setting.json","r") as fp:
meta_setting = json.load(fp)
with open(f+"/dataset_setting.json","r") as fp:
dataset_setting = json.load(fp)
with open(f+"/model_setting.json","r") as fp:
model_setting = json.load(fp)
with open(f+"/framework_setting.json","r") as fp:
framework_setting = json.load(fp)
with open(f+"/training_setting.json","r") as fp:
training_setting = json.load(fp)
training_setting["optimizer_class"] = eval(training_setting["optimizer_class"])
batchsize = meta_setting["batchsize"]
training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting)
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
training_dataloader = torch.utils.data.DataLoader( training_dataset,batch_size=batchsize)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize)
#model = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(model)
net = UNet2DModel(
sample_size=64,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
net = net.to(device)
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n")
print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
print(f"TRAINING SETTINGS:\n\n {training_setting}\n\n")
print("\n\nSTART TRAINING\n\n")
globals()[meta_setting["trainloop_function"]](model=framework,device=device, trainloader = training_dataloader, testloader = test_dataloader,safepath = f,**training_setting,)
print("\n\nFINISHED TRAINING\n\n")
def sample_func(f):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
with open(f+"/meta_setting.json","r") as fp:
meta_setting = json.load(fp)
with open(f+"/model_setting.json","r") as fp:
model_setting = json.load(fp)
with open(f+"/framework_setting.json","r") as fp:
framework_setting = json.load(fp)
with open(f+"/sampling_setting.json","r") as fp:
sampling_setting = json.load(fp)
# init Unet
batchsize = meta_setting["batchsize"]
#model = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(model)
net = UNet2DModel(
sample_size=64,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
net = net.to(device)
# init unconditional diffusion model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
print(f"SAMPLING SETTINGS:\n\n {sampling_setting}\n\n")
print("\n\nSTART SAMPLING\n\n")
globals()[meta_setting["sampling_function"]](model=framework,device=device, **sampling_setting,)
print("\n\nFINISHED SAMPLING\n\n")
def evaluate_func(f):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
with open(f+"/meta_setting.json","r") as fp:
meta_setting = json.load(fp)
with open(f+"/model_setting.json","r") as fp:
model_setting = json.load(fp)
with open(f+"/framework_setting.json","r") as fp:
framework_setting = json.load(fp)
with open(f+"/evaluation_setting.json","r") as fp:
evaluation_setting = json.load(fp)
with open(f+"/dataset_setting.json","r") as fp:
dataset_setting = json.load(fp)
# load dataset
batchsize = meta_setting["batchsize"]
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize)
# init Unet
#model = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(model)
net = UNet2DModel(
sample_size=64,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
net = net.to(device)
# init unconditional diffusion model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n")
print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n")
print("\n\nSTART EVALUATION\n\n")
globals()[meta_setting["evaluation_function"]](model=framework, device=device, testloader = test_dataloader,safepath = f,**evaluation_setting,)
print("\n\nFINISHED EVALUATION\n\n")
def pipeline_func(f):
# TODO
#train_func(f)
generate_func(f)
#evaluate_func(f)
def hello(name):
print(f'Hello {name}!')
if __name__ == '__main__':
import json
import sys
from trainer.train import *
from dataloader.load import *
from models.Framework import *
from models.unet_unconditional_diffusion import *
from models.unet import UNet
import torch
from torch import nn
print(sys.argv)
functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func,"hello":hello}
functions[sys.argv[1]](sys.argv[2])