Daryl Fung
initial commit
39f3704
raw
history blame
11.2 kB
import torch
from infer.lib.rmvpe import STFT
from torch.nn.functional import conv1d, conv2d
from typing import Union, Optional
from .utils import linspace, temperature_sigmoid, amp_to_db
class TorchGate(torch.nn.Module):
"""
A PyTorch module that applies a spectral gate to an input signal.
Arguments:
sr {int} -- Sample rate of the input signal.
nonstationary {bool} -- Whether to use non-stationary or stationary masking (default: {False}).
n_std_thresh_stationary {float} -- Number of standard deviations above mean to threshold noise for
stationary masking (default: {1.5}).
n_thresh_nonstationary {float} -- Number of multiplies above smoothed magnitude spectrogram. for
non-stationary masking (default: {1.3}).
temp_coeff_nonstationary {float} -- Temperature coefficient for non-stationary masking (default: {0.1}).
n_movemean_nonstationary {int} -- Number of samples for moving average smoothing in non-stationary masking
(default: {20}).
prop_decrease {float} -- Proportion to decrease signal by where the mask is zero (default: {1.0}).
n_fft {int} -- Size of FFT for STFT (default: {1024}).
win_length {[int]} -- Window length for STFT. If None, defaults to `n_fft` (default: {None}).
hop_length {[int]} -- Hop length for STFT. If None, defaults to `win_length` // 4 (default: {None}).
freq_mask_smooth_hz {float} -- Frequency smoothing width for mask (in Hz). If None, no smoothing is applied
(default: {500}).
time_mask_smooth_ms {float} -- Time smoothing width for mask (in ms). If None, no smoothing is applied
(default: {50}).
"""
@torch.no_grad()
def __init__(
self,
sr: int,
nonstationary: bool = False,
n_std_thresh_stationary: float = 1.5,
n_thresh_nonstationary: float = 1.3,
temp_coeff_nonstationary: float = 0.1,
n_movemean_nonstationary: int = 20,
prop_decrease: float = 1.0,
n_fft: int = 1024,
win_length: bool = None,
hop_length: int = None,
freq_mask_smooth_hz: float = 500,
time_mask_smooth_ms: float = 50,
):
super().__init__()
# General Params
self.sr = sr
self.nonstationary = nonstationary
assert 0.0 <= prop_decrease <= 1.0
self.prop_decrease = prop_decrease
# STFT Params
self.n_fft = n_fft
self.win_length = self.n_fft if win_length is None else win_length
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
# Stationary Params
self.n_std_thresh_stationary = n_std_thresh_stationary
# Non-Stationary Params
self.temp_coeff_nonstationary = temp_coeff_nonstationary
self.n_movemean_nonstationary = n_movemean_nonstationary
self.n_thresh_nonstationary = n_thresh_nonstationary
# Smooth Mask Params
self.freq_mask_smooth_hz = freq_mask_smooth_hz
self.time_mask_smooth_ms = time_mask_smooth_ms
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
@torch.no_grad()
def _generate_mask_smoothing_filter(self) -> Union[torch.Tensor, None]:
"""
A PyTorch module that applies a spectral gate to an input signal using the STFT.
Returns:
smoothing_filter (torch.Tensor): a 2D tensor representing the smoothing filter,
with shape (n_grad_freq, n_grad_time), where n_grad_freq is the number of frequency
bins to smooth and n_grad_time is the number of time frames to smooth.
If both self.freq_mask_smooth_hz and self.time_mask_smooth_ms are None, returns None.
"""
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None:
return None
n_grad_freq = (
1
if self.freq_mask_smooth_hz is None
else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2)))
)
if n_grad_freq < 1:
raise ValueError(
f"freq_mask_smooth_hz needs to be at least {int((self.sr / (self._n_fft / 2)))} Hz"
)
n_grad_time = (
1
if self.time_mask_smooth_ms is None
else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000))
)
if n_grad_time < 1:
raise ValueError(
f"time_mask_smooth_ms needs to be at least {int((self.hop_length / self.sr) * 1000)} ms"
)
if n_grad_time == 1 and n_grad_freq == 1:
return None
v_f = torch.cat(
[
linspace(0, 1, n_grad_freq + 1, endpoint=False),
linspace(1, 0, n_grad_freq + 2),
]
)[1:-1]
v_t = torch.cat(
[
linspace(0, 1, n_grad_time + 1, endpoint=False),
linspace(1, 0, n_grad_time + 2),
]
)[1:-1]
smoothing_filter = torch.outer(v_f, v_t).unsqueeze(0).unsqueeze(0)
return smoothing_filter / smoothing_filter.sum()
@torch.no_grad()
def _stationary_mask(
self, X_db: torch.Tensor, xn: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Computes a stationary binary mask to filter out noise in a log-magnitude spectrogram.
Arguments:
X_db (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the log-magnitude spectrogram.
xn (torch.Tensor): 1D tensor containing the audio signal corresponding to X_db.
Returns:
sig_mask (torch.Tensor): Binary mask of the same shape as X_db, where values greater than the threshold
are set to 1, and the rest are set to 0.
"""
if xn is not None:
if "privateuseone" in str(xn.device):
if not hasattr(self, "stft"):
self.stft = STFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window="hann",
).to(xn.device)
XN = self.stft.transform(xn)
else:
XN = torch.stft(
xn,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(xn.device),
)
XN_db = amp_to_db(XN).to(dtype=X_db.dtype)
else:
XN_db = X_db
# calculate mean and standard deviation along the frequency axis
std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
# compute noise threshold
noise_thresh = mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary
# create binary mask by thresholding the spectrogram
sig_mask = X_db > noise_thresh.unsqueeze(2)
return sig_mask
@torch.no_grad()
def _nonstationary_mask(self, X_abs: torch.Tensor) -> torch.Tensor:
"""
Computes a non-stationary binary mask to filter out noise in a log-magnitude spectrogram.
Arguments:
X_abs (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the magnitude spectrogram.
Returns:
sig_mask (torch.Tensor): Binary mask of the same shape as X_abs, where values greater than the threshold
are set to 1, and the rest are set to 0.
"""
X_smoothed = (
conv1d(
X_abs.reshape(-1, 1, X_abs.shape[-1]),
torch.ones(
self.n_movemean_nonstationary,
dtype=X_abs.dtype,
device=X_abs.device,
).view(1, 1, -1),
padding="same",
).view(X_abs.shape)
/ self.n_movemean_nonstationary
)
# Compute slowness ratio and apply temperature sigmoid
slowness_ratio = (X_abs - X_smoothed) / (X_smoothed + 1e-6)
sig_mask = temperature_sigmoid(
slowness_ratio, self.n_thresh_nonstationary, self.temp_coeff_nonstationary
)
return sig_mask
def forward(
self, x: torch.Tensor, xn: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Apply the proposed algorithm to the input signal.
Arguments:
x (torch.Tensor): The input audio signal, with shape (batch_size, signal_length).
xn (Optional[torch.Tensor]): The noise signal used for stationary noise reduction. If `None`, the input
signal is used as the noise signal. Default: `None`.
Returns:
torch.Tensor: The denoised audio signal, with the same shape as the input signal.
"""
# Compute short-time Fourier transform (STFT)
if "privateuseone" in str(x.device):
if not hasattr(self, "stft"):
self.stft = STFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window="hann",
).to(x.device)
X, phase = self.stft.transform(x, return_phase=True)
else:
X = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(x.device),
)
# Compute signal mask based on stationary or nonstationary assumptions
if self.nonstationary:
sig_mask = self._nonstationary_mask(X.abs())
else:
sig_mask = self._stationary_mask(amp_to_db(X), xn)
# Propagate decrease in signal power
sig_mask = self.prop_decrease * (sig_mask.float() - 1.0) + 1.0
# Smooth signal mask with 2D convolution
if self.smoothing_filter is not None:
sig_mask = conv2d(
sig_mask.unsqueeze(1),
self.smoothing_filter.to(sig_mask.dtype),
padding="same",
)
# Apply signal mask to STFT magnitude and phase components
Y = X * sig_mask.squeeze(1)
# Inverse STFT to obtain time-domain signal
if "privateuseone" in str(Y.device):
y = self.stft.inverse(Y, phase)
else:
y = torch.istft(
Y,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=torch.hann_window(self.win_length).to(Y.device),
)
return y.to(dtype=x.dtype)