Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-07-01 00:15:56 +0300
committerJean-Marc Valin <jmvalin@jmvalin.ca>2023-07-01 00:15:56 +0300
commit105e1d83fad6393b00edb7eb676be483eb4ee2d7 (patch)
tree74873b0b88818966a9f614f144bb2e55762338a6
parent178672ed1823f2a2fdc7e36e34578383f799f4f6 (diff)
Opus ng lace
-rw-r--r--dnn/torch/osce/README.md4
-rw-r--r--dnn/torch/osce/data/__init__.py30
-rw-r--r--dnn/torch/osce/data/silk_enhancement_set.py140
-rw-r--r--dnn/torch/osce/engine/engine.py101
-rw-r--r--dnn/torch/osce/losses/stft_loss.py277
-rw-r--r--dnn/torch/osce/make_default_setup.py56
-rw-r--r--dnn/torch/osce/models/__init__.py36
-rw-r--r--dnn/torch/osce/models/lace.py176
-rw-r--r--dnn/torch/osce/models/nns_base.py69
-rw-r--r--dnn/torch/osce/models/scale_embedding.py68
-rw-r--r--dnn/torch/osce/models/silk_feature_net.py86
-rw-r--r--dnn/torch/osce/models/silk_feature_net_pl.py90
-rw-r--r--dnn/torch/osce/test_model.py96
-rw-r--r--dnn/torch/osce/train_model.py297
-rw-r--r--dnn/torch/osce/utils/complexity.py35
-rw-r--r--dnn/torch/osce/utils/endoscopy.py234
-rw-r--r--dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py236
-rw-r--r--dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py222
-rw-r--r--dnn/torch/osce/utils/layers/pitch_auto_correlator.py84
-rw-r--r--dnn/torch/osce/utils/misc.py42
-rw-r--r--dnn/torch/osce/utils/pitch.py121
-rw-r--r--dnn/torch/osce/utils/silk_features.py151
-rw-r--r--dnn/torch/osce/utils/spec.py194
-rw-r--r--dnn/torch/osce/utils/templates.py92
24 files changed, 2937 insertions, 0 deletions
diff --git a/dnn/torch/osce/README.md b/dnn/torch/osce/README.md
new file mode 100644
index 00000000..1f940113
--- /dev/null
+++ b/dnn/torch/osce/README.md
@@ -0,0 +1,4 @@
+# Opus Speech Coding Enhancement
+
+This folder hosts models for enhancing SILK. See related Opus repo https://gitlab.xiph.org/xiph/opus/-/tree/exp-neural-silk-enhancement
+for feature generation. \ No newline at end of file
diff --git a/dnn/torch/osce/data/__init__.py b/dnn/torch/osce/data/__init__.py
new file mode 100644
index 00000000..9f7ea183
--- /dev/null
+++ b/dnn/torch/osce/data/__init__.py
@@ -0,0 +1,30 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+from .silk_enhancement_set import SilkEnhancementSet \ No newline at end of file
diff --git a/dnn/torch/osce/data/silk_enhancement_set.py b/dnn/torch/osce/data/silk_enhancement_set.py
new file mode 100644
index 00000000..186333e9
--- /dev/null
+++ b/dnn/torch/osce/data/silk_enhancement_set.py
@@ -0,0 +1,140 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+
+from torch.utils.data import Dataset
+import numpy as np
+
+from utils.silk_features import silk_feature_factory
+from utils.pitch import hangover, calculate_acorr_window
+
+
+class SilkEnhancementSet(Dataset):
+ def __init__(self,
+ path,
+ frames_per_sample=100,
+ no_pitch_value=256,
+ preemph=0.85,
+ skip=91,
+ acorr_radius=2,
+ pitch_hangover=8,
+ num_bands_clean_spec=64,
+ num_bands_noisy_spec=18,
+ noisy_spec_scale='opus',
+ noisy_apply_dct=True,
+ add_offset=False,
+ add_double_lag_acorr=False
+ ):
+
+ assert frames_per_sample % 4 == 0
+
+ self.frame_size = 80
+ self.frames_per_sample = frames_per_sample
+ self.no_pitch_value = no_pitch_value
+ self.preemph = preemph
+ self.skip = skip
+ self.acorr_radius = acorr_radius
+ self.pitch_hangover = pitch_hangover
+ self.num_bands_clean_spec = num_bands_clean_spec
+ self.num_bands_noisy_spec = num_bands_noisy_spec
+ self.noisy_spec_scale = noisy_spec_scale
+ self.add_double_lag_acorr = add_double_lag_acorr
+
+ self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
+ self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
+ self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
+ self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
+ self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
+ self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
+ self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
+
+ self.clean_signal = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
+ self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
+
+ self.create_features = silk_feature_factory(no_pitch_value,
+ acorr_radius,
+ pitch_hangover,
+ num_bands_clean_spec,
+ num_bands_noisy_spec,
+ noisy_spec_scale,
+ noisy_apply_dct,
+ add_offset,
+ add_double_lag_acorr)
+
+ self.history_len = 700 if add_double_lag_acorr else 350
+ # discard some frames to have enough signal history
+ self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
+
+ num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
+
+ self.len = num_frames // frames_per_sample
+
+ def __len__(self):
+ return self.len
+
+ def __getitem__(self, index):
+
+ frame_start = self.frames_per_sample * index + self.skip_frames
+ frame_stop = frame_start + self.frames_per_sample
+
+ signal_start = frame_start * self.frame_size - self.skip
+ signal_stop = frame_stop * self.frame_size - self.skip
+
+ clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
+ coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
+
+ coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
+
+ features, periods = self.create_features(
+ coded_signal,
+ coded_signal_history,
+ self.lpcs[frame_start : frame_stop],
+ self.gains[frame_start : frame_stop],
+ self.ltps[frame_start : frame_stop],
+ self.periods[frame_start : frame_stop],
+ self.offsets[frame_start : frame_stop]
+ )
+
+ if self.preemph > 0:
+ clean_signal[1:] -= self.preemph * clean_signal[: -1]
+ coded_signal[1:] -= self.preemph * coded_signal[: -1]
+
+ num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
+ num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
+
+ numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
+
+ return {
+ 'features' : features,
+ 'periods' : periods.astype(np.int64),
+ 'target' : clean_signal.astype(np.float32),
+ 'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
+ 'numbits' : numbits.astype(np.float32)
+ }
diff --git a/dnn/torch/osce/engine/engine.py b/dnn/torch/osce/engine/engine.py
new file mode 100644
index 00000000..7688e9b4
--- /dev/null
+++ b/dnn/torch/osce/engine/engine.py
@@ -0,0 +1,101 @@
+import torch
+from tqdm import tqdm
+import sys
+
+def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
+
+ model.to(device)
+ model.train()
+
+ running_loss = 0
+ previous_running_loss = 0
+
+
+ with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+
+ for i, batch in enumerate(tepoch):
+
+ # set gradients to zero
+ optimizer.zero_grad()
+
+
+ # push batch to device
+ for key in batch:
+ batch[key] = batch[key].to(device)
+
+ target = batch['target']
+
+ # calculate model output
+ output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
+
+ # calculate loss
+ if isinstance(output, list):
+ loss = torch.zeros(1, device=device)
+ for y in output:
+ loss = loss + criterion(target, y.squeeze(1))
+ loss = loss / len(output)
+ else:
+ loss = criterion(target, output.squeeze(1))
+
+ # calculate gradients
+ loss.backward()
+
+ # update weights
+ optimizer.step()
+
+ # update learning rate
+ scheduler.step()
+
+ # update running loss
+ running_loss += float(loss.cpu())
+
+ # update status bar
+ if i % log_interval == 0:
+ tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
+ previous_running_loss = running_loss
+
+
+ running_loss /= len(dataloader)
+
+ return running_loss
+
+def evaluate(model, criterion, dataloader, device, log_interval=10):
+
+ model.to(device)
+ model.eval()
+
+ running_loss = 0
+ previous_running_loss = 0
+
+
+ with torch.no_grad():
+ with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+
+ for i, batch in enumerate(tepoch):
+
+
+
+ # push batch to device
+ for key in batch:
+ batch[key] = batch[key].to(device)
+
+ target = batch['target']
+
+ # calculate model output
+ output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
+
+ # calculate loss
+ loss = criterion(target, output.squeeze(1))
+
+ # update running loss
+ running_loss += float(loss.cpu())
+
+ # update status bar
+ if i % log_interval == 0:
+ tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
+ previous_running_loss = running_loss
+
+
+ running_loss /= len(dataloader)
+
+ return running_loss \ No newline at end of file
diff --git a/dnn/torch/osce/losses/stft_loss.py b/dnn/torch/osce/losses/stft_loss.py
new file mode 100644
index 00000000..4c164cb6
--- /dev/null
+++ b/dnn/torch/osce/losses/stft_loss.py
@@ -0,0 +1,277 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+"""STFT-based Loss modules."""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+import torchaudio
+
+
+def get_window(win_name, win_length, *args, **kwargs):
+ window_dict = {
+ 'bartlett_window' : torch.bartlett_window,
+ 'blackman_window' : torch.blackman_window,
+ 'hamming_window' : torch.hamming_window,
+ 'hann_window' : torch.hann_window,
+ 'kaiser_window' : torch.kaiser_window
+ }
+
+ if not win_name in window_dict:
+ raise ValueError()
+
+ return window_dict[win_name](win_length, *args, **kwargs)
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+ """Perform STFT and convert to magnitude spectrogram.
+ Args:
+ x (Tensor): Input signal tensor (B, T).
+ fft_size (int): FFT size.
+ hop_size (int): Hop size.
+ win_length (int): Window length.
+ window (str): Window function type.
+ Returns:
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+ """
+
+ win = get_window(window, win_length).to(x.device)
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
+
+
+ return torch.clamp(torch.abs(x_stft), min=1e-7)
+
+def spectral_convergence_loss(Y_true, Y_pred):
+ dims=list(range(1, len(Y_pred.shape)))
+ return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
+
+
+def log_magnitude_loss(Y_true, Y_pred):
+ Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
+ Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
+
+ return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
+
+def spectral_xcorr_loss(Y_true, Y_pred):
+ Y_true = Y_true.abs()
+ Y_pred = Y_pred.abs()
+ dims=list(range(1, len(Y_pred.shape)))
+ xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
+
+ return 1 - xcorr.mean()
+
+
+
+class MRLogMelLoss(nn.Module):
+ def __init__(self,
+ fft_sizes=[512, 256, 128, 64],
+ overlap=0.5,
+ fs=16000,
+ n_mels=18
+ ):
+
+ self.fft_sizes = fft_sizes
+ self.overlap = overlap
+ self.fs = fs
+ self.n_mels = n_mels
+
+ super().__init__()
+
+ self.mel_specs = []
+ for fft_size in fft_sizes:
+ hop_size = int(round(fft_size * (1 - self.overlap)))
+
+ n_mels = self.n_mels
+ if fft_size < 128:
+ n_mels //= 2
+
+ self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
+
+ for i, mel_spec in enumerate(self.mel_specs):
+ self.add_module(f'mel_spec_{i+1}', mel_spec)
+
+ def forward(self, y_true, y_pred):
+
+ loss = torch.zeros(1, device=y_true.device)
+
+ for mel_spec in self.mel_specs:
+ Y_true = mel_spec(y_true)
+ Y_pred = mel_spec(y_pred)
+ loss = loss + log_magnitude_loss(Y_true, Y_pred)
+
+ loss = loss / len(self.mel_specs)
+
+ return loss
+
+def create_weight_matrix(num_bins, bins_per_band=10):
+ m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
+
+ r0 = bins_per_band // 2
+ r1 = bins_per_band - r0
+
+ for i in range(num_bins):
+ i0 = max(i - r0, 0)
+ j0 = min(i + r1, num_bins)
+
+ m[i, i0: j0] += 1
+
+ if i < r0:
+ m[i, :r0 - i] += 1
+
+ if i > num_bins - r1:
+ m[i, num_bins - r1 - i:] += 1
+
+ return m / bins_per_band
+
+def weighted_spectral_convergence(Y_true, Y_pred, w):
+
+ # calculate sfm based weights
+ logY = torch.log(torch.abs(Y_true) + 1e-9)
+ Y = torch.abs(Y_true)
+
+ avg_logY = torch.matmul(logY.transpose(1, 2), w)
+ avg_Y = torch.matmul(Y.transpose(1, 2), w)
+
+ sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
+
+ weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
+
+ loss = torch.mean(
+ torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
+ / (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
+ )
+
+ return loss
+
+def gen_filterbank(N, Fs=16000):
+ in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
+ out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
+ #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
+ ERB_N = 24.7 + .108*in_freq
+ delta = np.abs(in_freq-out_freq)/ERB_N
+ center = (delta<.5).astype('float32')
+ R = -12*center*delta**2 + (1-center)*(3-12*delta)
+ RE = 10.**(R/10.)
+ norm = np.sum(RE, axis=1)
+ RE = RE/norm[:, np.newaxis]
+ return torch.from_numpy(RE)
+
+def smooth_log_mag(Y_true, Y_pred, filterbank):
+ Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
+ Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
+
+ loss = torch.abs(
+ torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
+ )
+
+ loss = loss.mean()
+
+ return loss
+
+class MRSTFTLoss(nn.Module):
+ def __init__(self,
+ fft_sizes=[2048, 1024, 512, 256, 128, 64],
+ overlap=0.5,
+ window='hann_window',
+ fs=16000,
+ log_mag_weight=1,
+ sc_weight=0,
+ wsc_weight=0,
+ smooth_log_mag_weight=0,
+ sxcorr_weight=0):
+ super().__init__()
+
+ self.fft_sizes = fft_sizes
+ self.overlap = overlap
+ self.window = window
+ self.log_mag_weight = log_mag_weight
+ self.sc_weight = sc_weight
+ self.wsc_weight = wsc_weight
+ self.smooth_log_mag_weight = smooth_log_mag_weight
+ self.sxcorr_weight = sxcorr_weight
+ self.fs = fs
+
+ # weights for SFM weighted spectral convergence loss
+ self.wsc_weights = torch.nn.ParameterDict()
+ for fft_size in fft_sizes:
+ width = min(11, int(1000 * fft_size / self.fs + .5))
+ width += width % 2
+ self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
+ create_weight_matrix(fft_size // 2 + 1, width),
+ requires_grad=False
+ )
+
+ # filterbanks for smooth log magnitude loss
+ self.filterbanks = torch.nn.ParameterDict()
+ for fft_size in fft_sizes:
+ self.filterbanks[str(fft_size)] = torch.nn.Parameter(
+ gen_filterbank(fft_size//2),
+ requires_grad=False
+ )
+
+
+ def __call__(self, y_true, y_pred):
+
+
+ lm_loss = torch.zeros(1, device=y_true.device)
+ sc_loss = torch.zeros(1, device=y_true.device)
+ wsc_loss = torch.zeros(1, device=y_true.device)
+ slm_loss = torch.zeros(1, device=y_true.device)
+ sxcorr_loss = torch.zeros(1, device=y_true.device)
+
+ for fft_size in self.fft_sizes:
+ hop_size = int(round(fft_size * (1 - self.overlap)))
+ win_size = fft_size
+
+ Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
+ Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
+
+ if self.log_mag_weight > 0:
+ lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
+
+ if self.sc_weight > 0:
+ sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
+
+ if self.wsc_weight > 0:
+ wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
+
+ if self.smooth_log_mag_weight > 0:
+ slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
+
+ if self.sxcorr_weight > 0:
+ sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
+
+
+ total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+ + self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+ + self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
+
+ return total_loss \ No newline at end of file
diff --git a/dnn/torch/osce/make_default_setup.py b/dnn/torch/osce/make_default_setup.py
new file mode 100644
index 00000000..06add8fa
--- /dev/null
+++ b/dnn/torch/osce/make_default_setup.py
@@ -0,0 +1,56 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import argparse
+
+import yaml
+
+from utils.templates import setup_dict
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('name', type=str, help='name of default setup file')
+parser.add_argument('--model', choices=['lace'], help='model name', default='lace')
+parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
+
+args = parser.parse_args()
+
+setup = setup_dict[args.model]
+
+# update dataset if given
+if type(args.path2dataset) != type(None):
+ setup['dataset'] = args.path2dataset
+
+name = args.name
+if not name.endswith('.yml'):
+ name += '.yml'
+
+if __name__ == '__main__':
+ with open(name, 'w') as f:
+ f.write(yaml.dump(setup)) \ No newline at end of file
diff --git a/dnn/torch/osce/models/__init__.py b/dnn/torch/osce/models/__init__.py
new file mode 100644
index 00000000..c8dfc5d9
--- /dev/null
+++ b/dnn/torch/osce/models/__init__.py
@@ -0,0 +1,36 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+from .lace import LACE
+
+
+
+model_dict = {
+ 'lace': LACE
+}
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py
new file mode 100644
index 00000000..a11dfc41
--- /dev/null
+++ b/dnn/torch/osce/models/lace.py
@@ -0,0 +1,176 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import numpy as np
+
+from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
+from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
+
+from models.nns_base import NNSBase
+from models.silk_feature_net_pl import SilkFeatureNetPL
+from models.silk_feature_net import SilkFeatureNet
+from .scale_embedding import ScaleEmbedding
+
+class LACE(NNSBase):
+ """ Linear-Adaptive Coding Enhancer """
+ FRAME_SIZE=80
+
+ def __init__(self,
+ num_features=47,
+ pitch_embedding_dim=64,
+ cond_dim=256,
+ pitch_max=257,
+ kernel_size=15,
+ preemph=0.85,
+ skip=91,
+ comb_gain_limit_db=-6,
+ global_gain_limits_db=[-6, 6],
+ conv_gain_limits_db=[-6, 6],
+ numbits_range=[50, 650],
+ numbits_embedding_dim=8,
+ hidden_feature_dim=64,
+ partial_lookahead=True,
+ norm_p=2):
+
+ super().__init__(skip=skip, preemph=preemph)
+
+
+ self.num_features = num_features
+ self.cond_dim = cond_dim
+ self.pitch_max = pitch_max
+ self.pitch_embedding_dim = pitch_embedding_dim
+ self.kernel_size = kernel_size
+ self.preemph = preemph
+ self.skip = skip
+ self.numbits_range = numbits_range
+ self.numbits_embedding_dim = numbits_embedding_dim
+ self.hidden_feature_dim = hidden_feature_dim
+ self.partial_lookahead = partial_lookahead
+
+ # pitch embedding
+ self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
+
+ # numbits embedding
+ self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
+
+ # feature net
+ if partial_lookahead:
+ self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
+ else:
+ self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
+
+ # comb filters
+ left_pad = self.kernel_size // 2
+ right_pad = self.kernel_size - 1 - left_pad
+ self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+ self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+
+ # spectral shaping
+ self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+
+ def flop_count(self, rate=16000, verbose=False):
+
+ frame_rate = rate / self.FRAME_SIZE
+
+ # feature net
+ feature_net_flops = self.feature_net.flop_count(frame_rate)
+ comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
+ af_flops = self.af1.flop_count(rate)
+
+ if verbose:
+ print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
+ print(f"comb filters: {comb_flops / 1e6} MFLOPS")
+ print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
+
+ return feature_net_flops + comb_flops + af_flops
+
+ def forward(self, x, features, periods, numbits, debug=False):
+
+ periods = periods.squeeze(-1)
+ pitch_embedding = self.pitch_embedding(periods)
+ numbits_embedding = self.numbits_embedding(numbits).flatten(2)
+
+ full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
+ cf = self.feature_net(full_features)
+
+ y = self.cf1(x, cf, periods, debug=debug)
+
+ y = self.cf2(y, cf, periods, debug=debug)
+
+ y = self.af1(y, cf, debug=debug)
+
+ return y
+
+ def get_impulse_responses(self, features, periods, numbits):
+ """ generates impoulse responses on frame centers (input without batch dimension) """
+
+ num_frames = features.size(0)
+ batch_size = 32
+ max_len = 2 * (self.pitch_max + self.kernel_size) + 10
+
+ # spread out some pulses
+ x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
+ for b in range(batch_size):
+ x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
+
+ # prepare input
+ x = torch.from_numpy(x).float().to(features.device)
+ features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
+ periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
+ numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
+
+ # run network
+ with torch.no_grad():
+ periods = periods.squeeze(-1)
+ pitch_embedding = self.pitch_embedding(periods)
+ numbits_embedding = self.numbits_embedding(numbits).flatten(2)
+ full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
+ cf = self.feature_net(full_features)
+ y = self.cf1(x, cf, periods, debug=False)
+ y = self.cf2(y, cf, periods, debug=False)
+ y = self.af1(y, cf, debug=False)
+
+ # collect responses
+ y = y.detach().squeeze().cpu().numpy()
+ cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
+ num_responses = num_frames - cut_frames
+ responses = np.zeros((num_responses, max_len))
+
+ for i in range(num_responses):
+ b = i % batch_size
+ start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
+ stop = start + max_len
+
+ responses[i, :] = y[b, start:stop]
+
+ return responses
diff --git a/dnn/torch/osce/models/nns_base.py b/dnn/torch/osce/models/nns_base.py
new file mode 100644
index 00000000..6e667b96
--- /dev/null
+++ b/dnn/torch/osce/models/nns_base.py
@@ -0,0 +1,69 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+class NNSBase(nn.Module):
+
+ def __init__(self, skip=91, preemph=0.85):
+ super().__init__()
+
+ self.skip = skip
+ self.preemph = preemph
+
+ def process(self, sig, features, periods, numbits, debug=False):
+
+ self.eval()
+ has_numbits = 'numbits' in self.forward.__code__.co_varnames
+ device = next(iter(self.parameters())).device
+ with torch.no_grad():
+
+ # run model
+ x = sig.view(1, 1, -1).to(device)
+ f = features.unsqueeze(0).to(device)
+ p = periods.unsqueeze(0).to(device)
+ n = numbits.unsqueeze(0).to(device)
+
+ if has_numbits:
+ y = self.forward(x, f, p, n, debug=debug).squeeze()
+ else:
+ y = self.forward(x, f, p, debug=debug).squeeze()
+
+ # deemphasis
+ if self.preemph > 0:
+ for i in range(len(y) - 1):
+ y[i + 1] += self.preemph * y[i]
+
+ # delay compensation
+ y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device)))
+ out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
+
+ return out \ No newline at end of file
diff --git a/dnn/torch/osce/models/scale_embedding.py b/dnn/torch/osce/models/scale_embedding.py
new file mode 100644
index 00000000..58695302
--- /dev/null
+++ b/dnn/torch/osce/models/scale_embedding.py
@@ -0,0 +1,68 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import math as m
+import torch
+from torch import nn
+
+
+class ScaleEmbedding(nn.Module):
+ def __init__(self,
+ dim,
+ min_val,
+ max_val,
+ logscale=False):
+
+ super().__init__()
+
+ if min_val >= max_val:
+ raise ValueError('min_val must be smaller than max_val')
+
+ if min_val <= 0 and logscale:
+ raise ValueError('min_val must be positive when logscale is true')
+
+ self.dim = dim
+ self.logscale = logscale
+ self.min_val = min_val
+ self.max_val = max_val
+
+ if logscale:
+ self.min_val = m.log(self.min_val)
+ self.max_val = m.log(self.max_val)
+
+
+ self.offset = (self.min_val + self.max_val) / 2
+ self.scale_factors = nn.Parameter(
+ torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val)
+ )
+
+ def forward(self, x):
+ if self.logscale: x = torch.log(x)
+ x = torch.clip(x, self.min_val, self.max_val) - self.offset
+ return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5)
diff --git a/dnn/torch/osce/models/silk_feature_net.py b/dnn/torch/osce/models/silk_feature_net.py
new file mode 100644
index 00000000..ed22f52e
--- /dev/null
+++ b/dnn/torch/osce/models/silk_feature_net.py
@@ -0,0 +1,86 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.complexity import _conv1d_flop_count
+
+class SilkFeatureNet(nn.Module):
+
+ def __init__(self,
+ feature_dim=47,
+ num_channels=256,
+ lookahead=False):
+
+ super(SilkFeatureNet, self).__init__()
+
+ self.feature_dim = feature_dim
+ self.num_channels = num_channels
+ self.lookahead = lookahead
+
+ self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
+ self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2)
+
+ self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
+
+ def flop_count(self, rate=200):
+ count = 0
+ for conv in self.conv1, self.conv2:
+ count += _conv1d_flop_count(conv, rate)
+
+ count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
+
+ return count
+
+
+ def forward(self, features, state=None):
+ """ features shape: (batch_size, num_frames, feature_dim) """
+
+ batch_size = features.size(0)
+
+ if state is None:
+ state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
+
+
+ features = features.permute(0, 2, 1)
+ if self.lookahead:
+ c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
+ c = torch.tanh(self.conv2(F.pad(c, [2, 2])))
+ else:
+ c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
+ c = torch.tanh(self.conv2(F.pad(c, [4, 0])))
+
+ c = c.permute(0, 2, 1)
+
+ c, _ = self.gru(c, state)
+
+ return c \ No newline at end of file
diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py
new file mode 100644
index 00000000..ae37951c
--- /dev/null
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -0,0 +1,90 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.complexity import _conv1d_flop_count
+
+class SilkFeatureNetPL(nn.Module):
+ """ feature net with partial lookahead """
+ def __init__(self,
+ feature_dim=47,
+ num_channels=256,
+ hidden_feature_dim=64):
+
+ super(SilkFeatureNetPL, self).__init__()
+
+ self.feature_dim = feature_dim
+ self.num_channels = num_channels
+ self.hidden_feature_dim = hidden_feature_dim
+
+ self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)
+ self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)
+ self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4)
+
+ self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
+
+ def flop_count(self, rate=200):
+ count = 0
+ for conv in self.conv1, self.conv2, self.tconv:
+ count += _conv1d_flop_count(conv, rate)
+
+ count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
+
+ return count
+
+
+ def forward(self, features, state=None):
+ """ features shape: (batch_size, num_frames, feature_dim) """
+
+ batch_size = features.size(0)
+ num_frames = features.size(1)
+
+ if state is None:
+ state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
+
+ features = features.permute(0, 2, 1)
+ # dimensionality reduction
+ c = torch.tanh(self.conv1(features))
+
+ # frame accumulation
+ c = c.permute(0, 2, 1)
+ c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1)
+ c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
+
+ # upsampling
+ c = self.tconv(c)
+ c = c.permute(0, 2, 1)
+
+ c, _ = self.gru(c, state)
+
+ return c \ No newline at end of file
diff --git a/dnn/torch/osce/test_model.py b/dnn/torch/osce/test_model.py
new file mode 100644
index 00000000..616a0ec5
--- /dev/null
+++ b/dnn/torch/osce/test_model.py
@@ -0,0 +1,96 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import argparse
+
+import torch
+
+from scipy.io import wavfile
+
+
+from models import model_dict
+from utils.silk_features import load_inference_data
+from utils import endoscopy
+
+debug = False
+if debug:
+ args = type('dummy', (object,),
+ {
+ 'input' : 'testitems/all_0_orig.se',
+ 'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
+ 'output' : 'out.wav',
+ })()
+else:
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('input', type=str, help='path to folder with features and signals')
+ parser.add_argument('checkpoint', type=str, help='checkpoint file')
+ parser.add_argument('output', type=str, help='output file')
+ parser.add_argument('--debug', action='store_true', help='enables debug output')
+
+
+ args = parser.parse_args()
+
+
+torch.set_num_threads(2)
+
+input_folder = args.input
+checkpoint_file = args.checkpoint
+
+
+output_file = args.output
+if not output_file.endswith('.wav'):
+ output_file += '.wav'
+
+checkpoint = torch.load(checkpoint_file, map_location="cpu")
+
+# check model
+if not 'name' in checkpoint['setup']['model']:
+ print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
+ model_name = 'pitchpostfilter'
+else:
+ model_name = checkpoint['setup']['model']['name']
+
+model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
+
+model.load_state_dict(checkpoint['state_dict'])
+
+# generate model input
+setup = checkpoint['setup']
+signal, features, periods, numbits = load_inference_data(input_folder, **setup['data'])
+
+if args.debug:
+ endoscopy.init()
+
+output = model.process(signal, features, periods, numbits, debug=args.debug)
+
+wavfile.write(output_file, 16000, output.cpu().numpy())
+
+if args.debug:
+ endoscopy.close()
diff --git a/dnn/torch/osce/train_model.py b/dnn/torch/osce/train_model.py
new file mode 100644
index 00000000..6e2514b9
--- /dev/null
+++ b/dnn/torch/osce/train_model.py
@@ -0,0 +1,297 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+import sys
+
+import yaml
+
+try:
+ import git
+ has_git = True
+except:
+ has_git = False
+
+import torch
+from torch.optim.lr_scheduler import LambdaLR
+
+import numpy as np
+
+from scipy.io import wavfile
+
+import pesq
+
+from data import SilkEnhancementSet
+from models import model_dict
+from engine.engine import train_one_epoch, evaluate
+
+
+from utils.silk_features import load_inference_data
+from utils.misc import count_parameters
+
+from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('setup', type=str, help='setup yaml file')
+parser.add_argument('output', type=str, help='output path')
+parser.add_argument('--device', type=str, help='compute device', default=None)
+parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
+parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
+parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
+
+args = parser.parse_args()
+
+
+torch.set_num_threads(4)
+
+with open(args.setup, 'r') as f:
+ setup = yaml.load(f.read(), yaml.FullLoader)
+
+checkpoint_prefix = 'checkpoint'
+output_prefix = 'output'
+setup_name = 'setup.yml'
+output_file='out.txt'
+
+
+# check model
+if not 'name' in setup['model']:
+ print(f'warning: did not find model entry in setup, using default PitchPostFilter')
+ model_name = 'pitchpostfilter'
+else:
+ model_name = setup['model']['name']
+
+# prepare output folder
+if os.path.exists(args.output):
+ print("warning: output folder exists")
+
+ reply = input('continue? (y/n): ')
+ while reply not in {'y', 'n'}:
+ reply = input('continue? (y/n): ')
+
+ if reply == 'n':
+ os._exit()
+else:
+ os.makedirs(args.output, exist_ok=True)
+
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+# add repo info to setup
+if has_git:
+ working_dir = os.path.split(__file__)[0]
+ try:
+ repo = git.Repo(working_dir)
+ setup['repo'] = dict()
+ hash = repo.head.object.hexsha
+ urls = list(repo.remote().urls)
+ is_dirty = repo.is_dirty()
+
+ if is_dirty:
+ print("warning: repo is dirty")
+
+ setup['repo']['hash'] = hash
+ setup['repo']['urls'] = urls
+ setup['repo']['dirty'] = is_dirty
+ except:
+ has_git = False
+
+# dump setup
+with open(os.path.join(args.output, setup_name), 'w') as f:
+ yaml.dump(setup, f)
+
+ref = None
+if args.testdata is not None:
+
+ testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
+
+ inference_test = True
+ inference_folder = os.path.join(args.output, 'inference_test')
+ os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
+
+ try:
+ ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
+ except:
+ pass
+else:
+ inference_test = False
+
+# training parameters
+batch_size = setup['training']['batch_size']
+epochs = setup['training']['epochs']
+lr = setup['training']['lr']
+lr_decay_factor = setup['training']['lr_decay_factor']
+
+# load training dataset
+data_config = setup['data']
+data = SilkEnhancementSet(setup['dataset'], **data_config)
+
+# load validation dataset if given
+if 'validation_dataset' in setup:
+ validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
+
+ validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
+
+ run_validation = True
+else:
+ run_validation = False
+
+# create model
+model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
+
+if args.initial_checkpoint is not None:
+ print(f"loading state dict from {args.initial_checkpoint}...")
+ chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
+ model.load_state_dict(chkpt['state_dict'])
+
+# set compute device
+if type(args.device) == type(None):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+else:
+ device = torch.device(args.device)
+
+# push model to device
+model.to(device)
+
+# dataloader
+dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
+
+# optimizer is introduced to trainable parameters
+parameters = [p for p in model.parameters() if p.requires_grad]
+optimizer = torch.optim.Adam(parameters, lr=lr)
+
+# learning rate scheduler
+scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
+
+# loss
+w_l1 = setup['training']['loss']['w_l1']
+w_lm = setup['training']['loss']['w_lm']
+w_slm = setup['training']['loss']['w_slm']
+w_sc = setup['training']['loss']['w_sc']
+w_logmel = setup['training']['loss']['w_logmel']
+w_wsc = setup['training']['loss']['w_wsc']
+w_xcorr = setup['training']['loss']['w_xcorr']
+w_sxcorr = setup['training']['loss']['w_sxcorr']
+w_l2 = setup['training']['loss']['w_l2']
+
+w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
+
+stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
+logmelloss = MRLogMelLoss().to(device)
+
+def xcorr_loss(y_true, y_pred):
+ dims = list(range(1, len(y_true.shape)))
+
+ loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
+
+ return torch.mean(loss)
+
+def td_l2_norm(y_true, y_pred):
+ dims = list(range(1, len(y_true.shape)))
+
+ loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
+
+ return loss.mean()
+
+def td_l1(y_true, y_pred, pow=0):
+ dims = list(range(1, len(y_true.shape)))
+ tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
+
+ return torch.mean(tmp)
+
+def criterion(x, y):
+
+ return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
+
+
+
+# model checkpoint
+checkpoint = {
+ 'setup' : setup,
+ 'state_dict' : model.state_dict(),
+ 'loss' : -1
+}
+
+
+
+
+if not args.no_redirect:
+ print(f"re-directing output to {os.path.join(args.output, output_file)}")
+ sys.stdout = open(os.path.join(args.output, output_file), "w")
+
+print("summary:")
+
+print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
+if hasattr(model, 'flop_count'):
+ print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
+
+if ref is not None:
+ noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
+ initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
+ print(f"initial MOS (PESQ): {initial_mos}")
+
+best_loss = 1e9
+
+for ep in range(1, epochs + 1):
+ print(f"training epoch {ep}...")
+ new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
+
+
+ # save checkpoint
+ checkpoint['state_dict'] = model.state_dict()
+ checkpoint['loss'] = new_loss
+
+ if run_validation:
+ print("running validation...")
+ validation_loss = evaluate(model, criterion, validation_dataloader, device)
+ checkpoint['validation_loss'] = validation_loss
+
+ if validation_loss < best_loss:
+ torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
+ best_loss = validation_loss
+
+ if inference_test:
+ print("running inference test...")
+ out = model.process(testsignal, features, periods, numbits).cpu().numpy()
+ wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
+ if ref is not None:
+ mos = pesq.pesq(16000, ref, out, mode='wb')
+ print(f"MOS (PESQ): {mos}")
+
+
+ torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
+ torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
+
+
+ print()
+
+print('Done')
diff --git a/dnn/torch/osce/utils/complexity.py b/dnn/torch/osce/utils/complexity.py
new file mode 100644
index 00000000..79de22c5
--- /dev/null
+++ b/dnn/torch/osce/utils/complexity.py
@@ -0,0 +1,35 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+def _conv1d_flop_count(layer, rate):
+ return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
+
+
+def _dense_flop_count(layer, rate):
+ return 2 * ((layer.in_features + 1) * layer.out_features * rate ) \ No newline at end of file
diff --git a/dnn/torch/osce/utils/endoscopy.py b/dnn/torch/osce/utils/endoscopy.py
new file mode 100644
index 00000000..141447e2
--- /dev/null
+++ b/dnn/torch/osce/utils/endoscopy.py
@@ -0,0 +1,234 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+""" module for inspecting models during inference """
+
+import os
+
+import yaml
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+
+import torch
+import numpy as np
+
+# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
+_state = dict()
+_folder = 'endoscopy'
+
+def get_gru_gates(gru, input, state):
+ hidden_size = gru.hidden_size
+
+ direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
+ recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
+
+ # reset gate
+ start, stop = 0 * hidden_size, 1 * hidden_size
+ reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
+
+ # update gate
+ start, stop = 1 * hidden_size, 2 * hidden_size
+ update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
+
+ # new gate
+ start, stop = 2 * hidden_size, 3 * hidden_size
+ new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
+
+ return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
+
+
+def init(folder='endoscopy'):
+ """ sets up output folder for endoscopy data """
+
+ global _folder
+ _folder = folder
+
+ if not os.path.exists(folder):
+ os.makedirs(folder)
+ else:
+ print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
+
+def write_data(key, data, fs):
+ """ appends data to previous data written under key """
+
+ global _state
+
+ # convert to numpy if torch.Tensor is given
+ if isinstance(data, torch.Tensor):
+ data = data.detach().numpy()
+
+ if not key in _state:
+ _state[key] = {
+ 'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
+ 'fs' : fs,
+ 'dim' : tuple(data.shape),
+ 'dtype' : str(data.dtype)
+ }
+
+ with open(os.path.join(_folder, key + '.yml'), 'w') as f:
+ f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
+ else:
+ if _state[key]['fs'] != fs:
+ raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
+ if _state[key]['dtype'] != str(data.dtype):
+ raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
+ if _state[key]['dim'] != tuple(data.shape):
+ raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
+
+ _state[key]['fid'].write(data.tobytes())
+
+def close(folder='endoscopy'):
+ """ clean up """
+ for key in _state.keys():
+ _state[key]['fid'].close()
+
+
+def read_data(folder='endoscopy'):
+ """ retrieves written data as numpy arrays """
+
+
+ keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
+
+ return_dict = dict()
+
+ for key in keys:
+ with open(os.path.join(folder, key + '.yml'), 'r') as f:
+ value = yaml.load(f.read(), yaml.FullLoader)
+
+ with open(os.path.join(folder, key + '.bin'), 'rb') as f:
+ data = np.frombuffer(f.read(), dtype=value['dtype'])
+
+ value['data'] = data.reshape((-1,) + value['dim'])
+
+ return_dict[key] = value
+
+ return return_dict
+
+def get_best_reshape(shape, target_ratio=1):
+ """ calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
+
+ if len(shape) > 1:
+ pixel_count = 1
+ for s in shape:
+ pixel_count *= s
+ else:
+ pixel_count = shape[0]
+
+ if pixel_count == 1:
+ return (1,)
+
+ num_columns = int((pixel_count / target_ratio)**.5)
+
+ while (pixel_count % num_columns):
+ num_columns -= 1
+
+ num_rows = pixel_count // num_columns
+
+ return (num_rows, num_columns)
+
+def get_type_and_shape(shape):
+
+ # can happen if data is one dimensional
+ if len(shape) == 0:
+ shape = (1,)
+
+ # calculate pixel count
+ if len(shape) > 1:
+ pixel_count = 1
+ for s in shape:
+ pixel_count *= s
+ else:
+ pixel_count = shape[0]
+
+ if pixel_count == 1:
+ return 'plot', (1, )
+
+ # stay with shape if already 2-dimensional
+ if len(shape) == 2:
+ if (shape[0] != pixel_count) or (shape[1] != pixel_count):
+ return 'image', shape
+
+ return 'image', get_best_reshape(shape)
+
+def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
+
+ # determine plot setup
+ num_keys = len(data.keys())
+
+ num_rows = int((num_keys * 3/4) ** .5)
+
+ num_cols = (num_keys + num_rows - 1) // num_rows
+
+ fig, axs = plt.subplots(num_rows, num_cols)
+ fig.set_size_inches(num_cols * 5, num_rows * 5)
+
+ display = dict()
+
+ fs_max = max([val['fs'] for val in data.values()])
+
+ num_samples = max([val['data'].shape[0] for val in data.values()])
+
+ keys = sorted(data.keys())
+
+ # inspect data
+ for i, key in enumerate(keys):
+ axs[i // num_cols, i % num_cols].title.set_text(key)
+
+ display[key] = dict()
+
+ display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
+ display[key]['down_factor'] = data[key]['fs'] / fs_max
+
+ start_index = max(start_index, half_signal_window_length)
+ while stop_index < 0:
+ stop_index += num_samples
+
+ stop_index = min(stop_index, num_samples - half_signal_window_length)
+
+ # actual plotting
+ frames = []
+ for index in range(start_index, stop_index):
+ ims = []
+ for i, key in enumerate(keys):
+ feature_index = int(round(index * display[key]['down_factor']))
+
+ if display[key]['type'] == 'plot':
+ ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
+
+ elif display[key]['type'] == 'image':
+ ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
+
+ frames.append(ims)
+
+ ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
+
+ if not filename.endswith('.mp4'):
+ filename += '.mp4'
+
+ ani.save(filename) \ No newline at end of file
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
new file mode 100644
index 00000000..b146240e
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
@@ -0,0 +1,236 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.endoscopy import write_data
+
+class LimitedAdaptiveComb1d(nn.Module):
+ COUNTER = 1
+
+ def __init__(self,
+ kernel_size,
+ feature_dim,
+ frame_size=160,
+ overlap_size=40,
+ use_bias=True,
+ padding=None,
+ max_lag=256,
+ name=None,
+ gain_limit_db=10,
+ global_gain_limits_db=[-6, 6],
+ norm_p=2):
+ """
+
+ Parameters:
+ -----------
+
+ feature_dim : int
+ dimension of features from which kernels, biases and gains are computed
+
+ frame_size : int, optional
+ frame size, defaults to 160
+
+ overlap_size : int, optional
+ overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
+
+ use_bias : bool, optional
+ if true, biases will be added to output channels. Defaults to True
+
+ padding : List[int, int], optional
+ left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
+
+ max_lag : int, optional
+ maximal pitch lag, defaults to 256
+
+ have_a0 : bool, optional
+ If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
+
+ name: str or None, optional
+ specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
+
+ """
+
+ super(LimitedAdaptiveComb1d, self).__init__()
+
+ self.in_channels = 1
+ self.out_channels = 1
+ self.feature_dim = feature_dim
+ self.kernel_size = kernel_size
+ self.frame_size = frame_size
+ self.overlap_size = overlap_size
+ self.use_bias = use_bias
+ self.max_lag = max_lag
+ self.limit_db = gain_limit_db
+ self.norm_p = norm_p
+
+ if name is None:
+ self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
+ LimitedAdaptiveComb1d.COUNTER += 1
+ else:
+ self.name = name
+
+ # network for generating convolution weights
+ self.conv_kernel = nn.Linear(feature_dim, kernel_size)
+
+ if self.use_bias:
+ self.conv_bias = nn.Linear(feature_dim,1)
+
+ # comb filter gain
+ self.filter_gain = nn.Linear(feature_dim, 1)
+ self.log_gain_limit = gain_limit_db * 0.11512925464970229
+ with torch.no_grad():
+ self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
+
+ self.global_filter_gain = nn.Linear(feature_dim, 1)
+ log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
+ self.filter_gain_a = (log_max - log_min) / 2
+ self.filter_gain_b = (log_max + log_min) / 2
+
+ if type(padding) == type(None):
+ self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
+ else:
+ self.padding = padding
+
+ self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
+
+ def forward(self, x, features, lags, debug=False):
+ """ adaptive 1d convolution
+
+
+ Parameters:
+ -----------
+ x : torch.tensor
+ input signal of shape (batch_size, in_channels, num_samples)
+
+ feathres : torch.tensor
+ frame-wise features of shape (batch_size, num_frames, feature_dim)
+
+ lags: torch.LongTensor
+ frame-wise lags for comb-filtering
+
+ """
+
+ batch_size = x.size(0)
+ num_frames = features.size(1)
+ num_samples = x.size(2)
+ frame_size = self.frame_size
+ overlap_size = self.overlap_size
+ kernel_size = self.kernel_size
+ win1 = torch.flip(self.overlap_win, [0])
+ win2 = self.overlap_win
+
+ if num_samples // self.frame_size != num_frames:
+ raise ValueError('non matching sizes in AdaptiveConv1d.forward')
+
+ conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
+ conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
+
+ if self.use_bias:
+ conv_biases = self.conv_bias(features).permute(0, 2, 1)
+
+ conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
+ # calculate gains
+ global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
+
+ if debug and batch_size == 1:
+ key = self.name + "_gains"
+ write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+ key = self.name + "_kernels"
+ write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+ key = self.name + "_lags"
+ write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+ key = self.name + "_global_conv_gains"
+ write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+
+
+ # frame-wise convolution with overlap-add
+ output_frames = []
+ overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
+ x = F.pad(x, self.padding)
+ x = F.pad(x, [self.max_lag, self.overlap_size])
+
+ idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
+ idx = torch.repeat_interleave(idx, batch_size, 0)
+ idx = torch.repeat_interleave(idx, self.in_channels, 1)
+
+
+ for i in range(num_frames):
+
+ cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
+ xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
+
+ new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
+
+
+ if self.use_bias:
+ new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
+
+ offset = self.max_lag + self.padding[0]
+ new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
+
+ # overlapping part
+ output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
+
+ # non-overlapping part
+ output_frames.append(new_chunk[:, :, overlap_size : frame_size])
+
+ # mem for next frame
+ overlap_mem = new_chunk[:, :, frame_size :]
+
+ # concatenate chunks
+ output = torch.cat(output_frames, dim=-1)
+
+ return output
+
+ def flop_count(self, rate):
+ frame_rate = rate / self.frame_size
+ overlap = self.overlap_size
+ overhead = overlap / self.frame_size
+
+ count = 0
+
+ # kernel computation and filtering
+ count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
+ count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
+ count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
+
+ # bias computation
+ if self.use_bias:
+ count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
+
+ # a0 computation
+ count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
+
+ # windowing
+ count += overlap * frame_rate * 3 * self.out_channels
+
+ return count
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
new file mode 100644
index 00000000..5992296f
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
@@ -0,0 +1,222 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.endoscopy import write_data
+
+class LimitedAdaptiveConv1d(nn.Module):
+ COUNTER = 1
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ feature_dim,
+ frame_size=160,
+ overlap_size=40,
+ use_bias=True,
+ padding=None,
+ name=None,
+ gain_limits_db=[-6, 6],
+ shape_gain_db=0,
+ norm_p=2):
+ """
+
+ Parameters:
+ -----------
+
+ in_channels : int
+ number of input channels
+
+ out_channels : int
+ number of output channels
+
+ feature_dim : int
+ dimension of features from which kernels, biases and gains are computed
+
+ frame_size : int
+ frame size
+
+ overlap_size : int
+ overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
+
+ use_bias : bool
+ if true, biases will be added to output channels
+
+
+ padding : List[int, int]
+
+ """
+
+ super(LimitedAdaptiveConv1d, self).__init__()
+
+
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.feature_dim = feature_dim
+ self.kernel_size = kernel_size
+ self.frame_size = frame_size
+ self.overlap_size = overlap_size
+ self.use_bias = use_bias
+ self.gain_limits_db = gain_limits_db
+ self.shape_gain_db = shape_gain_db
+ self.norm_p = norm_p
+
+ if name is None:
+ self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
+ LimitedAdaptiveConv1d.COUNTER += 1
+ else:
+ self.name = name
+
+ # network for generating convolution weights
+ self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size)
+
+ if self.use_bias:
+ self.conv_bias = nn.Linear(feature_dim, out_channels)
+
+ self.shape_gain = min(1, 10**(shape_gain_db / 20))
+
+ self.filter_gain = nn.Linear(feature_dim, out_channels)
+ log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
+ self.filter_gain_a = (log_max - log_min) / 2
+ self.filter_gain_b = (log_max + log_min) / 2
+
+ if type(padding) == type(None):
+ self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
+ else:
+ self.padding = padding
+
+ self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
+
+
+ def flop_count(self, rate):
+ frame_rate = rate / self.frame_size
+ overlap = self.overlap_size
+ overhead = overlap / self.frame_size
+
+ count = 0
+
+ # kernel computation and filtering
+ count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
+ count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
+
+ # bias computation
+ if self.use_bias:
+ count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
+
+ # gain computation
+
+ count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
+
+ # windowing
+ count += 3 * overlap * frame_rate * self.out_channels
+
+ return count
+
+ def forward(self, x, features, debug=False):
+ """ adaptive 1d convolution
+
+
+ Parameters:
+ -----------
+ x : torch.tensor
+ input signal of shape (batch_size, in_channels, num_samples)
+
+ feathres : torch.tensor
+ frame-wise features of shape (batch_size, num_frames, feature_dim)
+
+ """
+
+ batch_size = x.size(0)
+ num_frames = features.size(1)
+ num_samples = x.size(2)
+ frame_size = self.frame_size
+ overlap_size = self.overlap_size
+ kernel_size = self.kernel_size
+ win1 = torch.flip(self.overlap_win, [0])
+ win2 = self.overlap_win
+
+ if num_samples // self.frame_size != num_frames:
+ raise ValueError('non matching sizes in AdaptiveConv1d.forward')
+
+ conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
+
+ # normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
+ conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
+
+ # limit shape
+ id_kernels = torch.zeros_like(conv_kernels)
+ id_kernels[..., self.padding[1]] = 1
+
+ conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
+
+ if self.use_bias:
+ conv_biases = self.conv_bias(features).permute(0, 2, 1)
+
+ # calculate gains
+ conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
+ if debug and batch_size == 1:
+ key = self.name + "_gains"
+ write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+ key = self.name + "_kernels"
+ write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+
+
+ # frame-wise convolution with overlap-add
+ output_frames = []
+ overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
+ x = F.pad(x, self.padding)
+ x = F.pad(x, [0, self.overlap_size])
+
+ for i in range(num_frames):
+ xx = x[:, :, i * frame_size : (i + 1) * frame_size + kernel_size - 1 + overlap_size].reshape((1, batch_size * self.in_channels, -1))
+ new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
+
+ if self.use_bias:
+ new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
+
+ new_chunk = new_chunk * conv_gains[:, :, i : i + 1]
+
+ # overlapping part
+ output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
+
+ # non-overlapping part
+ output_frames.append(new_chunk[:, :, overlap_size : frame_size])
+
+ # mem for next frame
+ overlap_mem = new_chunk[:, :, frame_size :]
+
+ # concatenate chunks
+ output = torch.cat(output_frames, dim=-1)
+
+ return output \ No newline at end of file
diff --git a/dnn/torch/osce/utils/layers/pitch_auto_correlator.py b/dnn/torch/osce/utils/layers/pitch_auto_correlator.py
new file mode 100644
index 00000000..ef58ae8e
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/pitch_auto_correlator.py
@@ -0,0 +1,84 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class PitchAutoCorrelator(nn.Module):
+ def __init__(self,
+ frame_size=80,
+ pitch_min=32,
+ pitch_max=300,
+ radius=2):
+
+ super().__init__()
+
+ self.frame_size = frame_size
+ self.pitch_min = pitch_min
+ self.pitch_max = pitch_max
+ self.radius = radius
+
+
+ def forward(self, x, periods):
+ # x of shape (batch_size, channels, num_samples)
+ # periods of shape (batch_size, num_frames)
+
+ num_frames = periods.size(1)
+ batch_size = periods.size(0)
+ num_samples = self.frame_size * num_frames
+ channels = x.size(1)
+
+ assert num_samples == x.size(-1)
+
+ range = torch.arange(-self.radius, self.radius + 1, device=x.device)
+ idx = torch.arange(self.frame_size * num_frames, device=x.device)
+ p_up = torch.repeat_interleave(periods, self.frame_size, 1)
+ lookup = idx + self.pitch_max - p_up
+ lookup = lookup.unsqueeze(-1) + range
+ lookup = lookup.unsqueeze(1)
+
+ # padding
+ x_pad = F.pad(x, [self.pitch_max, 0])
+ x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1)
+
+ # framing
+ x_select = torch.gather(x_ext, 2, lookup)
+ x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1)
+ lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1)
+
+ # calculate auto-correlation
+ dotp = torch.sum(x_frames * lag_frames, dim=-2)
+ frame_nrg = torch.sum(x_frames * x_frames, dim=-2)
+ lag_frame_nrg = torch.sum(lag_frames * lag_frames, dim=-2)
+
+ acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9)
+
+ return acorr
diff --git a/dnn/torch/osce/utils/misc.py b/dnn/torch/osce/utils/misc.py
new file mode 100644
index 00000000..d4c03478
--- /dev/null
+++ b/dnn/torch/osce/utils/misc.py
@@ -0,0 +1,42 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+def count_parameters(model, verbose=False):
+ total = 0
+ for name, p in model.named_parameters():
+ count = torch.ones_like(p).sum().item()
+
+ if verbose:
+ print(f"{name}: {count} parameters")
+
+ total += count
+
+ return total \ No newline at end of file
diff --git a/dnn/torch/osce/utils/pitch.py b/dnn/torch/osce/utils/pitch.py
new file mode 100644
index 00000000..32b3bbf8
--- /dev/null
+++ b/dnn/torch/osce/utils/pitch.py
@@ -0,0 +1,121 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import numpy as np
+
+def hangover(lags, num_frames=10):
+ lags = lags.copy()
+ count = 0
+ last_lag = 0
+
+ for i in range(len(lags)):
+ lag = lags[i]
+
+ if lag == 0:
+ if count < num_frames:
+ lags[i] = last_lag
+ count += 1
+ else:
+ count = 0
+
+ return lags
+
+
+def smooth_pitch_lags(lags, d=2):
+
+ assert d < 4
+
+ num_silk_frames = len(lags) // 4
+
+ smoothed_lags = lags.copy()
+
+ tmp = np.arange(1, d+1)
+ kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32)
+ kernel = kernel / np.sum(kernel)
+
+ last = lags[0:d][::-1]
+ for i in range(num_silk_frames):
+ frame = lags[i * 4: (i+1) * 4]
+
+ if np.max(np.abs(frame)) == 0:
+ last = frame[4-d:]
+ continue
+
+ if i == num_silk_frames - 1:
+ next = frame[4-d:][::-1]
+ else:
+ next = lags[(i+1) * 4 : (i+1) * 4 + d]
+
+ if np.max(np.abs(next)) == 0:
+ next = frame[4-d:][::-1]
+
+ if np.max(np.abs(last)) == 0:
+ last = frame[0:d][::-1]
+
+ smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid')
+
+ smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame)
+
+ last = frame[4-d:]
+
+ return smoothed_lags
+
+def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32):
+ eps = 1e-9
+
+ lag_multiplier = 2 if add_double_lag_acorr else 1
+
+ if history is None:
+ history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype)
+
+ offset = len(history)
+
+ assert offset >= max_lag + radius
+ assert len(x) % frame_size == 0
+
+ num_frames = len(x) // frame_size
+ lags = lags.copy()
+
+ x_ext = np.concatenate((history, x), dtype=x.dtype)
+
+ d = radius
+ num_acorrs = 2 * d + 1
+ acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype)
+
+ for idx in range(num_frames):
+ lag = lags[idx].item()
+ frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size]
+
+ for k in range(lag_multiplier):
+ lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag
+ for j in range(num_acorrs):
+ past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d]
+ acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps)
+
+ return acorrs, lags \ No newline at end of file
diff --git a/dnn/torch/osce/utils/silk_features.py b/dnn/torch/osce/utils/silk_features.py
new file mode 100644
index 00000000..071a6c26
--- /dev/null
+++ b/dnn/torch/osce/utils/silk_features.py
@@ -0,0 +1,151 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+import os
+
+import numpy as np
+import torch
+
+import scipy
+
+from utils.pitch import hangover, calculate_acorr_window
+from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
+
+def spec_from_lpc(a, n_fft=128, eps=1e-9):
+ order = a.shape[-1]
+ assert order + 1 < n_fft
+
+ x = np.zeros((*a.shape[:-1], n_fft ))
+ x[..., 0] = 1
+ x[..., 1:1 + order] = -a
+
+ X = np.fft.fft(x, axis=-1)
+ X = np.abs(X[..., :n_fft//2 + 1]) ** 2
+
+ S = 1 / (X + eps)
+
+ return S
+
+def silk_feature_factory(no_pitch_value=256,
+ acorr_radius=2,
+ pitch_hangover=8,
+ num_bands_clean_spec=64,
+ num_bands_noisy_spec=18,
+ noisy_spec_scale='opus',
+ noisy_apply_dct=True,
+ add_offset=False,
+ add_double_lag_acorr=False
+ ):
+
+ w = scipy.signal.windows.cosine(320)
+ fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
+ fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
+
+ def create_features(noisy, noisy_history, lpcs, gains, ltps, periods, offsets):
+
+ periods = periods.copy()
+
+ if pitch_hangover > 0:
+ periods = hangover(periods, num_frames=pitch_hangover)
+
+ periods[periods == 0] = no_pitch_value
+
+ clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320)
+
+ if noisy_apply_dct:
+ noisy_cepstrum = np.repeat(
+ cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
+ else:
+ noisy_cepstrum = np.repeat(
+ log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
+
+ log_gains = np.log(gains + 1e-9).reshape(-1, 1)
+
+ acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
+
+ if add_offset:
+ features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains, offsets.reshape(-1, 1)), axis=-1, dtype=np.float32)
+ else:
+ features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
+
+ return features, periods.astype(np.int64)
+
+ return create_features
+
+
+
+def load_inference_data(path,
+ no_pitch_value=256,
+ skip=92,
+ preemph=0.85,
+ acorr_radius=2,
+ pitch_hangover=8,
+ num_bands_clean_spec=64,
+ num_bands_noisy_spec=18,
+ noisy_spec_scale='opus',
+ noisy_apply_dct=True,
+ add_offset=False,
+ add_double_lag_acorr=False,
+ **kwargs):
+
+ print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
+
+ lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
+ ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
+ gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
+ periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
+ num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
+ num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
+ offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
+
+ # load signal, add back delay and pre-emphasize
+ signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
+ signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
+
+ create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_offset, add_double_lag_acorr)
+
+ num_frames = min((len(signal) // 320) * 4, len(lpcs))
+ signal = signal[: num_frames * 80]
+ lpcs = lpcs[: num_frames]
+ ltps = ltps[: num_frames]
+ gains = gains[: num_frames]
+ periods = periods[: num_frames]
+ num_bits = num_bits[: num_frames // 4]
+ num_bits_smooth = num_bits[: num_frames // 4]
+ offsets = offsets[: num_frames]
+
+ numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
+
+ features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods, offsets)
+
+ if preemph > 0:
+ signal[1:] -= preemph * signal[:-1]
+
+ return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits)
diff --git a/dnn/torch/osce/utils/spec.py b/dnn/torch/osce/utils/spec.py
new file mode 100644
index 00000000..7e41d84e
--- /dev/null
+++ b/dnn/torch/osce/utils/spec.py
@@ -0,0 +1,194 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import math as m
+import numpy as np
+import scipy
+
+def erb(f):
+ return 24.7 * (4.37 * f + 1)
+
+def inv_erb(e):
+ return (e / 24.7 - 1) / 4.37
+
+def bark(f):
+ return 6 * m.asinh(f/600)
+
+def inv_bark(b):
+ return 600 * m.sinh(b / 6)
+
+
+scale_dict = {
+ 'bark': [bark, inv_bark],
+ 'erb': [erb, inv_erb]
+}
+
+def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
+
+ f0 = 0
+ num_bins = n_fft // 2 + 1
+ f1 = fs / n_fft * (num_bins - 1)
+ fstep = fs / n_fft
+
+ if scale == 'opus':
+ bins_5ms = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
+ fac = 1000 * n_fft / fs / 5
+ if num_bands != 18:
+ print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
+ num_bands = 18
+ center_bins = np.array([fac * bin for bin in bins_5ms])
+ else:
+ to_scale, from_scale = scale_dict[scale]
+
+ s0 = to_scale(f0)
+ s1 = to_scale(f1)
+
+ center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
+ center_bins = (center_freqs - f0) / fstep
+
+ if round_center_bins:
+ center_bins = np.round(center_bins)
+
+ filter_bank = np.zeros((num_bands, num_bins))
+
+ band = 0
+ for bin in range(num_bins):
+ # update band index
+ if bin > center_bins[band + 1]:
+ band += 1
+
+ # calculate filter coefficients
+ frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
+ filter_bank[band][bin] = frac
+ filter_bank[band + 1][bin] = 1 - frac
+
+ if return_upper:
+ extend = n_fft - num_bins
+ filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
+
+ if normalize:
+ filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
+
+ return filter_bank
+
+
+def compressed_log_spec(pspec):
+
+ lpspec = np.zeros_like(pspec)
+ num_bands = pspec.shape[-1]
+
+ log_max = -2
+ follow = -2
+
+ for i in range(num_bands):
+ tmp = np.log10(pspec[i] + 1e-9)
+ tmp = max(log_max, max(follow - 2.5, tmp))
+ lpspec[i] = tmp
+ log_max = max(log_max, tmp)
+ follow = max(follow - 2.5, tmp)
+
+ return lpspec
+
+def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
+ """ calculates cepstrum from SILK lpcs """
+ order = a.shape[-1]
+ assert order + 1 < n_fft
+
+ a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
+
+ x = np.zeros((*a.shape[:-1], n_fft ))
+ x[..., 0] = 1
+ x[..., 1:1 + order] = -a
+
+ X = np.fft.fft(x, axis=-1)
+ X = np.abs(X[..., :n_fft//2 + 1]) ** power
+
+ S = 1 / (X + eps)
+
+ if fb is None:
+ Sf = S
+ else:
+ Sf = np.matmul(S, fb.T)
+
+ if compress:
+ Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
+ else:
+ Sf = np.log(Sf + eps)
+
+ return Sf
+
+def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
+ """ calculates cepstrum from SILK lpcs """
+
+ Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
+
+ cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
+
+ return cepstrum
+
+
+
+def log_spectrum(x, frame_size, fb=None, window=None, power=1):
+ """ calculate cepstrum on 50% overlapping frames """
+
+ assert(2*len(x)) % frame_size == 0
+ assert frame_size % 2 == 0
+
+ n = len(x)
+ num_even = n // frame_size
+ num_odd = (n - frame_size // 2) // frame_size
+ num_bins = frame_size // 2 + 1
+
+ x_even = x[:num_even * frame_size].reshape(-1, frame_size)
+ x_odd = x[frame_size//2 : frame_size//2 + frame_size * num_odd].reshape(-1, frame_size)
+
+ x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
+ x_unfold[::2, :] = x_even
+ x_unfold[1::2, :] = x_odd
+
+ if window is not None:
+ x_unfold *= window.reshape(1, -1)
+
+ X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
+
+ if fb is not None:
+ X = np.matmul(X, fb.T)
+
+
+ return np.log(X + 1e-9)
+
+
+def cepstrum(x, frame_size, fb=None, window=None):
+ """ calculate cepstrum on 50% overlapping frames """
+
+ X = log_spectrum(x, frame_size, fb, window)
+
+ cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
+
+ return cepstrum \ No newline at end of file
diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py
new file mode 100644
index 00000000..1232710f
--- /dev/null
+++ b/dnn/torch/osce/utils/templates.py
@@ -0,0 +1,92 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+setup_dict = dict()
+
+lace_setup = {
+ 'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
+ 'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
+ 'model': {
+ 'name': 'lace',
+ 'args': [],
+ 'kwargs': {
+ 'comb_gain_limit_db': 10,
+ 'cond_dim': 128,
+ 'conv_gain_limits_db': [-12, 12],
+ 'global_gain_limits_db': [-6, 6],
+ 'hidden_feature_dim': 96,
+ 'kernel_size': 15,
+ 'num_features': 93,
+ 'numbits_embedding_dim': 8,
+ 'numbits_range': [50, 650],
+ 'partial_lookahead': True,
+ 'pitch_embedding_dim': 64,
+ 'pitch_max': 300,
+ 'preemph': 0.85,
+ 'skip': 91
+ }
+ },
+ 'data': {
+ 'frames_per_sample': 100,
+ 'no_pitch_value': 7,
+ 'preemph': 0.85,
+ 'skip': 91,
+ 'pitch_hangover': 8,
+ 'acorr_radius': 2,
+ 'num_bands_clean_spec': 64,
+ 'num_bands_noisy_spec': 18,
+ 'noisy_spec_scale': 'opus',
+ 'pitch_hangover': 8,
+ },
+ 'training': {
+ 'batch_size': 256,
+ 'lr': 5.e-4,
+ 'lr_decay_factor': 2.5e-5,
+ 'epochs': 50,
+ 'frames_per_sample': 50,
+ 'loss': {
+ 'w_l1': 0,
+ 'w_lm': 0,
+ 'w_logmel': 0,
+ 'w_sc': 0,
+ 'w_wsc': 0,
+ 'w_xcorr': 0,
+ 'w_sxcorr': 1,
+ 'w_l2': 10,
+ 'w_slm': 2
+ }
+ }
+}
+
+
+
+setup_dict = {
+ 'lace': lace_setup,
+}