Source code for sigkit.transforms.frequency_shift
"""Module for FrequencyShift Torch Transform."""
from typing import Tuple, Union
import numpy as np
import torch
from torch import nn
from sigkit.core.base import SigKitError
[docs]
class ApplyFrequencyShift(nn.Module):
"""Apply a constant or random frequency offset to a 1D complex64 torch.Tensor.
Args:
freq_offset:
- If a single float or int: apply that fixed frequency (in Hz).
- If a tuple/list of two floats: (min_freq, max_freq), pick
random uniform from [min_freq, max_freq] per call.
sample_rate:
- Sampling rate of the signal (in samples per second).
"""
def __init__(
self,
freq_offset: Union[float, Tuple[float, float]],
sample_rate: float,
):
super().__init__()
if isinstance(freq_offset, (int, float)):
self.min_f = float(freq_offset)
self.max_f = float(freq_offset)
elif (
isinstance(freq_offset, (tuple, list))
and len(freq_offset) == 2
and all(isinstance(f, (int, float)) for f in freq_offset)
):
self.min_f = float(freq_offset[0])
self.max_f = float(freq_offset[1])
else:
raise SigKitError(
f"ApplyFrequencyShift: freq_offset must be a single number or a tuple"
f" of two numbers, got {freq_offset}"
)
if sample_rate <= 0:
raise SigKitError(
f"ApplyFrequencyShift: sample_rate must be positive, got {sample_rate}"
)
self.sample_rate = float(sample_rate)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the frequency shift to the Tensor."""
if x.dtype != torch.complex64 or x.ndim != 1:
raise SigKitError(
f"ApplyFrequencyShift expects a 1D tensor of dtype=torch.complex64,"
f" got {x.shape=}, {x.dtype=}"
)
if self.min_f == self.max_f:
f = self.min_f
else:
r = torch.rand(1).item()
f = self.min_f + (self.max_f - self.min_f) * r
n = torch.arange(x.shape[0], device=x.device, dtype=torch.float32)
t = n / self.sample_rate
ang = 2 * np.pi * f * t
real = torch.cos(ang)
imag = torch.sin(ang)
phase = torch.complex(real, imag)
return x * phase