Source code for sigkit.transforms.phase_shift

"""Module for PhaseShift Torch Transform."""

import math
from typing import Tuple, Union

import torch
from torch import nn

from sigkit.core.base import SigKitError


[docs] class ApplyPhaseShift(nn.Module): """Apply a constant or random phase offset to a 1D complex64 torch.Tensor. Args: phase_offset: - If a single float or int: apply that fixed phase (radians). - If a tuple/list of two floats: (min_phase, max_phase), pick random uniform phi from [min_phase, max_phase] per call """ def __init__(self, phase_offset: Union[float, Tuple[float, float]]): super().__init__() if isinstance(phase_offset, (int, float)): self.min_phi = float(phase_offset) self.max_phi = float(phase_offset) elif ( isinstance(phase_offset, (tuple, list)) and len(phase_offset) == 2 and all(isinstance(p, (int, float)) for p in phase_offset) ): self.min_phi = float(phase_offset[0]) self.max_phi = float(phase_offset[1]) else: raise SigKitError( f"ApplyPhaseShift: phase_offset must be a single number or a tuple of " f"two numbers, got {phase_offset!r}" )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the PhaseShift to the Tensor.""" if x.dtype != torch.complex64 or x.ndim != 1: raise SigKitError( f"ApplyPhaseShift expects a 1D tensor of dtype=torch.complex64, " f"got {x.shape=}, {x.dtype=}" ) if self.min_phi == self.max_phi: phi = self.min_phi else: r = torch.rand(1).item() phi = self.min_phi + (self.max_phi - self.min_phi) * r c = math.cos(phi) s = math.sin(phi) phase_factor = torch.complex( torch.tensor(c, dtype=torch.float32), torch.tensor(s, dtype=torch.float32) ) return x * phase_factor