Source code for sigkit.transforms.utils
"""Utility module for SigKit transforms."""
import torch
import torch.nn as nn
from torchvision.transforms import Compose
from sigkit.core.base import SigKitError
[docs]
class ComplexTo2D(nn.Module):
"""Convert to the expected input of the model for training.
Transform a 1D torch.Tensor of dtype=torch.complex64 and shape (N,)
into a 2×N torch.Tensor of dtype=torch.float32:
- Row 0 = real part
- Row 1 = imaginary part
Example:
x = torch.randn(4096) + 1j * torch.randn(4096)
x = x.to(torch.complex64)
iq = ComplexTo2D(x)
"""
[docs]
def forward(self, x: torch.Tensor):
if not isinstance(x, torch.Tensor):
raise SigKitError(f"Expected a torch.Tensor, got {type(x)}")
if x.dtype != torch.complex64:
raise SigKitError(f"Expected dtype=torch.complex64, got {x.dtype}")
if x.ndim != 1:
raise SigKitError(f"ComplexTo2D expects a 1D tensor, got {x.shape=}")
real = x.real.to(torch.float32) # shape (N,), dtype float32
imag = x.imag.to(torch.float32) # shape (N,), dtype float32
return torch.stack([real, imag], dim=0) # shape (2, N), dtype float32
[docs]
class Normalize(nn.Module):
"""Normalize the input data."""
def __init__(self, norm=float("inf")):
super().__init__()
self.norm_order = norm
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = torch.linalg.norm(x, ord=self.norm_order, dim=-1, keepdim=True)
return (x / norm).to(torch.complex64)
[docs]
class RandomApplyProb(nn.Module):
"""Apply a list of transforms with a given probability per transform."""
def __init__(self, transforms_p: list[tuple[nn.Module, float]]):
self.transforms_p = transforms_p
def __call__(self, x):
for transform, p in self.transforms_p:
if torch.rand(1).item() < p:
x = transform(x)
return x
InferenceTransform = Compose( # convert complex tensor for inference
[
Normalize(norm=float("inf")),
ComplexTo2D(),
]
)