Source code for sigkit.transforms.awgn
"""Module for AWGN Torch Transform."""
import torch
from torch import nn
from sigkit.core.base import SigKitError
[docs]
class ApplyAWGN(nn.Module):
"""Applies Additive White Gaussian Noise to reach a target SNR.
Args:
snr_db:
- If float or int: use that fixed SNR (in dB) on every forward().
- If tuple/list of two floats: (min_snr_db, max_snr_db), sample
uniformly from [min_snr_db, max_snr_db] each all.
"""
def __init__(self, snr_db: float | tuple[float, float]):
super().__init__()
if isinstance(snr_db, (int, float)):
self.min_snr = float(snr_db)
self.max_snr = float(snr_db)
elif (
isinstance(snr_db, (tuple, list))
and len(snr_db) == 2
and all(isinstance(v, (int, float)) for v in snr_db)
):
self.min_snr = float(snr_db[0])
self.max_snr = float(snr_db[1])
else:
raise SigKitError(
"ApplyAWGN: snr_db must be a number or a tuple of two numbers, "
f"got {snr_db!r}"
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies AWGN to the input to reach the target SNR.
Args:
x: torch.Tensor of shape [N], dtype=torch.complex64.
Returns:
torch.Tensor of shape [N], dtype=torch.complex64 with AWGN.
"""
if x.dtype != torch.complex64 or x.ndim != 1:
raise SigKitError("Expected input of shape [N] and dtype=torch.complex64")
if self.min_snr == self.max_snr:
snr_db = self.min_snr
else:
r = torch.rand(1).item()
snr_db = self.min_snr + (self.max_snr - self.min_snr) * r
sig_power = (x.abs() ** 2).mean()
snr_lin = 10.0 ** (snr_db / 10.0)
noise_power = sig_power / snr_lin
std_dev = torch.sqrt(noise_power / 2.0)
real_noise = std_dev * torch.randn_like(x.real)
imag_noise = std_dev * torch.randn_like(x.real)
noise = (real_noise + 1j * imag_noise).to(torch.complex64)
return x + noise