Skip to content
Snippets Groups Projects
Commit 472430c1 authored by Christoph Reich's avatar Christoph Reich
Browse files

Initial commit

parent 01fa69b2
Branches
No related tags found
No related merge requests found
Showing
with 1906 additions and 0 deletions
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
LICENSE 0 → 100644
MIT License
Copyright (c) 2021 Christoph Reich & Tim Prangemeier (BCS, TU Darmstadt)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
README.md 0 → 100644
# Multi-StyleGAN: Towards Image-Based Simulation of Time-Lapse Live-Cell Microscopy
[![arXiv](https://img.shields.io/badge/cs.CV-arXiv%3A2106.08285-B31B1B.svg)](https://arxiv.org/abs/2106.08285)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/ChristophReich1996/Multi-StyleGAN/blob/master/LICENSE)
**[Christoph Reich*](https://github.com/ChristophReich1996), [Tim Prangemeier*](https://www.bcs.tu-darmstadt.de/bcs_team/prangemeiertim.en.jsp), [Christian Wildner](https://www.bcs.tu-darmstadt.de/bcs_team/wildnerchristian.en.jsp) & [Heinz Koeppl](https://www.bcs.tu-darmstadt.de/bcs_team/koepplheinz.en.jsp)**<br/>
*Christoph Reich and Tim Prangemeier - both authors contributed equally
## | [Project Page](https://christophreich1996.github.io/multi_stylegan) | [Paper](https://arxiv.org/abs/2106.08285) | [Dataset]() |
<p align="center">
<img src="/github/latent_space_interpolation.gif" alt="1" width = 288px height = 192px >
</p>
<p align="center">
This repository includes the <b>official</b> and <b>maintained</b> <a href="https://pytorch.org/">PyTorch</a> implementation of the paper <a href="https://arxiv.org/abs/2106.08285"> Multi-StyleGAN: Towards Image-Based Simulation of Time-Lapse Live-Cell Microscopy</a>
</p>
## Abstract
*Time-lapse fluorescent microscopy (TLFM) combined with
predictive mathematical modelling is a powerful tool to study the inherently dynamic processes of life on the single-cell level. Such experiments
are costly, complex and labour intensive. A complimentary approach
and a step towards completely in silico experiments, is to synthesise
the imagery itself. Here, we propose Multi-StyleGAN as a descriptive
approach to simulate time-lapse fluorescence microscopy imagery of living cells, based on a past experiment. This novel generative adversarial
network synthesises a multi-domain sequence of consecutive timesteps.
We showcase Multi-StyleGAN on imagery of multiple live yeast cells in
microstructured environments and train on a dataset recorded in our laboratory. The simulation captures underlying biophysical factors and time
dependencies, such as cell morphology, growth, physical interactions, as
well as the intensity of a fluorescent reporter protein. An immediate application is to generate additional training and validation data for feature
extraction algorithms or to aid and expedite development of advanced
experimental techniques such as online monitoring or control of cells.*
**If you find this research useful in your work, please cite our paper:**
```bibtex
@inproceedings{Reich2021,
title={{Multi-StyleGAN: Towards Image-Based Simulation of Time-Lapse Live-Cell Microscopy}},
author={Reich, Christoph and Prangemeier, Tim and Wildner, Christian and Koeppl, Heinz},
booktitle={{International Conference on Medical image computing and computer-assisted intervention (in press)}},
year={2021},
organization={Springer}
}
```
## Method
<img src="/github/Multi-StyleGAN.png" alt="1" width = 617px height = 176px ><br/>
**Figure 1.** Architecture of Multi-StyleGAN. The style mapping network <img src="https://render.githubusercontent.com/render/math?math=f"> (in purple)
transforms the input noise vector <img src="https://render.githubusercontent.com/render/math?math=z\sim \mathcal{N}_{512}(0, 1)"> into a latent vector <img src="https://render.githubusercontent.com/render/math?math=w\in\mathcal{W}">, which in
turn is incorporated into each stage of the generator by three dual-style-convolutional
blocks. The generator predicts a sequence of three consecutive images for both
the brightfield and green fluorescent protein channels. The U-Net discriminator [2] distinguishes between real and
a fake input sequences by making both a scalar and a pixel-wise real/fake prediction.
Standard residual discriminator blocks in gray and non-local blocks in blue.
<img src="/github/Dual-styled-convolutional_block.png" alt="1" width = 451px height = 221px ><br/>
**Figure 2.** Dual-styled-convolutional block of the Multi-StyleGAN. The incoming latent
vector w is transformed into the style vector s by a linear layer. This style vector modulates (mod) the convolutional weights <img src="https://render.githubusercontent.com/render/math?math=\theta_{b}"> and <img src="https://render.githubusercontent.com/render/math?math=\theta_{g}">, which are optionally demodulated
(demod) before convolving the (optionally bilinearly upsampled) incoming features
of the previous block. Learnable biasses (<img src="https://render.githubusercontent.com/render/math?math=b_{b}"> and <img src="https://render.githubusercontent.com/render/math?math=b_{g}">) and channel-wise Gaussian noise
(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{N}">) scaled by a learnable constant (cb and cg), are added to the features. The final
output features are obtained by applying a leaky ReLU activation.
## Results
<img src="/github/prediction_ema_100_bf_0.png" alt="1" width = 288px height = 96px > <img src="/github/prediction_ema_100_bf_12.png" alt="1" width = 288px height = 96px ><br/>
<img src="/github/prediction_ema_100_gfp_0.png" alt="1" width = 288px height = 96px > <img src="/github/prediction_ema_100_gfp_12.png" alt="1" width = 288px height = 96px ><br/>
**Figure 3.** Samples generated by Multi-StyleGAN. Brightfield channel on the top and green fluorescent protein on the bottom.<br/>
**Table 1.** Evaluation metrics for Multi-StyleGAN and baselines.
| Model | FID (BF) <img src="https://render.githubusercontent.com/render/math?math=\downarrow"> | FVD (BF) <img src="https://render.githubusercontent.com/render/math?math=\downarrow"> | FID (GFP) <img src="https://render.githubusercontent.com/render/math?math=\downarrow"> | FVD (GFP) <img src="https://render.githubusercontent.com/render/math?math=\downarrow"> |
| --- | --- | --- | --- | --- |
| Multi-StyleGAN | **33.3687** | **4.4632** | **207.8409** | **30.1650** |
| StyleGAN 2 3d + ADA + U-Net dis. | 200.5408 | 45.6296 | 224.7860 | 35.2169 |
| StyleGAN 2 + ADA + U-Net dis. | 76.0344 | 14.7509 | 298.7545 | 31.4771 |
## Dependencies
All required Python packages can be installed by:
```shell script
pip install -r requirements.txt
```
To install the necessary custom CUDA extensions adapted from [StyleGAN 2](https://github.com/NVlabs/stylegan2) [1] run:
```shell script
cd multi_stylegan/op_static
python setup.py install
```
The code is tested with [PyTorch 1.8.1](https://pytorch.org/get-started/locally/) and CUDA 11.1 on Ubuntu with Python 3.6!
Using other PyTorch and CUDA version newer than [PyTorch 1.7.0](https://pytorch.org/get-started/previous-versions/) and
CUDA 10.1 should also be possible. Please note using a different PyTorch version eventually requires a different version
of [Kornia](https://kornia.github.io/) or [Torchvision](https://pytorch.org/vision/stable/index.html).
## Data
**Our proposed time-lapse fluorescent microscopy is available at [this url](https://arxiv.org/pdf/2106.08285.pdf).**
The dataset includes 9696 images structured in sequences of both brightfield and green fluorescent protein (GFP) channels at a resolution of 256 × 256.
## Trained Model
**The checkpoint of our trained Multi-StyleGAN is available at [this url]().**
The checkpoint (PyTorch state dict) includes the EMA generator weights (`"generator_ema"`), the generator weights
(`"generator"`), the generator optimizer state (`"generator_optimizer"`), the discriminator weights (`"discriminator"`),
the discriminator optimizer state (`"discriminator_optimizer"`), and the path-length regularization states
(`"path_length_regularization"`)
## Usage
To train Multi-StyleGAN in the proposed setting run the following command:
```shell script
python -W ingore train_gan.py --cuda_devices "0, 1, 2, 3" --data_parallel --path_to_data "60x_10BF_200GFP_200RFP20_3Z_10min"
```
Dataset path and cuda devices may differ on other systems!
To perform training runs with different settings use the command line arguments of the [`train_gan.py`](train_gan.py) file.
The [`train_gan.py`](train_gan.py) takes the following command line arguments:
|Argument | Default value | Info|
|--- | --- | ---|
| --cuda_devices (str) | `"0, 1, 2, 3"` | String of cuda device indexes to be used. |
| --batch_size (int) | `24` | Batch size to be utilized while training. |
| --data_parallel (binary flag) | False | Binary flag. If set data parallel is utilized. |
| --epochs (int) | 100 | Number of epochs to perform while training. |
| --lr_generator (float) | `2e-04` | Learning rate of the generator network. |
| --lr_discriminator (float) | `6e-04` | Learning rate of the discriminator network. |
| --path_to_data (str) | `"./60x_10BF_200GFP_200RFP20_3Z_10min"` | Path to dataset. |
| --load_checkpoint (str) | `""` | Path to checkpoint to be loaded. If `""` no loading is performed. |
| --resume_training (binary flag) | False | Binary flag. If set training is resumed and so cut mix aug/reg and wrong order aug is used. |
| --no_top_k (binary flag) | False | Binary flag. If set no top-k is utilized. |
| --no_ada (binary flag) | False | Binary flag. If set no adaptive discriminator augmentation is utilized. |
To generate samples of the trained Multi-StyleGAN use the [`get_gan_samples.py`](scripts/get_gan_samples.py) script.
```shell script
python -W ingore scripts/get_gan_samples.py --cuda_devices "0" --load_checkpoint "checkpoint_100.pt"
```
This script takes the following command line arguments:
|Argument | Default value | Info|
|--- | --- | ---|
| --cuda_devices (str) | `"0"` | String of cuda device indexes to be used. |
| --samples (int) | `100` | Number of samples to be generated. |
| --load_checkpoint (str) | `"checkpoint_100.pt"` | Path to checkpoint to be loaded. |
To generate a latent space interpolation use the [`gan_latent_space_interpolation.py`](scripts/gan_latent_space_interpolation.py) script.
For producing the final `.mp4` video [`ffmpeg`](https://www.ffmpeg.org/) is required.
```shell script
python -W ingore scripts/gan_latent_space_interpolation.py --cuda_devices "0" --load_checkpoint "checkpoint_100.pt"
```
This script takes the following command line arguments:
|Argument | Default value | Info|
|--- | --- | ---|
| --cuda_devices (str) | `"0"` | String of cuda device indexes to be used. |
| --load_checkpoint (str) | `"checkpoint_100.pt"` | Path to checkpoint to be loaded. |
## Acknowledgements
We thank [Markus Baier](https://www.bcs.tu-darmstadt.de/bcs_team/index.en.jsp) for aid with the computational setup,
[Klaus-Dieter Voss](https://www.bcs.tu-darmstadt.de/bcs_team/index.en.jsp) for aid with the microfluidics
fabrication, and Tim Kircher, [Tizian Dege](https://github.com/TixXx1337), and
[Florian Schwald](https://github.com/FlorianSchwald59) for aid with the data preparation.
We also thank [piergiaj](https://github.com/piergiaj) for providing a [PyTorch i3d](https://github.com/piergiaj/pytorch-i3d)
implementation and trained models, which we used to compute the FVD score. The used code is indicated and is available
under the [original licence](https://github.com/piergiaj/pytorch-i3d/blob/master/LICENSE.txt).
## References
```bibtex
[1] @inproceedings{Karras2020,
title={Analyzing and improving the image quality of stylegan},
author={Karras, Tero and Laine, Samuli and Aittala, Miika and Hellsten, Janne and Lehtinen, Jaakko and Aila, Timo},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={8110--8119},
year={2020}
}
```
```bibtex
[2] @inproceedings{Schonfeld2020,
title={A u-net based discriminator for generative adversarial networks},
author={Schonfeld, Edgar and Schiele, Bernt and Khoreva, Anna},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={8207--8216},
year={2020}
}
```
# Import datasets
from .tlfm_dataset import TFLMDatasetGAN
from typing import Optional, Tuple, Union, List
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
import cv2
import math
from dataset import utils
class TFLMDatasetGAN(Dataset):
"""
This class implements the unsupervised TFLM dataset including data from a trapped yeast cell TFLM experiment for
the generation task.
"""
def __init__(self, path: str,
sequence_length: int = 3,
overlap: bool = True,
transformations: transforms.Compose = transforms.Compose(
[transforms.RandomHorizontalFlip(p=0.5)]),
z_position_indications: Tuple[str] = ("_000_", "_001_", "_002_"),
gfp_min: Union[float, int] = 150.0,
gfp_max: Union[float, int] = 2200.0,
rfp_min: Union[float, int] = 20.0,
rfp_max: Union[float, int] = 2000.0,
flip: bool = True,
positions: Optional[Tuple[str, ...]] = None,
no_rfp: bool = False,
no_gfp: bool = False) -> None:
"""
Constructor method
:param path: (str) Path to dataset
:param sequence_length: (int) Length of sequence to be returned
:param overlap: (bool) If true sequences can overlap
:param transformations: (transforms.Compose) Transformations and augmentations to be applied
:param z_position_indications: (Tuple[str]) String to indicate each z position
:param gfp_min: (Union[float, int]) Minimal value assumed gfp value
:param gfp_max: (Union[float, int]) Maximal value assumed gfp value
:param rfp_min: (Union[float, int]) Minimal value assumed rfp value
:param rfp_max: (Union[float, int]) Maximal value assumed rfp value
:param flip: (bool) If true images are flipped vertically
:param positions: (Optional[Tuple[str, ...]]) If given only positions which are given are loaded
:param no_rfp: (bool) If true no rfp channel is utilized
:param no_rfp: (bool) If true nogfp channel is utilized
"""
# Save parameters
self.transformations = transformations
self.gfp_min = gfp_min
self.gfp_max = gfp_max
self.rfp_min = rfp_min
self.rfp_max = rfp_max
self.flip = flip
self.no_rfp = no_rfp
self.no_gfp = no_gfp
# Load data sample paths
self.paths_to_dataset_samples = []
# Iterate over all position folders
for position_folder in os.listdir(path=path):
if (positions is None) or (position_folder in positions):
# Check that current folder is really a folder
if os.path.isdir(os.path.join(path, position_folder)):
# Load images all in folder
all_images = [os.path.join(path, position_folder, image_file) for image_file in
os.listdir(os.path.join(path, position_folder)) if "tif" in image_file]
# Get all BF images
all_bf_images = [image_file for image_file in all_images if "-BF0_" in image_file]
# Get all GFP images
all_gfp_images = [image_file for image_file in all_images if "-GFP" in image_file]
# Get all RFP images
all_rfp_images = [image_file for image_file in all_images if "-RFP" in image_file]
# Convert list of images to list of z positions including images
bf_images = []
for z_position_indication in z_position_indications:
bf_images.append(
[image_file for image_file in all_bf_images if z_position_indication in image_file])
# Sort images by time steps and trap number
bf_images[-1].sort(key=lambda item:
item.split("-")[-1].split("_")[-1].replace(".tif", "") +
item.split("_")[-5])
gfp_images = []
for z_position_indication in z_position_indications:
gfp_images.append(
[image_file for image_file in all_gfp_images if z_position_indication in image_file])
# Sort images by time steps and trap number
gfp_images[-1].sort(key=lambda item:
item.split("-")[-1].split("_")[-1].replace(".tif", "") +
item.split("_")[-5])
rfp_images = []
for z_position_indication in z_position_indications:
rfp_images.append(
[image_file for image_file in all_rfp_images if z_position_indication in image_file])
# Sort images by time steps and trap number
rfp_images[-1].sort(key=lambda item:
item.split("-")[-1].split("_")[-1].replace(".tif", "") +
item.split("_")[-5])
# Construct image sequences
for z_position in range(len(z_position_indications)):
for index in range(0, len(bf_images[z_position]) - sequence_length + 1,
1 if overlap else sequence_length):
if self._check_if_same_trap(bf_images[z_position][index:index + sequence_length]):
# Save paths
self.paths_to_dataset_samples.append(
(tuple(bf_images[z_position][index:index + sequence_length]),
tuple(gfp_images[z_position][index:index + sequence_length]),
tuple(rfp_images[z_position][index:index + sequence_length])))
def _check_if_same_trap(self, path_list: List[str]) -> bool:
"""
Method checks of a sequence of images paths include the same trap.
:param path_list: (List[str]) List of strings
:return: (bool) If same trap true else false
"""
traps = [path[path.find("trap"):path.find("trap") + 8] for path in path_list]
return all(trap == traps[0] for trap in traps)
def __len__(self) -> int:
"""
Returns the length of the dataset.
:return: (int) Length of the dataset.
"""
return len(self.paths_to_dataset_samples)
def __getitem__(self, item: int) -> torch.Tensor:
"""
Method returns one instance with the index item of the dataset.
:param item: (int) Index of the dataset element to be returned
:return: (torch.Tensor) Image sequence of n images
"""
# Get paths
path_bf_images, path_gfp_images, path_rfp_images = self.paths_to_dataset_samples[item]
# Load bf images
bf_images = []
for path_bf_image in path_bf_images:
image = cv2.imread(path_bf_image, -1).astype(np.float32)
image = torch.from_numpy(image)
bf_images.append(image)
bf_images = torch.stack(bf_images, dim=0)
# Load gfp images
if not self.no_gfp:
gfp_images = []
for path_gfp_image in path_gfp_images:
image = cv2.imread(path_gfp_image, -1).astype(np.float32)
image = torch.from_numpy(image)
gfp_images.append(image)
gfp_images = torch.stack(gfp_images, dim=0)
# Load rfp images
if not self.no_rfp:
rfp_images = []
for path_rfp_image in path_rfp_images:
image = cv2.imread(path_rfp_image, -1).astype(np.float32)
image = torch.from_numpy(image)
rfp_images.append(image)
rfp_images = torch.stack(rfp_images, dim=0)
if self.no_gfp:
# Concat images
images = torch.cat([bf_images], dim=0)
# Perform transformations
images = self.transformations(images)
# Remove batch dimension
images = images[0] if images.ndimension() == 4 else images
# Reshape images to [1, sequence length, height, width]
images = images.unsqueeze(dim=0)
elif self.no_rfp:
# Concat images
images = torch.cat([bf_images, gfp_images], dim=0)
# Perform transformations
images: torch.Tensor = self.transformations(images)
# Remove batch dimension
images = images[0] if images.ndimension() == 4 else images
# Reshape images to [2, sequence length, height, width]
images = torch.stack(images.split(split_size=images.shape[0] // 2, dim=0), dim=0)
else:
# Concat images
images = torch.cat([bf_images, gfp_images, rfp_images], dim=0)
# Perform transformations
images = self.transformations(images)
# Remove batch dimension
images = images[0] if images.ndimension() == 4 else images
# Reshape images to [3, sequence length, height, width]
images = torch.stack(images.split(split_size=images.shape[0] // 3, dim=0), dim=0)
# Normalized bf images
images[0] = utils.normalize_0_1(images[0])
# Normalize gfp images
if not self.no_gfp:
# images[1] = utils.normalize_0_1(images[1])
images[1] = ((images[1] - self.gfp_min).clamp(min=0.0) / self.gfp_max).clamp(max=1.0)
# Normalize rfp images
if not self.no_rfp:
# images[2] = utils.normalize_0_1(images[2])
images[2] = ((images[2] - self.rfp_min).clamp(min=0.0) / self.rfp_max).clamp(max=1.0)
# Flip images if utilized
images = images.flip(dims=(-2,)) if self.flip else images
return images
class ElasticDeformation(nn.Module):
"""
This module implements random elastic deformation.
"""
def __init__(self, sample_mode: str = "bilinear", alpha: int = 80,
sigma: int = 16) -> None:
"""
Constructor method
:param sample_mode: (str) Resmapling mode
:param alpha: (int) Scale factor of the deformation
:param sigma: (int) Standard deviation of the gaussian kernel to be applied
"""
# Call super constructor
super(ElasticDeformation, self).__init__()
# Save parameters
self.sample_mode = sample_mode
self.alpha = alpha
self.sigma = sigma
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass applies random elastic deformation
:param input: (torch.Tensor) Input tensor
:return: (torch.Tensor) Augmented output tensor
"""
return elastic_deformation(img=input, sample_mode=self.sample_mode, alpha=self.alpha, sigma=self.sigma)
def elastic_deformation(img: torch.Tensor, sample_mode: str = "bilinear", alpha: int = 50,
sigma: int = 12) -> torch.Tensor:
"""
Performs random elastic deformation to the given Tensor image
:param img: (torch.Tensor) Input image
:param sample_mode: (str) Resmapling mode
:param alpha: (int) Scale factor of the deformation
:param sigma: (int) Standard deviation of the gaussian kernel to be applied
"""
# Get image shape
height, width = img.shape[-2:]
# Get kernel size
kernel_size = (sigma * 4) + 1
# Get mean of gaussian kernel
mean = (kernel_size - 1) / 2.
# Make gaussian kernel
# https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/7
x_cord = torch.arange(kernel_size, device=img.device)
x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1)
gaussian_kernel = (1. / (2. * math.pi * sigma ** 2)) \
* torch.exp(-torch.sum((xy_grid - mean) ** 2., dim=-1) / (2. * sigma ** 2))
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
gaussian_kernel = gaussian_kernel.repeat(1, 1, 1, 1)
gaussian_kernel.requires_grad = False
# Make random deformations in the range of [-1, 1]
dx = (torch.rand((height, width), dtype=torch.float, device=img.device) * 2. - 1.).view(1, 1, height, width)
dy = (torch.rand((height, width), dtype=torch.float, device=img.device) * 2. - 1.).view(1, 1, height, width)
# Apply gaussian filter to deformations
dx, dy = torch.nn.functional.conv2d(input=torch.cat([dx, dy], dim=0), weight=gaussian_kernel, stride=1,
padding=kernel_size // 2).squeeze(dim=0) * alpha
# Add deformations to coordinate grid
grid = torch.stack(torch.meshgrid([torch.arange(height, dtype=torch.float, device=img.device),
torch.arange(width, dtype=torch.float, device=img.device)]),
dim=-1).unsqueeze(dim=0).flip(dims=(-1,))
grid[..., 0] += dx
grid[..., 1] += dy
# Convert grid to relative sampling location in the range of [-1, 1]
grid[..., 0] = 2 * (grid[..., 0] - (height // 2)) / height
grid[..., 1] = 2 * (grid[..., 1] - (width // 2)) / width
# Resample image
img_deformed = torch.nn.functional.grid_sample(input=img[None] if img.ndimension() == 3 else img,
grid=grid, mode=sample_mode, padding_mode='border',
align_corners=False)[0]
return img_deformed
import torch
def normalize_0_1(tensor: torch.Tensor, max: float = None, min: float = None) -> torch.Tensor:
"""
Function normalizes a given input tensor channel-wise to a rage between zero and one.
:param tensor: (torch.Tensor) Input tensor of the shape [channels, height, width]
:param max: (float) Max value utilized in the normalization
:param min: (float) Min value of the normalization
:return: (torch.Tensor) Normalized input tensor of the same shape as the input
"""
# Save shape
channels, height, width = tensor.shape
# Flatten input tensor to the shape [channels, height * width]
tensor = tensor.flatten(start_dim=1)
# Get channel wise min and max
tensor_min = tensor.min(dim=1, keepdim=True)[0].float() if min is None else torch.tensor(min, dtype=torch.float)
tensor_max = tensor.max(dim=1, keepdim=True)[0].float() if max is None else torch.tensor(max, dtype=torch.float)
# Normalize tensor
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
# Reshape tensor to original shape
tensor = tensor.reshape(channels, height, width)
return tensor
github/Dual-styled-convolutional_block.png

85.1 KiB

github/Multi-StyleGAN.png

111 KiB

Image diff could not be displayed: it is too large. Options to address this: view the blob.
File added
github/prediction_ema_100_bf_0.png

122 KiB

github/prediction_ema_100_bf_12.png

146 KiB

github/prediction_ema_100_gfp_0.png

82 KiB

github/prediction_ema_100_gfp_12.png

106 KiB

# Import 2D U-Net discriminator
from .u_net_2d_discriminator import Discriminator as MultiStyleGANDiscriminator
# Import twin 2D StyleGAN2 generator
from .multi_stylegan_generator import Generator as MultiStyleGANGenerator
# Import configs
from .config import generation_hyperparameters, multi_style_gan_generator_config, u_net_2d_discriminator_config
# Import model wrapper
from .model_wrapper import ModelWrapper
# Import data logger
from .misc import Logger
# Import validation metrics
from .validation_metrics import IS, FID, FVD
# Import losses
from .loss import PathLengthRegularization
# Import ADA
from .adaptive_discriminator_augmentation import AdaptiveDiscriminatorAugmentation, AugmentationPipeline
from typing import Union, Tuple
import torch
import torch.nn as nn
import kornia.augmentation.functional as kaf
import numpy as np
import random
import math
class AdaptiveDiscriminatorAugmentation(nn.Module):
"""
This class implements adaptive discriminator augmentation proposed in:
https://arxiv.org/pdf/2006.06676.pdf
The adaptive discriminator augmentation model wraps a given discriminator network.
"""
def __init__(self, discriminator: Union[nn.Module, nn.DataParallel], r_target: float = 0.6,
p_step: float = 5e-03, r_update: int = 8, p_max: float = 0.8) -> None:
"""
Constructor method
:param discriminator: (Union[nn.Module, nn.DataParallel]) Discriminator network
:param r_target: (float) Target value for r
:param p_step: (float) Step size of p
:param r_update: (int) Update frequency of r
:param p_max: (float) Global max value of p
"""
# Call super constructor
super(AdaptiveDiscriminatorAugmentation, self).__init__()
# Save parameters
self.discriminator = discriminator
self.r_target = r_target
self.p_step = p_step
self.r_update = r_update
self.p_max = p_max
# Init augmentation variables
self.r = []
self.p = 0.05
self.r_history = []
# Init augmentation pipeline
self.augmentation_pipeline = AugmentationPipeline()
@torch.no_grad()
def __calc_r(self, prediction_scalar: torch.Tensor, prediction_pixel_wise: torch.Tensor) -> float:
"""
Method computes the overfitting heuristic r.
:param prediction_scalar: (torch.Tensor) Scalar prediction [batch size, 1]
:param prediction_pixel_wise: (torch.Tensor) Pixel-wise prediction [batch size, 1, height, width]
:return: (float) Value of the overfitting heuristic r
"""
return (0.5 * torch.mean(torch.sign(prediction_scalar))
+ 0.5 * torch.mean(torch.sign(prediction_pixel_wise.mean(dim=(-1, -2))))).item()
def forward(self, images: torch.Tensor, is_real: bool = False,
is_cut_mix: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass
:param images: (torch.Tensor) Mini batch of images (real or fake) [batch size, channels, time steps, height, width]
:param is_real: (bool) If true real images are utilized as the input
:param is_cut_mix: (bool) If true cut mix is utilized and no augmentation is performed
:return: (Tuple[torch.Tensor, torch.Tensor]) Scalar and pixel-wise real/fake prediction of the discriminator
"""
# Case if cut mix is utilized
if is_cut_mix:
return self.discriminator(images)
# Reshape images to [batch size, channels * time steps, height, width]
original_shape = images.shape
images = images.flatten(start_dim=1, end_dim=2)
# Apply augmentations
images: torch.Tensor = self.augmentation_pipeline(images, self.p)
# Reshape images again to original shape
images = images.view(original_shape)
# Discriminator prediction
prediction_scalar, prediction_pixel_wise = self.discriminator(images)
# If fake images are given compute overfitting heuristic
if not is_real:
self.r.append(self.__calc_r(prediction_scalar=prediction_scalar.detach(),
prediction_pixel_wise=prediction_pixel_wise.detach()))
# Update p
if len(self.r) >= self.r_update:
# Calc r over the last epochs
r = np.mean(self.r)
# If r above target value increment p else reduce
if r > self.r_target:
self.p += self.p_step
else:
self.p -= self.p_step
# Check if p is negative
self.p = self.p if self.p >= 0. else 0.
# Check if p is larger than 1
self.p = self.p if self.p < self.p_max else self.p_max
# Reset r
self.r = []
# Save current r in history
self.r_history.append(r)
return prediction_scalar, prediction_pixel_wise
class AugmentationPipeline(nn.Module):
"""
This class implement the differentiable augmentation pipeline for ADA.
"""
def __init__(self) -> None:
# Call super constructor
super(AugmentationPipeline, self).__init__()
def forward(self, images: torch.Tensor, p: float) -> torch.Tensor:
"""
Forward pass applies augmentation to mini-batch of given images
:param images: (torch.Tensor) Mini-batch images [batch size, channels, height, width]
:param p: (float) Probability of augmentation to be applied
:return: (torch.Tensor) Augmented images [batch size, channels, height, width]
"""
# Perform vertical flip
images_flipped = [index for index, value in enumerate(torch.rand(images.shape[0]) <= p) if value == True]
if len(images_flipped) > 0:
images[images_flipped] = images[images_flipped].flip(dims=(-1,))
# Perform rotation
images_rotated = [index for index, value in enumerate(torch.rand(images.shape[0]) <= p) if value == True]
if len(images_rotated) > 0:
angle = random.choice([torch.tensor(0.), torch.tensor(-90.), torch.tensor(90.), torch.tensor(180.)])
angle = angle.to(images.to(images.device))
images[images_rotated] = kaf.rotate(images[images_rotated],
angle=angle)
# Perform integer translation
images_translated = [index for index, value in enumerate(torch.rand(images.shape[0]) <= p) if value == True]
if len(images_translated) > 0:
images[images_translated] = integer_translation(images[images_translated])
# Perform isotropic scaling
images_scaling = [index for index, value in enumerate(torch.rand(images.shape[0]) <= p) if value == True]
if len(images_scaling) > 0:
images[images_scaling] = kaf.apply_affine(
images[images_scaling],
params={"angle": torch.zeros(len(images_scaling), device=images.device),
"translations": torch.zeros(len(images_scaling), 2, device=images.device),
"center": torch.ones(len(images_scaling), 2, device=images.device)
* 0.5 * torch.tensor(images.shape[2:], device=images.device),
"scale": torch.ones(len(images_scaling), 2, device=images.device) *
torch.from_numpy(
np.random.lognormal(mean=0, sigma=(0.2 * math.log(2)) ** 2,
size=(len(images_scaling), 1))).float().to(images.device),
"sx": torch.zeros(len(images_scaling), device=images.device),
"sy": torch.zeros(len(images_scaling), device=images.device)},
flags={"resample": torch.tensor(1, device=images.device),
"padding_mode": torch.tensor(2, device=images.device),
"align_corners": torch.tensor(True, device=images.device)})
# Perform rotation
images_rotated = [index for index, value in enumerate(torch.rand(images.shape[0]) <= (1 - math.sqrt(1 - p)))
if value == True]
if len(images_rotated) > 0:
images[images_rotated] = kaf.apply_affine(
images[images_rotated],
params={"angle": torch.from_numpy(
np.random.uniform(low=-180, high=180, size=len(images_rotated))).to(images.device),
"translations": torch.zeros(len(images_rotated), 2, device=images.device),
"center": torch.ones(len(images_rotated), 2, device=images.device)
* 0.5 * torch.tensor(images.shape[2:], device=images.device),
"scale": torch.ones(len(images_rotated), 2, device=images.device),
"sx": torch.zeros(len(images_rotated), device=images.device),
"sy": torch.zeros(len(images_rotated), device=images.device)},
flags={"resample": torch.tensor(1, device=images.device),
"padding_mode": torch.tensor(2, device=images.device),
"align_corners": torch.tensor(True, device=images.device)})
# Perform anisotropic scaling
images_scaling = [index for index, value in enumerate(torch.rand(images.shape[0]) <= p) if value == True]
if len(images_scaling) > 0:
images[images_scaling] = kaf.apply_affine(
images[images_scaling],
params={"angle": torch.zeros(len(images_scaling), device=images.device),
"translations": torch.zeros(len(images_scaling), 2, device=images.device),
"center": torch.ones(len(images_scaling), 2, device=images.device)
* 0.5 * torch.tensor(images.shape[2:], device=images.device),
"scale": torch.ones(len(images_scaling), 2, device=images.device) *
torch.from_numpy(
np.random.lognormal(mean=0, sigma=(0.2 * math.log(2)) ** 2,
size=(len(images_scaling), 2))).float().to(images.device),
"sx": torch.zeros(len(images_scaling), device=images.device),
"sy": torch.zeros(len(images_scaling), device=images.device)},
flags={"resample": torch.tensor(1, device=images.device),
"padding_mode": torch.tensor(2, device=images.device),
"align_corners": torch.tensor(True, device=images.device)})
# Perform rotation
images_rotated = [index for index, value in enumerate(torch.rand(images.shape[0]) <= (1 - math.sqrt(1 - p)))
if value == True]
if len(images_rotated) > 0:
images[images_rotated] = kaf.apply_affine(
images[images_rotated],
params={"angle": torch.from_numpy(
np.random.uniform(low=-180, high=180, size=len(images_rotated))).to(images.device),
"translations": torch.zeros(len(images_rotated), 2, device=images.device),
"center": torch.ones(len(images_rotated), 2, device=images.device)
* 0.5 * torch.tensor(images.shape[2:], device=images.device),
"scale": torch.ones(len(images_rotated), 2, device=images.device),
"sx": torch.zeros(len(images_rotated), device=images.device),
"sy": torch.zeros(len(images_rotated), device=images.device)},
flags={"resample": torch.tensor(1, device=images.device),
"padding_mode": torch.tensor(2, device=images.device),
"align_corners": torch.tensor(True, device=images.device)})
return images
def integer_translation(images: torch.Tensor) -> torch.Tensor:
"""
Function implements integer translation augmentation
:param images: (torch.Tensor) Input images
:return: (torch.Tensor) Augmented images
"""
# Get translation index
translation_index = (int(images.shape[-2] * random.uniform(-0.125, 0.125)),
int(images.shape[-1] * random.uniform(-0.125, 0.125)))
# Apply translation
return torch.roll(images, shifts=translation_index, dims=(-2, -1))
from typing import Dict, Any
import math
# U-Net 2D discriminator hyperparameters for resolution
u_net_2d_discriminator_config: Dict[str, Any] = {
# Set encoder channels
"encoder_channels": ((3, 128), (128, 256), (256, 384), (384, 768), (768, 1024)),
# Set decoder channels
"decoder_channels": ((1024, 768), (768, 384), (384, 256), (256, 128)),
# Utilize fft input
"fft": False,
}
# StyleGAN 2 2D generator hyperparameters for resolution
multi_style_gan_generator_config: Dict[str, Any] = {
# Channels utilized in each resolution stage
"channels": (512, 512, 512, 512, 512, 512, 512),
# Channel factor
"channel_factor": 1,
# Number of latent dimensions
"latent_dimensions": 512,
# Depth of the style mapping network
"depth_style_mapping": 8,
# Starting resolution
"starting_resolution": (4, 4)
}
# Additional hyperparameters
generation_hyperparameters: Dict[str, Any] = {
# Probability of mixed noise input
"p_mixed_noise": 0.9,
# Lazy generator regularization factor
"lazy_generator_regularization": 16,
# Weights factor for generator regularization
"w_generator_regularization": math.log(2) / ((256 ** 2) * (math.log(256) - math.log(2))),
# Lazy discriminator regularization factor
"lazy_discriminator_regularization": 16,
# Weights factor for discriminator R1 regularization
"w_discriminator_regularization_r1": 10.0,
# Weights factor for discriminator regularization
"w_discriminator_regularization": 4.0,
# Fraction of batch size of wrongly ordered samples
"batch_factor_wrong_order": 1. / 4.,
# Batch size for path length regularization
"batch_size_shrink_path_length_regularization": 2. / 4.,
# Beta for optimizers
"betas": (0.0, 0.999),
# Factor of total training steps when top-k should be started
"top_k_start": 1. / 4.,
# Factor of total training steps when top-k should be finished with v=0.5
"top_k_finish": 3. / 4.,
# Factor of total training steps when wrong time order is utilized
"wrong_order_start": 3. / 4.,
# Factor of total training epochs when to applied trap region weights map
"trap_weight": 1. / 4.
}
from typing import Union, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class EqualizedConv2d(nn.Module):
"""
Implementation of a 2d equalized Convolution
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = 3,
stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 1,
bias: bool = True):
"""
Constructor method
:param in_channels: (int) Number of input channels
:param out_channels: (int) Number of output channels
:param kernel_size: (Union[int, Tuple[int, int]]) Kernel size
:param stride: (Union[int, Tuple[int, int]]) Stride factor used in the convolution
:param padding: (Union[int, Tuple[int, int]]) Padding factor used in the convolution
:param bias: (bool) Use bias
"""
# Init super constructor
super(EqualizedConv2d, self).__init__()
# Save parameters
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.stride = (stride, stride) if isinstance(stride, int) else stride
self.padding = (padding, padding) if isinstance(padding, int) else padding
# Init weights tensor for convolution
self.weight = nn.Parameter(
nn.init.normal_(torch.empty(out_channels, in_channels, *self.kernel_size, dtype=torch.float)),
requires_grad=True)
# Init bias weight if needed
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels, dtype=torch.float), requires_grad=True)
else:
self.bias = None
# Init scale factor
self.scale = torch.tensor(
np.sqrt(2) / np.sqrt(in_channels * (self.kernel_size[0] * self.kernel_size[1]))).float()
self.scale_bias = torch.tensor(np.sqrt(2) / np.sqrt(out_channels)).float()
def __repr__(self) -> str:
"""
Method returns information about the module
:return: (str) Info string
"""
return ('{}({}, {}, kernel_size=({}, {}), stride=({}, {}), padding=({}, {}), bias={})'.format(
self.__class__.__name__,
self.weight.shape[1],
self.weight.shape[0],
self.weight.shape[2],
self.weight.shape[3],
self.stride[0],
self.stride[1],
self.padding[0],
self.padding[1],
self.bias is not None))
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass
:param input: (Torch Tensor) Input tensor 4D
:return: (Torch tensor) Output tensor 4D
"""
if self.bias is None:
output = F.conv2d(input=input, weight=self.weight * self.scale, stride=self.stride, padding=self.padding)
else:
output = F.conv2d(input=input, weight=self.weight * self.scale, bias=self.bias * self.scale_bias,
stride=self.stride, padding=self.padding)
return output
class EqualizedTransposedConv2d(nn.Module):
"""
Implementation of a 2d equalized transposed Convolution
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = 2,
stride: Union[int, Tuple[int, int]] = 2, padding: Union[int, Tuple[int, int]] = 0,
bias: bool = True) -> None:
"""
Constructor method
:param in_channels: (int) Number of input channels
:param out_channels: (int) Number of output channels
:param kernel_size: (Union[int, Tuple[int, int]]) Kernel size
:param stride: (Union[int, Tuple[int, int]]) Stride factor used in the convolution
:param padding: (Union[int, Tuple[int, int]]) Padding factor used in the convolution
:param bias: (bool) Use bias
"""
# Init super constructor
super(EqualizedTransposedConv2d, self).__init__()
# Save parameters
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.stride = (stride, stride) if isinstance(stride, int) else stride
self.padding = (padding, padding) if isinstance(padding, int) else padding
# Init weights tensor for convolution
self.weight = nn.Parameter(
nn.init.normal_(torch.empty(in_channels, out_channels, *self.kernel_size, dtype=torch.float)),
requires_grad=True)
# Init bias weight if needed
if bias:
self.bias = nn.Parameter(torch.ones(out_channels, dtype=torch.float), requires_grad=True)
else:
self.bias = None
# Init scale factor
self.scale = torch.tensor(
np.sqrt(2) / np.sqrt(in_channels * (self.kernel_size[0] * self.kernel_size[1]))).float()
self.scale_bias = torch.tensor(np.sqrt(2) / np.sqrt(out_channels)).float()
def __repr__(self) -> str:
"""
Method returns information about the module
:return: (str) Info string
"""
return ('{}({}, {}, kernel_size=({}, {}), stride=({}, {}), padding=({}, {}), bias={})'.format(
self.__class__.__name__,
self.weight.shape[1],
self.weight.shape[0],
self.weight.shape[2],
self.weight.shape[3],
self.stride[0],
self.stride[1],
self.padding[0],
self.padding[1],
self.bias is not None))
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass
:param input: (Torch Tensor) Input tensor 4D
:return: (Torch tensor) Output tensor 4D
"""
if self.bias is None:
output = F.conv_transpose2d(input=input, weight=self.weight * self.scale, stride=self.stride,
padding=self.padding)
else:
output = F.conv_transpose2d(input=input, weight=self.weight * self.scale, bias=self.bias * self.scale_bias,
stride=self.stride, padding=self.padding)
return output
class EqualizedConv1d(nn.Module):
"""
This class implements an equalized 1d convolution
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
stride: int = 1, padding: int = 1, bias: bool = True) -> None:
"""
Constructor method
:param in_channels: (int) Number of input channels
:param out_channels: (int) Number of output channels
:param kernel_size: (int) Kernel size
:param stride: (int) Stride factor used in the convolution
:param padding: (int) Padding factor used in the convolution
:param bias: (bool) Use bias
"""
# Call super constructor
super(EqualizedConv1d, self).__init__()
# Save parameters
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
# Init weight parameter
self.weight = nn.Parameter(
nn.init.normal_(torch.empty(out_channels, in_channels, kernel_size, dtype=torch.float)), requires_grad=True)
# Init bias if utilized
if bias:
self.bias = nn.Parameter(torch.ones(out_channels, dtype=torch.float), requires_grad=True)
else:
self.bias = None
# Init scale factor
self.scale = torch.tensor(
np.sqrt(2) / np.sqrt(in_channels * (self.kernel_size))).float()
self.scale_bias = torch.tensor(np.sqrt(2) / np.sqrt(out_channels)).float()
def __repr__(self) -> str:
"""
Method returns information about the module
:return: (str) Info string
"""
return ('{}({}, {}, kernel_size={}, stride={}, padding={}, bias={})'.format(
self.__class__.__name__,
self.weight.shape[1],
self.weight.shape[0],
self.weight.shape[2],
self.stride,
self.padding,
self.bias is not None))
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass
:param input: (torch.Tensor) Input tensor 3D
:return: (torch.Tensor) Output tensor 3D
"""
if self.bias is None:
output = F.conv1d(input=input, weight=self.weight * self.scale, stride=self.stride,
padding=self.padding)
else:
output = F.conv1d(input=input, weight=self.weight * self.scale, bias=self.bias * self.scale_bias,
stride=self.stride, padding=self.padding)
return output
class EqualizedLinear(nn.Module):
"""
Implementation of an equalized linear layer
"""
def __init__(self, in_channels: int, out_channels: int, bias: bool = True) -> None:
"""
Constructor method
:param in_channels: (int) Number of input channels
:param out_channels: (int) Number of output channels
:param use_bias: (bool) True if bias should be used
"""
# Init super constructor
super(EqualizedLinear, self).__init__()
# Init weights tensor for convolution
self.weight = nn.Parameter(
nn.init.normal_(torch.empty(out_channels, in_channels, dtype=torch.float)), requires_grad=True)
# Init bias weight if needed
if bias:
self.bias = nn.Parameter(torch.FloatTensor(out_channels).fill_(0), requires_grad=True)
else:
self.bias = None
# Init scale factor
self.scale = np.sqrt(2) / np.sqrt(in_channels)
self.scale_bias = np.sqrt(2) / np.sqrt(out_channels)
def __repr__(self) -> str:
"""
Method returns information about the module
:return: (str) Info string
"""
return ('{}({}, {}, bias={})'.format(self.__class__.__name__, self.weight.shape[1], self.weight.shape[0],
self.bias is not None))
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass
:param input: (torch.Tensor) Input tensor 2D or 3D
:return: (torch.Tensor) Output tensor 2D or 3D
"""
if self.bias is None:
output = F.linear(input=input, weight=self.weight * self.scale)
else:
output = F.linear(input=input, weight=self.weight * self.scale, bias=self.bias * self.scale_bias)
return output
class PixelwiseNormalization(nn.Module):
"""
Pixelwise Normalization module
"""
def __init__(self, alpha: float = 1e-8) -> None:
"""
Constructor method
:param alpha: (float) Small constants for numeric stability
"""
super(PixelwiseNormalization, self).__init__()
self.alpha = alpha
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass
:param input: (Torch Tensor) Input tensor
:return: (Torch Tensor) Normalized output tensor with same shape as input
"""
output = input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + self.alpha)
return output
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd
class WassersteinDiscriminatorLoss(nn.Module):
"""
This class implements the Wasserstein loss for a discriminator network.
"""
def __init__(self) -> None:
"""
Constructor method
"""
# Call super constructor
super(WassersteinDiscriminatorLoss, self).__init__()
def forward(self, prediction_real: torch.Tensor,
prediction_fake: torch.Tensor, weight: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the loss module
:param prediction_real: (torch.Tensor) Prediction for real samples
:param prediction_fake: (torch.Tensor) Prediction for fake samples
:param weight: (torch.Tensor) Weights map to be applied
:return: (Tuple[torch.Tensor, torch.Tensor]) Scalar loss value
"""
# Compute loss
if weight is not None:
loss_real = - torch.mean(
prediction_real * weight.view(1, 1, 1, weight.shape[-2], weight.shape[-1]).to(prediction_real.device))
loss_fake = torch.mean(
prediction_fake * weight.view(1, 1, 1, weight.shape[-2], weight.shape[-1]).to(prediction_fake.device))
return loss_real, loss_fake
else:
loss_real = - torch.mean(prediction_real)
loss_fake = torch.mean(prediction_fake)
return loss_real, loss_fake
class WassersteinDiscriminatorLossCutMix(nn.Module):
"""
This class implements the Wasserstein loss for a discriminator network when utilizing cut mix augmentation.
"""
def __init__(self) -> None:
"""
Constructor method
"""
# Call super constructor
super(WassersteinDiscriminatorLossCutMix, self).__init__()
def forward(self, prediction: torch.Tensor,
label: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass. Loss parts are not summed up to not retain the whole backward graph later.
:param prediction: (torch.Tensor)
:return: (Tuple[torch.Tensor, torch.Tensor]) Loss values for real and fake part
"""
# Compute loss
loss_real = - torch.mean(prediction * label)
loss_fake = torch.mean(prediction * (-label + 1.))
return loss_real, loss_fake
class WassersteinGeneratorLoss(nn.Module):
"""
This class implements the Wasserstein loss for a generator network.
"""
def __init__(self) -> None:
"""
Constructor method
"""
# Call super constructor
super(WassersteinGeneratorLoss, self).__init__()
def forward(self, prediction_fake: torch.Tensor, weight: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the loss module
:param prediction_fake: (torch.Tensor) Prediction for fake samples
:param weight: (torch.Tensor) Weights map to be applied
:return: (torch.Tensor) Scalar loss value
"""
# Compute loss
if weight is not None:
loss = - torch.mean(
prediction_fake * weight.view(1, 1, 1, weight.shape[-2], weight.shape[-1]).to(prediction_fake.device))
return loss
else:
loss = - torch.mean(prediction_fake)
return loss
class NonSaturatingLogisticGeneratorLoss(nn.Module):
"""
Implementation of the non saturating GAN loss for the generator network.
"""
def __init__(self) -> None:
"""
Constructor method
"""
# Call super constructor
super(NonSaturatingLogisticGeneratorLoss, self).__init__()
def __repr__(self):
"""
Get representation of the loss module
:return: (str) String including information
"""
return '{}'.format(self.__class__.__name__)
def forward(self, prediction_fake: torch.Tensor, weight: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass to compute the generator loss
:param prediction_fake: (torch.Tensor) Prediction of the discriminator for fake samples
:param weight: (torch.Tensor) Weights map to be applied
:return: (torch.Tensor) Loss value
"""
# Calc loss
if weight is not None:
loss = torch.mean(
F.softplus(-prediction_fake) * weight.view(1, 1, 1, weight.shape[-2], weight.shape[-1]).to(
prediction_fake.device))
return loss
else:
loss = torch.mean(F.softplus(-prediction_fake))
return loss
class NonSaturatingLogisticDiscriminatorLoss(nn.Module):
"""
Implementation of the non saturating GAN loss for the discriminator network.
"""
def __init__(self) -> None:
"""
Constructor
"""
# Call super constructor
super(NonSaturatingLogisticDiscriminatorLoss, self).__init__()
def forward(self, prediction_real: torch.Tensor,
prediction_fake: torch.Tensor, weight: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass. Loss parts are not summed up to not retain the whole backward graph later.
:param prediction_real: (torch.Tensor) Prediction of the discriminator for real images
:param prediction_fake: (torch.Tensor) Prediction of the discriminator for fake images
:param weight: (torch.Tensor) Weights map to be applied
:return: (Tuple[torch.Tensor, torch.Tensor]) Loss values for real and fake part
"""
if weight is not None:
# Calc real loss part
loss_real = torch.mean(
F.softplus(-prediction_real) * weight.view(1, 1, 1, weight.shape[-2], weight.shape[-1]).to(
prediction_real.device))
# Calc fake loss part
loss_fake = torch.mean(
F.softplus(prediction_fake) * weight.view(1, 1, 1, weight.shape[-2], weight.shape[-1]).to(
prediction_fake.device))
return loss_real, loss_fake
else:
# Calc real loss part
loss_real = torch.mean(F.softplus(-prediction_real))
# Calc fake loss part
loss_fake = torch.mean(F.softplus(prediction_fake))
return loss_real, loss_fake
class NonSaturatingLogisticDiscriminatorLossCutMix(nn.Module):
"""
Implementation of the non saturating GAN loss for the discriminator network when performing cut mix augmentation.
"""
def __init__(self) -> None:
"""
Constructor
"""
# Call super constructor
super(NonSaturatingLogisticDiscriminatorLossCutMix, self).__init__()
def forward(self, prediction: torch.Tensor, label: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass. Loss parts are not summed up to not retain the whole backward graph later.
:param prediction: (torch.Tensor)
:return: (Tuple[torch.Tensor, torch.Tensor]) Loss values for real and fake part
"""
# Calc real loss part
loss_real = torch.mean(F.softplus(-prediction) * label)
# Calc fake loss part
loss_fake = torch.mean(F.softplus(prediction) * (-label + 1.))
return loss_real, loss_fake
class HingeGeneratorLoss(WassersteinGeneratorLoss):
"""
This class implements the hinge gan loss for the generator network. Note that the generator hinge loss is equivalent
to the generator Wasserstein loss!
"""
def __init__(self) -> None:
"""
Constructor method
"""
# Call super constructor
super(HingeGeneratorLoss, self).__init__()
class HingeDiscriminatorLoss(nn.Module):
"""
This class implements the hinge gan loss for the discriminator network.
"""
def __init__(self) -> None:
"""
Constructor method
"""
# Call super constructor
super(HingeDiscriminatorLoss, self).__init__()
def forward(self, prediction_real: torch.Tensor, prediction_fake: torch.Tensor,
weight: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass. Loss parts are not summed up to not retain the whole backward graph later.
:param prediction_real: (torch.Tensor) Prediction of the discriminator for real images
:param prediction_fake: (torch.Tensor) Prediction of the discriminator for fake images
:param weight: (torch.Tensor) Weights map to be applied
:return: (Tuple[torch.Tensor, torch.Tensor]) Loss values for real and fake part
"""
if weight is not None:
# Calc loss for real prediction
loss_real = - torch.mean(torch.minimum(torch.tensor(0., dtype=torch.float, device=prediction_real.device),
prediction_real - 1.) * weight.view(1, 1, 1, weight.shape[-2],
weight.shape[-1]).to(
prediction_real.device))
# Calc loss for fake prediction
loss_fake = - torch.mean(torch.minimum(torch.tensor(0., dtype=torch.float, device=prediction_real.device),
- prediction_fake - 1.) * weight.view(1, 1, 1, weight.shape[-2],
weight.shape[-1]).to(
prediction_fake.device))
return loss_real, loss_fake
else:
# Calc loss for real prediction
loss_real = - torch.mean(torch.minimum(torch.tensor(0., dtype=torch.float, device=prediction_real.device),
prediction_real - 1.))
# Calc loss for fake prediction
loss_fake = - torch.mean(torch.minimum(torch.tensor(0., dtype=torch.float, device=prediction_real.device),
- prediction_fake - 1.))
return loss_real, loss_fake
class HingeDiscriminatorLossCutMix(nn.Module):
"""
This class implements the hinge gan loss for the discriminator network when utilizing cut mix augmentation.
"""
def __init__(self) -> None:
"""
Constructor method
"""
# Call super constructor
super(HingeDiscriminatorLossCutMix, self).__init__()
def forward(self, prediction: torch.Tensor,
label: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass. Loss parts are not summed up to not retain the whole backward graph later.
:param prediction: (torch.Tensor)
:return: (Tuple[torch.Tensor, torch.Tensor]) Loss values for real and fake part
"""
# Calc loss for real prediction
loss_real = - torch.mean(torch.minimum(torch.tensor(0., dtype=torch.float, device=prediction.device),
prediction - 1.) * label)
# Calc loss for fake prediction
loss_fake = - torch.mean(torch.minimum(torch.tensor(0., dtype=torch.float, device=prediction.device),
- prediction - 1.) * (- label + 1.))
return loss_real, loss_fake
class R1Regularization(nn.Module):
"""
Implementation of the R1 GAN regularization.
"""
def __init__(self):
"""
Constructor method
"""
# Call super constructor
super(R1Regularization, self).__init__()
def __repr__(self):
"""
Get representation of the loss module
:return: (str) String including information
"""
return '{}'.format(self.__class__.__name__)
def forward(self, prediction_real: torch.Tensor, image_real: torch.Tensor,
prediction_real_pixel_wise: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Forward pass to compute the regularization
:param prediction_real: (torch.Tensor) Prediction of the discriminator for a batch of real images
:param image_real: (torch.Tensor) Batch of the corresponding real images
:return: (torch.Tensor) Loss value
"""
# Calc gradient
grad_real, = autograd.grad(
outputs=(prediction_real.sum(), prediction_real_pixel_wise.sum())
if prediction_real_pixel_wise is not None else prediction_real.sum(),
inputs=image_real, create_graph=True)
# Calc regularization
regularization_loss = 0.5 * grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return regularization_loss
class R2Regularization(nn.Module):
"""
Implementation of the R2 GAN regularization.
"""
def __init__(self):
"""
Constructor method
"""
# Call super constructor
super(R2Regularization, self).__init__()
def __repr__(self):
"""
Get representation of the loss module
:return: (str) String including information
"""
return '{}'.format(self.__class__.__name__)
def forward(self, prediction_fake: torch.Tensor, image_fake) -> torch.Tensor:
"""
Forward pass to compute the regularization
:param prediction_real: (torch.Tensor) Prediction of the discriminator for a batch of real images
:param image_real: (torch.Tensor) Batch of the corresponding real images
:return: (torch.Tensor) Loss value
"""
# Calc gradient
grad_real = autograd.grad(outputs=prediction_fake.sum(), inputs=image_fake, create_graph=True)
# Calc regularization
regularization_loss = 0.5 * grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return regularization_loss
class PathLengthRegularization(nn.Module):
"""
Module implements the path length style gan regularization.
"""
def __init__(self, decay: float = 0.01) -> None:
"""
Constructor method
:param decay: (float) Decay of the current mean path length
:param weight: (float) Weight factor
"""
# Call super constructor
super(PathLengthRegularization, self).__init__()
# Save parameter
self.decay = decay
# Init mean path length
self.mean_path_length = torch.zeros(1, dtype=torch.float, requires_grad=False)
def __repr__(self):
"""
Get representation of the loss module
:return: (str) String including information
"""
return '{}'.format(self.__class__.__name__)
def forward(self, grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass
:param grad: (torch.Tensor) Patch length grads
:return: (Tuple[torch.Tensor, torch.Tensor]) Path length penalty and path lengths
"""
# Reduce dims
# Detach mean path length
self.mean_path_length.detach_()
# Get new path lengths
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1) + 1e-08).mean()
# Mean path length to device
self.mean_path_length = self.mean_path_length.to(grad.device)
# Calc path length mean
self.mean_path_length = self.mean_path_length + self.decay * (path_lengths.mean() - self.mean_path_length)
# Get path length penalty
path_length_penalty = torch.mean((path_lengths - self.mean_path_length) ** 2)
return path_length_penalty, path_lengths
class TopK(nn.Module):
"""
This class implements the top-k method proposed in:
https://arxiv.org/pdf/2002.06224.pdf
"""
def __init__(self, starting_iteration: int, final_iteration: int) -> None:
"""
Constructor method
:param starting_iteration: (bool) Number of iteration when to start with top-k training
:param final_iteration: (bool) Number of iteration when to stop top-k training decrease
"""
# Call super constructor
super(TopK, self).__init__()
# Save parameters
self.starting_iteration = starting_iteration
self.final_iteration = final_iteration
self.iterations = 0
def calc_v(self) -> float:
"""
Method tracks the iterations and estimates v.
:return: (float) v factor
"""
# Update iterations
self.iterations += 1
if self.iterations <= self.starting_iteration:
return 1.
elif self.iterations >= self.final_iteration:
return 0.5
else:
return 0.5 * (1. - float(self.iterations - self.starting_iteration)
/ float(self.final_iteration - self.starting_iteration)) + 0.5
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass
:param input: (torch.Tensor) Input tensor
:return: (torch.Tensor) Output tensor filtered by top-k
"""
# Calc v
v = self.calc_v()
# Flatten input
input = input.view(-1)
# Apply top k
output = torch.topk(input, k=max(1, int(input.shape[0] * v)))
return output
from typing import Any, Dict, Union, Iterable, Optional, List
import torch
import torchvision
import torch.nn as nn
import os
import json
from datetime import datetime
import random
import numpy as np
class Logger(object):
"""
Class to log different metrics.
"""
def __init__(self,
experiment_path: str =
os.path.join(os.getcwd(), "experiments", datetime.now().strftime("%d_%m_%Y__%H_%M_%S")),
experiment_path_extension: str = "",
path_metrics: str = "metrics",
path_hyperparameters: str = "hyperparameters",
path_plots: str = "plots",
path_models: str = "models") -> None:
"""
Constructor method
:param path_metrics: (str) Path to folder in which all metrics are stored
:param experiment_path_extension: (str) Extension to experiment folder
:param path_hyperparameters: (str) Path to folder in which all hyperparameters are stored
:param path_plots: (str) Path to folder in which all plots are stored
:param path_models: (str) Path to folder in which all models are stored
"""
experiment_path = experiment_path + experiment_path_extension
# Save parameters
self.path_metrics = os.path.join(experiment_path, path_metrics)
self.path_hyperparameters = os.path.join(experiment_path, path_hyperparameters)
self.path_plots = os.path.join(experiment_path, path_plots)
self.path_models = os.path.join(experiment_path, path_models)
# Init folders
os.makedirs(self.path_metrics, exist_ok=True)
os.makedirs(self.path_hyperparameters, exist_ok=True)
os.makedirs(self.path_plots, exist_ok=True)
os.makedirs(self.path_models, exist_ok=True)
# Init dicts to store the metrics and hyperparameters
self.metrics = dict()
self.temp_metrics = dict()
self.hyperparameters = dict()
def log_metric(self, metric_name: str, value: Any) -> None:
"""
Method writes a given metric value into a dict including list for every metric.
:param metric_name: (str) Name of the metric
:param value: (float) Value of the metric
"""
if metric_name in self.metrics:
self.metrics[metric_name].append(float(value))
else:
self.metrics[metric_name] = [float(value)]
def log_temp_metric(self, metric_name: str, value: Any) -> None:
"""
Method writes a given metric value into a dict including temporal metrics.
:param metric_name: (str) Name of the metric
:param value: (float) Value of the metric
"""
if metric_name in self.temp_metrics:
self.temp_metrics[metric_name].append(float(value))
else:
self.temp_metrics[metric_name] = [float(value)]
def save_temp_metric(self, metric_name: Union[Iterable[str], str]) -> Dict[str, float]:
"""
Method writes temporal metrics into the metrics dict by averaging.
:param metric_name: (Union[Iterable[str], str]) One temporal metric name ore a list of names
"""
averaged_temp_dict = dict()
# Case if only one metric is given
if isinstance(metric_name, str):
# Calc average
value = float(torch.tensor(self.temp_metrics[metric_name]).mean())
# Save metric in log dict
self.log_metric(metric_name=metric_name, value=value)
# Put metric also in dict to be returned
averaged_temp_dict[metric_name] = value
# Case if multiple metrics are given
else:
for name in metric_name:
# Calc average
value = float(torch.tensor(self.temp_metrics[name]).mean())
# Save metric in log dict
self.log_metric(metric_name=name, value=value)
# Put metric also in dict to be returned
averaged_temp_dict[name] = value
# Reset temp metrics
self.temp_metrics = dict()
# Save logs
self.save()
return averaged_temp_dict
def log_hyperparameter(self, hyperparameter_name: str = None, value: Any = None,
hyperparameter_dict: Dict[str, Any] = None) -> None:
"""
Method writes a given hyperparameter into a dict including all other hyperparameters.
:param hyperparameter_name: (str) Name of the hyperparameter
:param value: (Any) Value of the hyperparameter, must by convertible to str
:param hyperparameter_dict: (Dict[str, Any]) Dict of multiple hyperparameter to be saved
"""
# Case if name and value are given
if (hyperparameter_name is not None) and (value is not None):
if hyperparameter_name in self.hyperparameters:
self.hyperparameters[hyperparameter_name].append(str(value))
else:
self.hyperparameters[hyperparameter_name] = [str(value)]
# Case if dict of hyperparameters is given
if hyperparameter_dict is not None:
# Iterate over given dict, cast data and store in internal hyperparameters dict
for key in hyperparameter_dict.keys():
if key in self.hyperparameters.keys():
self.hyperparameters[key].append(str(hyperparameter_dict[key]))
else:
self.hyperparameters[key] = [str(hyperparameter_dict[key])]
def save_checkpoint(self, file_name: str, checkpoint_dict: Dict) -> None:
"""
This method saves a given checkpoint.
:param name: (str) File name with file format
:param model: (Dict) Dict including all modules
"""
torch.save(checkpoint_dict, os.path.join(self.path_models, file_name))
def save_prediction(self, prediction: torch.Tensor, name: str) -> None:
"""
This method saves the image predictions as an png image
:param prediction: (torch.Tensor) Prediction of the shape [batch size, 2, time steps, height, width]
:param name: (torch.Tensor) Name of the images without ending!
"""
for batch_index in range(prediction.shape[0]):
# Get images and normalize shape [time steps, 1, height, width]
bf_images = prediction[batch_index, 0][:, None]
# Make bf to rgb
bf_images = bf_images.repeat_interleave(3, dim=1)
if prediction.shape[1] > 1:
gfp_images = prediction[batch_index, 1][:, None]
# Make gfp to rgb only green shades
gfp_images = gfp_images.repeat_interleave(3, dim=1)
gfp_images[:, 0] = 0.0
gfp_images[:, 2] = 0.0
if prediction.shape[1] > 2:
rfp_images = prediction[batch_index, 2][:, None]
# Make rfp to rgb only red shades
rfp_images = rfp_images.repeat_interleave(3, dim=1)
rfp_images[:, 1] = 0.0
rfp_images[:, 2] = 0.0
# Save images
torchvision.utils.save_image(tensor=bf_images,
fp=os.path.join(self.path_plots, name + "_bf_{}.png".format(batch_index)),
nrow=bf_images.shape[0], padding=0)
if prediction.shape[1] > 1:
torchvision.utils.save_image(tensor=gfp_images,
fp=os.path.join(self.path_plots, name + "_gfp_{}.png".format(batch_index)),
nrow=gfp_images.shape[0], padding=0)
if prediction.shape[1] > 2:
torchvision.utils.save_image(tensor=rfp_images,
fp=os.path.join(self.path_plots, name + "_rfp_{}.png".format(batch_index)),
nrow=gfp_images.shape[0], padding=0)
def save(self) -> None:
"""
Method saves all current logs (metrics and hyperparameters). Plots are saved directly.
"""
# Save dict of hyperparameter as json file
with open(os.path.join(self.path_hyperparameters, 'hyperparameter.txt'), 'w') as json_file:
json.dump(self.hyperparameters, json_file)
# Iterate items in metrics dict
for metric_name, values in self.metrics.items():
# Convert list of values to torch tensor to use build in save method from torch
values = torch.tensor(values)
# Save values
torch.save(values, os.path.join(self.path_metrics, '{}.pt'.format(metric_name)))
@torch.no_grad()
def exponential_moving_average(model_ema: Union[torch.nn.Module, nn.DataParallel],
model_train: Union[torch.nn.Module, nn.DataParallel], decay: float = 0.999) -> None:
"""
Function apples one exponential moving average step to a given model to be accumulated and a given training model
:param model_ema: (Union[torch.nn.Module, nn.DataParallel]) Model to be accumulated
:param model_train: (Union[torch.nn.Module, nn.DataParallel]) Training model
:param decay: (float) Decay factor
"""
# Check types
assert type(model_ema) is type(model_train), 'EMA can only be performed on networks of the same type!'
# Get parameter dicts
model_ema_dict = dict(model_ema.named_parameters())
model_train_dict = dict(model_train.named_parameters())
# Apply ema
for key in model_ema_dict.keys():
model_ema_dict[key].data.mul_(decay).add_(1 - decay, model_train_dict[key].data)
def random_permutation(n: int) -> torch.Tensor:
"""
Function generates a random permutation without current permutation ([0, 1, 2, ...]).
:param n: (int) Number of elements
:return: (torch.Tensor) Permutation tensor
"""
# Get random permutation
permutation = torch.from_numpy(np.random.choice(range(n), size=n))
# Check of default permutation is present
if torch.equal(permutation, torch.arange(n)):
permutation = torch.arange(start=n - 1, end=-1, step=-1)
return permutation
def normalize_0_1_batch(input: torch.tensor) -> torch.tensor:
"""
Normalize a given tensor batch wise to a range of [0, 1]
:param input: (Torch tensor) Input tensor
:return: (Torch tensor) Normalized output tensor
"""
input_flatten = input.view(input.shape[0], -1)
return ((input - torch.min(input_flatten, dim=1)[0][:, None, None, None, None]) / (
torch.max(input_flatten, dim=1)[0][:, None, None, None, None] -
torch.min(input_flatten, dim=1)[0][:, None, None, None, None])).clamp(min=1e-03)
def normalize_m1_1_batch(input: torch.tensor) -> torch.tensor:
"""
Normalize a given tensor batch wise to a range of [-1, 1]
:param input: (Torch tensor) Input tensor
:return: (Torch tensor) Normalized output tensor
"""
output = 2. * normalize_0_1_batch(input) - 1.
return output
def get_noise(batch_size: int, latent_dimension, p_mixed_noise: float = 0.9, device: str = 'cuda') -> Union[
torch.Tensor, List]:
"""
Function returns an input noise for the style gan 2 generator.
Iter a list of two noise vectors or one noise vector will be returned.
:param batch_size: (int) Batch size to be used
:param latent_dimension: (int) Latent dimensions to be utilized
:param p_mixed_noise: (int) Probability that a mixed noise will be returned
:param device: (str) Device to be utilized
:return: List of noise tensors or single noise tensor
"""
if (p_mixed_noise > 0) and (random.random() < p_mixed_noise):
return list(torch.randn(2, batch_size, latent_dimension, dtype=torch.float32, device=device).unbind(0))
else:
return torch.randn(batch_size, latent_dimension, dtype=torch.float32, device=device)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment