Skip to content
Snippets Groups Projects
Commit 25e13fe6 authored by uj194098's avatar uj194098
Browse files

add missing files

parent 6db01786
No related branches found
No related tags found
No related merge requests found
Showing
with 248 additions and 0 deletions
from model import ConvNet1D
from specification_network import IDFT
from util import load_data
import torch
from torch import nn
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sampling_rate = 100
class AugmentedNetwork(nn.Module):
def __init__(self, model, dataset_info, normalize = False, aug_cfg={}, signals=None, device=device, eps=1):
super().__init__()
self.model = model
self.in_shape, aug_domain_shape, data_shape, self.out_shape = None, dataset_info["shape"], dataset_info["shape"], dataset_info["num_classes"]
self.aug_cfg = aug_cfg
self.norm_net = None
self.signals = signals
networks = []
self.augmentation = IDFT(data_shape[-1], eps=eps)
self.augmentation = self.augmentation.to(device)
networks += [self.augmentation]
if normalize:
assert dataset_info is not None, "Dataset info (nchannels, mean, std) needed to normalize"
self.norm_net = Normalization(self.device, data_shape[0], dataset_info["mean"], dataset_info["std"])
self.norm_net = self.norm_net.to(device)
networks += [self.norm_net]
networks += [model]
self.subnetworks = networks
def forward_encoding_network(self, x=None, c_x=None):
x_dec = x
# domain_ins = [x_dec, c_x] if self.aug_domain == "fourier" and self.aug_cfg.get("spec", "ball") == "ball" else [x_dec]
if c_x == None:
c_x = self.signals
domain_ins = [x_dec, c_x]
x_dec = self.augmentation(*domain_ins)
return x_dec
def forward(self, x = None, x_dec=None, **kwargs):
if x_dec is None:
x_dec = self.forward_encoding_network(x, **kwargs)
x_norm = self.norm_net.to(self.device)(x_dec) if self.norm_net else x_dec
return self.model(x_norm)
# , x_dec, x_norm
class Normalization(nn.Module):
def __init__(self, device, num_channels, mean=0., std=1.):
super().__init__()
mean_negative = -torch.tensor(mean, dtype=torch.float).reshape(num_channels, 1, 1) # torch.nn.Parameter( . , requires_grad=False)
std_reciprocal = 1/torch.tensor(std, dtype=torch.float).reshape(num_channels, 1, 1)
self.mean_negative = nn.Parameter(mean_negative, requires_grad=False)
self.std_reciprocal = nn.Parameter(std_reciprocal, requires_grad=False)
def forward(self, x):
return (x + self.mean_negative) * self.std_reciprocal
if __name__ == "__main__":
feature = "afib"
window_size = 250
dataset_info = {
"shape": (12, window_size),
"num_classes": 2
}
X_train, y_train, X_val, y_val, X_test, y_test = load_data(device=device, file=f"./data/{feature}_data_sampling_rate_{sampling_rate}.npz")
# Hyperparameters
input_size = X_train[0].shape[0] # Number of features per time step
num_classes = y_train.shape[1] # Number of output classes (for multi-label classification)
model = ConvNet1D(input_size, num_classes, window_size=window_size)
checkpoint_path = "./results/afib_simple_convnet_100/best_simple_convnet_model.pth"
model.load_state_dict(torch.load(checkpoint_path, weights_only=False, map_location=torch.device("cpu")))
x = torch.zeros((window_size*2))
model.eval()
model = model.cuda()
augmented_model = AugmentedNetwork(model, dataset_info)
signals = model._create_sliding_windows(X_test, window_size=window_size, step_size=int(window_size/2))
print(signals.shape)
predictions = augmented_model(x,signals[0])
# predictions = predictions.cpu().detach().numpy()
print(predictions[0])
File added
File added
results/afib_simple_convnet_100/simple_convnet_confusion_matrix.png

28.3 KiB

results/afib_simple_convnet_100/simple_convnet_losses.png

24 KiB

results/afib_simple_convnet_100/simple_convnet_roc_curve.png

34.1 KiB

results/verification/adv_signal.png

284 KiB

results/verification/adv_signal0.05.png

298 KiB

results/verification/random_noisy_signal0.01.png

246 KiB

results/verification/random_noisy_signal0.02.png

247 KiB

results/verification/random_noisy_signal0.03.png

250 KiB

results/verification/random_noisy_signal0.04.png

253 KiB

results/verification/random_noisy_signal0.05.png

271 KiB

results/verification/random_noisy_signal0.1.png

265 KiB

import torch
from torch import nn
import numpy as np
from scipy.fft import idct
from matplotlib import pyplot as plt
bias = False
class IDFT(nn.Module):
def __init__(self, N: int, real_signal=False, eps=1):
super().__init__()
self.N = N
self.real_signal = real_signal
self.eps = eps
IW = torch.from_numpy(IDFT_matrix(N))
IW_r, IW_i = torch.real(IW).to(torch.float), torch.imag(IW).to(torch.float)
self.init(N, IW_r, IW_i)
def init_DFT(self, N, IW_r, IW_i):
W1_r = torch.empty([N, 2*N])
W1_r[:, :N], W1_r[:, N:] = IW_r, -IW_i
W1_i = torch.empty([N, 2*N])
W1_i[:, :N], W1_i[:, N:] = IW_i, IW_r
self.W1_r, self.W1_i = W1_r, W1_i
def init(self, N, IW_r, IW_i):
self.init_DFT(N, IW_r, IW_i)
self.L_W2_r = nn.Linear(2*N, N, bias=bias, dtype=torch.float)
with torch.no_grad():
self.L_W2_r.weight.copy_(self.W1_r)
def forward(self, z, c_x = None):
# z = self.unmask(z, self.N)
z = self.create_freq_array(z, self.N)
if self.real_signal:
z_ = z[..., :z.shape[-1]//2+1]
z = self.real_z(z_)
x = self.L_W2_r(z)
if c_x is not None:
x = x + c_x
return x
@staticmethod
def real_z(z):
assert z.shape[-1] % 2 == 1, f"{z.shape}, but need input dimension to be odd"
r1 = z[..., :z.shape[-2]//2, :]
r21 = torch.flip(r1[..., :1, 1:], [z.ndim-1])
r22 = torch.flip(r1[..., 1:, 1:], [z.ndim-2, z.ndim-1])
i1 = z[..., z.shape[-2]//2:, :]
i21 = -torch.flip(i1[..., :1, 1:], [z.ndim-1])
i22 = -torch.flip(i1[..., 1:, 1:], [z.ndim-2, z.ndim-1])
z_ = torch.cat([r21, r22, i21, i22], dim=z.ndim-2)
print(z_)
z_hat = torch.cat([z, z_], dim=z.ndim-1)
return z_hat
def unmask(self, x, N):
length = int(x.shape[-1]/2)
x_filled = torch.zeros(x.shape[0], (2*N), dtype=x.dtype, device=x.device)
for i in range(N):
if i < length-1:
x_filled[:,i+1] = x[:,i]
x_filled[:,N+i+1] = x[:,length+i]
fifty = int(50//0.25)
x_filled[:, fifty] = x[:,length-1]
x_filled[:, N + fifty] = x[:,-1]
return x_filled
def create_freq_array(self, x, N):
eps = self.eps
comp = 2
leads = 12
idx = [0, 1, 1+leads, 1+leads+comp, 1+leads+comp*2]
idx += [idx[-1]+1, idx[-1]+1+leads, idx[-1]+1+leads+1, idx[-1]+leads+2]
# C x c x a x ph
# c scaled [-1,1]
# a scaled [0,1]
# ph scaled [0,2pi]
param_bw = x[idx[0]].item() * x[idx[1]:idx[2]].unsqueeze(1)/eps * (x[idx[2]:idx[3]].unsqueeze(0)+eps)/(2*eps)
param_pl = x[idx[4]].item() * x[idx[5]:idx[6]].unsqueeze(1)/eps * (x[idx[6]].item()+eps)/(2*eps)
z = torch.zeros(leads, N*2, dtype=x.dtype, device=x.device)
z[:, 1:1+comp] = param_bw * torch.cos((x[idx[3]:idx[4]]+eps)/eps*torch.pi).unsqueeze(0)
z[:, N+1:N+1+comp] = param_bw * torch.sin((x[idx[3]:idx[4]]+eps)/eps*torch.pi).unsqueeze(0)
z[:, int(50/0.4)] = (param_pl * torch.cos((x[idx[7]]+eps)/eps*torch.pi)).flatten()
z[:, N + int(50/0.4)] = (param_pl * torch.sin((x[idx[7]]+eps)/eps*torch.pi)).flatten()
return z
def IDFT_matrix(N: int):
# https://www.originlab.com/doc/Origin-Help/InverseFFT2-Algorithm
i, j = np.meshgrid(np.arange(N), np.arange(N))
omega = np.exp(2j * np.pi / N)
W = np.power(omega, i * j)
return W
def params_to_tensor(params):
"""Convert dictionary parameters into a single tensor"""
tensor_list = [
torch.tensor([params["C_bw"]], dtype=torch.float32),
torch.tensor(params["c_bw"], dtype=torch.float32),
torch.tensor(params["a_bw"], dtype=torch.float32),
torch.tensor(params["ph_bw"], dtype=torch.float32),
torch.tensor([params["C_pl"]], dtype=torch.float32),
torch.tensor(params["c_pl"], dtype=torch.float32),
torch.tensor([params["a_pl"]], dtype=torch.float32),
torch.tensor([params["ph_pl"]], dtype=torch.float32),
]
return torch.cat(tensor_list) # Concatenate into a single tensor
if __name__ == "__main__":
length = 1000
model = IDFT(length, )
frequencies = 10
x = np.zeros((1,frequencies*2))
i = 2
# f = 0.1 * i
# re = 1
# im = 1
# a = np.sqrt(re**2 + im**2)
# ph = np.arctan(re / im)
a = [ 0.7, 1, 0, 0, 0.2, 0.5, 0.3, 0.1, 0.4, 0.6]
ph = [0, 1, 2, 0, 0, 2, 0.5, 1, 1.5, 0.5]
f = [0.1 * i for i in range(0, 10)]
re = a * np.cos(ph)
im = a * np.sin(ph)
for i in range(frequencies):
x[:,i] = re[i]
x[:,i+frequencies] = im[i]
# x[i] = re
# x[i+length] = im
x = torch.tensor(x, dtype=torch.float)
signal = model(x)
print(signal)
# print(np.angle(x[i].detach().numpy()))
plt.plot(signal[0].detach().numpy())
# t = torch.linspace(0, 10, length)
# sin_signal = np.zeros(length)
# for i in range(0,frequencies):
# w = 2 * np.pi * f[i]
# temp = a[i]* np.cos(w * t + ph[i])
# temp = np.array(temp)
# sin_signal += temp
# # sin_signal = a* torch.cos(w * t + ph)
# plt.plot(sin_signal)
plt.show()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment