Source code for sigkit.metrics.integrity
"""Methods for computing signal integrity metrics like SNR and BER."""
from typing import Union
import numpy as np
import torch
[docs]
def estimate_snr(
clean: Union[np.ndarray, torch.Tensor],
noisy: Union[np.ndarray, torch.Tensor],
) -> float:
"""Compute SNR (dB) between clean and noisy signals."""
if isinstance(clean, torch.Tensor) and isinstance(noisy, torch.Tensor):
return _estimate_snr_torch(clean, noisy)
if isinstance(clean, np.ndarray) and isinstance(noisy, np.ndarray):
return _estimate_snr_np(clean, noisy)
raise ValueError(f"Type mismatch: got {type(clean)} vs {type(noisy)}")
def _estimate_snr_np(clean: np.ndarray, noisy: np.ndarray) -> float:
"""NumPy implementation of SNR in dB."""
sig_power = np.mean(np.abs(clean) ** 2)
noise_power = np.mean(np.abs(noisy - clean) ** 2)
return 10 * np.log10(sig_power / noise_power)
def _estimate_snr_torch(x: torch.Tensor, y: torch.Tensor) -> float:
"""PyTorch implementation of SNR in dB.
Expects real-imag stacked: shape (2, N).
"""
sig_power = (x.pow(2).sum(dim=0)).mean().item()
noise_power = ((y - x).pow(2).sum(dim=0)).mean().item()
return 10 * torch.log10(torch.tensor(sig_power / noise_power)).item()
[docs]
def calculate_ber(
bits: Union[np.ndarray, torch.Tensor],
truth_bits: Union[np.ndarray, torch.Tensor],
) -> float:
"""Compute bit-error rate (fraction of mismatches)."""
# If torch, convert to NumPy
if isinstance(truth_bits, torch.Tensor):
truth_bits = truth_bits.cpu().numpy()
if isinstance(bits, torch.Tensor):
bits = bits.cpu().numpy()
errors = np.count_nonzero(truth_bits != bits)
return errors / truth_bits.size