Source code for sigkit.core.base

"""Abstract base classes and common exceptions for SigKit."""

from dataclasses import dataclass, field

import numpy as np
import torch


[docs] class SigKitError(Exception): """Base exception for Sigkit-specific errors."""
[docs] @dataclass class Signal: """A container for a complex waveform. Parameters: samples: ndarray of shape (N) containing complex64 values, defaults to 4096. sample_rate: in Hz carrier_frequency: in Hz. """ samples: np.ndarray = field( default_factory=lambda: np.zeros(4096, dtype=np.complex64) ) sample_rate: float = 1.0 carrier_frequency: float = 0.0 def __post_init__(self): if self.samples.dtype != np.complex64: raise SigKitError("Signal samples must be np.ndarray[complex64]")
[docs] def to_tensor(self) -> torch.Tensor: """Convert the samples parameter to a PyTorch Tensor. Convert into a complex64 tensor of shape (N,), Note that for our training pipeline, N should be 4096. """ return torch.from_numpy(self.samples)
[docs] def to_baseband(self) -> "Signal": """Convert the Signal to baseband by removing the carrier frequency. If the carrier frequency is 0, the method returns. """ if self.carrier_frequency == 0.0: return self t = np.arange(self.samples.size) / self.sample_rate baseband_samples = ( self.samples * np.exp(-1j * 2 * np.pi * self.carrier_frequency * t) ).astype(np.complex64) return Signal( samples=baseband_samples, sample_rate=self.sample_rate, carrier_frequency=0.0, )