diff options
Diffstat (limited to 'dnn/torch/osce/stndrd/evaluation')
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/create_input_data.sh | 25 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/env.rc | 7 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/evaluate.py | 113 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py | 330 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/make_boxplots.py | 116 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py | 109 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/make_tables.py | 124 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py | 121 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/moc.py | 182 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/moc2.py | 190 | ||||
-rwxr-xr-x | dnn/torch/osce/stndrd/evaluation/process_dataset.sh | 98 | ||||
-rw-r--r-- | dnn/torch/osce/stndrd/evaluation/run_nomad.py | 138 |
12 files changed, 1553 insertions, 0 deletions
diff --git a/dnn/torch/osce/stndrd/evaluation/create_input_data.sh b/dnn/torch/osce/stndrd/evaluation/create_input_data.sh new file mode 100644 index 00000000..54bacb88 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/create_input_data.sh @@ -0,0 +1,25 @@ +#!/bin/bash + + +INPUT="dataset/LibriSpeech" +OUTPUT="testdata" +OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched" +BITRATES=( 6000 7500 ) # 9000 12000 15000 18000 24000 32000 ) + + +mkdir -p $OUTPUT + +for fn in $(find $INPUT -name "*.wav") +do + name=$(basename ${fn%*.wav}) + sox $fn -r 16000 -b 16 -e signed-integer ${OUTPUT}/tmp.raw + for br in ${BITRATES[@]} + do + folder=${OUTPUT}/"${name}_${br}.se" + echo "creating ${folder}..." + mkdir -p $folder + cp ${OUTPUT}/tmp.raw ${folder}/clean.s16 + (cd ${folder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16) + done + rm -f ${OUTPUT}/tmp.raw +done diff --git a/dnn/torch/osce/stndrd/evaluation/env.rc b/dnn/torch/osce/stndrd/evaluation/env.rc new file mode 100644 index 00000000..f1266b6f --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/env.rc @@ -0,0 +1,7 @@ +#!/bin/bash + +export PYTHON=/home/ubuntu/opt/miniconda3/envs/torch/bin/python +export LACE="/local/experiments/ietf_enhancement_studies/checkpoints/lace_checkpoint.pth" +export NOLACE="/local/experiments/ietf_enhancement_studies/checkpoints/nolace_checkpoint.pth" +export TESTMODEL="/local/experiments/ietf_enhancement_studies/opus/dnn/torch/osce/test_model.py" +export OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
\ No newline at end of file diff --git a/dnn/torch/osce/stndrd/evaluation/evaluate.py b/dnn/torch/osce/stndrd/evaluation/evaluate.py new file mode 100644 index 00000000..54700dbe --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/evaluate.py @@ -0,0 +1,113 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + + +from scipy.io import wavfile +from pesq import pesq +import numpy as np +from moc import compare +from moc2 import compare as compare2 +#from warpq import compute_WAPRQ as warpq +from lace_loss_metric import compare as laceloss_compare + + +parser = argparse.ArgumentParser() +parser.add_argument('folder', type=str, help='folder with processed items') +parser.add_argument('metric', type=str, choices=['pesq', 'moc', 'moc2', 'laceloss'], help='metric to be used for evaluation') + + +def get_bitrates(folder): + with open(os.path.join(folder, 'bitrates.txt')) as f: + x = f.read() + + bitrates = [int(y) for y in x.rstrip('\n').split()] + + return bitrates + +def get_itemlist(folder): + with open(os.path.join(folder, 'items.txt')) as f: + lines = f.readlines() + + items = [x.split()[0] for x in lines] + + return items + + +def process_item(folder, item, bitrate, metric): + fs, x_clean = wavfile.read(os.path.join(folder, 'clean', f"{item}_{bitrate}_clean.wav")) + fs, x_opus = wavfile.read(os.path.join(folder, 'opus', f"{item}_{bitrate}_opus.wav")) + fs, x_lace = wavfile.read(os.path.join(folder, 'lace', f"{item}_{bitrate}_lace.wav")) + fs, x_nolace = wavfile.read(os.path.join(folder, 'nolace', f"{item}_{bitrate}_nolace.wav")) + + x_clean = x_clean.astype(np.float32) / 2**15 + x_opus = x_opus.astype(np.float32) / 2**15 + x_lace = x_lace.astype(np.float32) / 2**15 + x_nolace = x_nolace.astype(np.float32) / 2**15 + + if metric == 'pesq': + result = [pesq(fs, x_clean, x_opus), pesq(fs, x_clean, x_lace), pesq(fs, x_clean, x_nolace)] + elif metric =='moc': + result = [compare(x_clean, x_opus), compare(x_clean, x_lace), compare(x_clean, x_nolace)] + elif metric =='moc2': + result = [compare2(x_clean, x_opus), compare2(x_clean, x_lace), compare2(x_clean, x_nolace)] + # elif metric == 'warpq': + # result = [warpq(x_clean, x_opus), warpq(x_clean, x_lace), warpq(x_clean, x_nolace)] + elif metric == 'laceloss': + result = [laceloss_compare(x_clean, x_opus), laceloss_compare(x_clean, x_lace), laceloss_compare(x_clean, x_nolace)] + else: + raise ValueError(f'unknown metric {metric}') + + return result + +def process_bitrate(folder, items, bitrate, metric): + results = np.zeros((len(items), 3)) + + for i, item in enumerate(items): + results[i, :] = np.array(process_item(folder, item, bitrate, metric)) + + return results + + +if __name__ == "__main__": + args = parser.parse_args() + + items = get_itemlist(args.folder) + bitrates = get_bitrates(args.folder) + + results = dict() + for br in bitrates: + print(f"processing bitrate {br}...") + results[br] = process_bitrate(args.folder, items, br, args.metric) + + np.save(os.path.join(args.folder, f'results_{args.metric}.npy'), results) + + print("Done.") diff --git a/dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py b/dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py new file mode 100644 index 00000000..b0790585 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py @@ -0,0 +1,330 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +"""STFT-based Loss modules.""" + +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np +import torchaudio + + +def get_window(win_name, win_length, *args, **kwargs): + window_dict = { + 'bartlett_window' : torch.bartlett_window, + 'blackman_window' : torch.blackman_window, + 'hamming_window' : torch.hamming_window, + 'hann_window' : torch.hann_window, + 'kaiser_window' : torch.kaiser_window + } + + if not win_name in window_dict: + raise ValueError() + + return window_dict[win_name](win_length, *args, **kwargs) + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + """ + + win = get_window(window, win_length).to(x.device) + x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True) + + + return torch.clamp(torch.abs(x_stft), min=1e-7) + +def spectral_convergence_loss(Y_true, Y_pred): + dims=list(range(1, len(Y_pred.shape))) + return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6)) + + +def log_magnitude_loss(Y_true, Y_pred): + Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15) + Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15) + + return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs)) + +def spectral_xcorr_loss(Y_true, Y_pred): + Y_true = Y_true.abs() + Y_pred = Y_pred.abs() + dims=list(range(1, len(Y_pred.shape))) + xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9) + + return 1 - xcorr.mean() + + + +class MRLogMelLoss(nn.Module): + def __init__(self, + fft_sizes=[512, 256, 128, 64], + overlap=0.5, + fs=16000, + n_mels=18 + ): + + self.fft_sizes = fft_sizes + self.overlap = overlap + self.fs = fs + self.n_mels = n_mels + + super().__init__() + + self.mel_specs = [] + for fft_size in fft_sizes: + hop_size = int(round(fft_size * (1 - self.overlap))) + + n_mels = self.n_mels + if fft_size < 128: + n_mels //= 2 + + self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels)) + + for i, mel_spec in enumerate(self.mel_specs): + self.add_module(f'mel_spec_{i+1}', mel_spec) + + def forward(self, y_true, y_pred): + + loss = torch.zeros(1, device=y_true.device) + + for mel_spec in self.mel_specs: + Y_true = mel_spec(y_true) + Y_pred = mel_spec(y_pred) + loss = loss + log_magnitude_loss(Y_true, Y_pred) + + loss = loss / len(self.mel_specs) + + return loss + +def create_weight_matrix(num_bins, bins_per_band=10): + m = torch.zeros((num_bins, num_bins), dtype=torch.float32) + + r0 = bins_per_band // 2 + r1 = bins_per_band - r0 + + for i in range(num_bins): + i0 = max(i - r0, 0) + j0 = min(i + r1, num_bins) + + m[i, i0: j0] += 1 + + if i < r0: + m[i, :r0 - i] += 1 + + if i > num_bins - r1: + m[i, num_bins - r1 - i:] += 1 + + return m / bins_per_band + +def weighted_spectral_convergence(Y_true, Y_pred, w): + + # calculate sfm based weights + logY = torch.log(torch.abs(Y_true) + 1e-9) + Y = torch.abs(Y_true) + + avg_logY = torch.matmul(logY.transpose(1, 2), w) + avg_Y = torch.matmul(Y.transpose(1, 2), w) + + sfm = torch.exp(avg_logY) / (avg_Y + 1e-9) + + weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2) + + loss = torch.mean( + torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2]) + / (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9) + ) + + return loss + +def gen_filterbank(N, Fs=16000): + in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:] + out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None] + #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73. + ERB_N = 24.7 + .108*in_freq + delta = np.abs(in_freq-out_freq)/ERB_N + center = (delta<.5).astype('float32') + R = -12*center*delta**2 + (1-center)*(3-12*delta) + RE = 10.**(R/10.) + norm = np.sum(RE, axis=1) + RE = RE/norm[:, np.newaxis] + return torch.from_numpy(RE) + +def smooth_log_mag(Y_true, Y_pred, filterbank): + Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true)) + Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred)) + + loss = torch.abs( + torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9) + ) + + loss = loss.mean() + + return loss + +class MRSTFTLoss(nn.Module): + def __init__(self, + fft_sizes=[2048, 1024, 512, 256, 128, 64], + overlap=0.5, + window='hann_window', + fs=16000, + log_mag_weight=0, + sc_weight=0, + wsc_weight=0, + smooth_log_mag_weight=2, + sxcorr_weight=1): + super().__init__() + + self.fft_sizes = fft_sizes + self.overlap = overlap + self.window = window + self.log_mag_weight = log_mag_weight + self.sc_weight = sc_weight + self.wsc_weight = wsc_weight + self.smooth_log_mag_weight = smooth_log_mag_weight + self.sxcorr_weight = sxcorr_weight + self.fs = fs + + # weights for SFM weighted spectral convergence loss + self.wsc_weights = torch.nn.ParameterDict() + for fft_size in fft_sizes: + width = min(11, int(1000 * fft_size / self.fs + .5)) + width += width % 2 + self.wsc_weights[str(fft_size)] = torch.nn.Parameter( + create_weight_matrix(fft_size // 2 + 1, width), + requires_grad=False + ) + + # filterbanks for smooth log magnitude loss + self.filterbanks = torch.nn.ParameterDict() + for fft_size in fft_sizes: + self.filterbanks[str(fft_size)] = torch.nn.Parameter( + gen_filterbank(fft_size//2), + requires_grad=False + ) + + + def __call__(self, y_true, y_pred): + + + lm_loss = torch.zeros(1, device=y_true.device) + sc_loss = torch.zeros(1, device=y_true.device) + wsc_loss = torch.zeros(1, device=y_true.device) + slm_loss = torch.zeros(1, device=y_true.device) + sxcorr_loss = torch.zeros(1, device=y_true.device) + + for fft_size in self.fft_sizes: + hop_size = int(round(fft_size * (1 - self.overlap))) + win_size = fft_size + + Y_true = stft(y_true, fft_size, hop_size, win_size, self.window) + Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window) + + if self.log_mag_weight > 0: + lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred) + + if self.sc_weight > 0: + sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred) + + if self.wsc_weight > 0: + wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)]) + + if self.smooth_log_mag_weight > 0: + slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)]) + + if self.sxcorr_weight > 0: + sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred) + + + total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss + + self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss + + self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes) + + return total_loss + + +def td_l2_norm(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6) + + return loss.mean() + + +class LaceLoss(nn.Module): + def __init__(self): + super().__init__() + + + self.stftloss = MRSTFTLoss(log_mag_weight=0, sc_weight=0, wsc_weight=0, smooth_log_mag_weight=2, sxcorr_weight=1) + + + def forward(self, x, y): + specloss = self.stftloss(x, y) + phaseloss = td_l2_norm(x, y) + total_loss = (specloss + 10 * phaseloss) / 13 + + return total_loss + + def compare(self, x_ref, x_deg): + # trim items to same size + n = min(len(x_ref), len(x_deg)) + x_ref = x_ref[:n].copy() + x_deg = x_deg[:n].copy() + + # pre-emphasis + x_ref[1:] -= 0.85 * x_ref[:-1] + x_deg[1:] -= 0.85 * x_deg[:-1] + + device = next(iter(self.parameters())).device + + x = torch.from_numpy(x_ref).to(device) + y = torch.from_numpy(x_deg).to(device) + + with torch.no_grad(): + dist = 10 * self.forward(x, y) + + return dist.cpu().numpy().item() + + +lace_loss = LaceLoss() +device = 'cuda' if torch.cuda.is_available() else 'cpu' +lace_loss.to(device) + +def compare(x, y): + + return lace_loss.compare(x, y) diff --git a/dnn/torch/osce/stndrd/evaluation/make_boxplots.py b/dnn/torch/osce/stndrd/evaluation/make_boxplots.py new file mode 100644 index 00000000..f7ea778a --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/make_boxplots.py @@ -0,0 +1,116 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +import numpy as np +import matplotlib.pyplot as plt +from prettytable import PrettyTable +from matplotlib.patches import Patch + +parser = argparse.ArgumentParser() +parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics') +parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all') +parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder') + +def load_data(folder): + data = dict() + + if os.path.isfile(os.path.join(folder, 'results_moc.npy')): + data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_moc2.npy')): + data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_pesq.npy')): + data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_warpq.npy')): + data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_nomad.npy')): + data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')): + data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item() + + return data + +def plot_data(filename, data, title=None): + compare_dict = dict() + for br in data.keys(): + compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0] + compare_dict[f'LACE {br/1000:.1f} kb/s'] = data[br][:, 1] + compare_dict[f'NoLACE {br/1000:.1f} kb/s'] = data[br][:, 2] + + plt.rcParams.update({ + "text.usetex": True, + "font.family": "Helvetica", + "font.size": 32 + }) + + black = '#000000' + red = '#ff5745' + blue = '#007dbc' + colors = [black, red, blue] + legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'), + Patch(facecolor=colors[1], label='LACE'), + Patch(facecolor=colors[2], label='NoLACE')] + + fig, ax = plt.subplots() + fig.set_size_inches(40, 20) + bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True) + + for i, patch in enumerate(bplot['boxes']): + patch.set_facecolor(colors[i%3]) + + ax.set_xticklabels(compare_dict.keys(), rotation=290) + + if title is not None: + ax.set_title(title) + + ax.legend(handles=legend_elements) + + fig.savefig(filename, bbox_inches='tight') + +if __name__ == "__main__": + args = parser.parse_args() + data = load_data(args.folder) + + + metrics = list(data.keys()) if args.metric == 'all' else [args.metric] + folder = args.folder if args.output is None else args.output + os.makedirs(folder, exist_ok=True) + + for metric in metrics: + print(f"Plotting data for {metric} metric...") + plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper()) + + print("Done.")
\ No newline at end of file diff --git a/dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py b/dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py new file mode 100644 index 00000000..ca65aba9 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py @@ -0,0 +1,109 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +import numpy as np +import matplotlib.pyplot as plt +from prettytable import PrettyTable +from matplotlib.patches import Patch + +parser = argparse.ArgumentParser() +parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics') +parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all') +parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder') + +def load_data(folder): + data = dict() + + if os.path.isfile(os.path.join(folder, 'results_moc.npy')): + data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_pesq.npy')): + data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_warpq.npy')): + data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_nomad.npy')): + data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')): + data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item() + + return data + +def plot_data(filename, data, title=None): + compare_dict = dict() + for br in data.keys(): + compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0] + compare_dict[f'LACE (MOC only) {br/1000:.1f} kb/s'] = data[br][:, 1] + compare_dict[f'LACE (MOC + TD) {br/1000:.1f} kb/s'] = data[br][:, 2] + + plt.rcParams.update({ + "text.usetex": True, + "font.family": "Helvetica", + "font.size": 32 + }) + colors = ['pink', 'lightblue', 'lightgreen'] + legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'), + Patch(facecolor=colors[1], label='MOC loss only'), + Patch(facecolor=colors[2], label='MOC + TD loss')] + + fig, ax = plt.subplots() + fig.set_size_inches(40, 20) + bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True) + + for i, patch in enumerate(bplot['boxes']): + patch.set_facecolor(colors[i%3]) + + ax.set_xticklabels(compare_dict.keys(), rotation=290) + + if title is not None: + ax.set_title(title) + + ax.legend(handles=legend_elements) + + fig.savefig(filename, bbox_inches='tight') + +if __name__ == "__main__": + args = parser.parse_args() + data = load_data(args.folder) + + + metrics = list(data.keys()) if args.metric == 'all' else [args.metric] + folder = args.folder if args.output is None else args.output + os.makedirs(folder, exist_ok=True) + + for metric in metrics: + print(f"Plotting data for {metric} metric...") + plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper()) + + print("Done.")
\ No newline at end of file diff --git a/dnn/torch/osce/stndrd/evaluation/make_tables.py b/dnn/torch/osce/stndrd/evaluation/make_tables.py new file mode 100644 index 00000000..56080127 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/make_tables.py @@ -0,0 +1,124 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +import numpy as np +import matplotlib.pyplot as plt +from prettytable import PrettyTable +from matplotlib.patches import Patch + +parser = argparse.ArgumentParser() +parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics') +parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all') +parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder') + +def load_data(folder): + data = dict() + + if os.path.isfile(os.path.join(folder, 'results_moc.npy')): + data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_moc2.npy')): + data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_pesq.npy')): + data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_warpq.npy')): + data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_nomad.npy')): + data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')): + data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item() + + return data + +def make_table(filename, data, title=None): + + # mean values + tbl = PrettyTable() + tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE'] + for br in data.keys(): + opus = data[br][:, 0] + lace = data[br][:, 1] + nolace = data[br][:, 2] + tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"]) + + with open(filename + ".txt", "w") as f: + f.write(str(tbl)) + + with open(filename + ".html", "w") as f: + f.write(tbl.get_html_string()) + + with open(filename + ".csv", "w") as f: + f.write(tbl.get_csv_string()) + + print(tbl) + + +def make_diff_table(filename, data, title=None): + + # mean values + tbl = PrettyTable() + tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus'] + for br in data.keys(): + opus = data[br][:, 0] + lace = data[br][:, 1] - opus + nolace = data[br][:, 2] - opus + tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"]) + + with open(filename + ".txt", "w") as f: + f.write(str(tbl)) + + with open(filename + ".html", "w") as f: + f.write(tbl.get_html_string()) + + with open(filename + ".csv", "w") as f: + f.write(tbl.get_csv_string()) + + print(tbl) + +if __name__ == "__main__": + args = parser.parse_args() + data = load_data(args.folder) + + metrics = list(data.keys()) if args.metric == 'all' else [args.metric] + folder = args.folder if args.output is None else args.output + os.makedirs(folder, exist_ok=True) + + for metric in metrics: + print(f"Plotting data for {metric} metric...") + make_table(os.path.join(folder, f"table_{metric}"), data[metric]) + make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric]) + + print("Done.")
\ No newline at end of file diff --git a/dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py b/dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py new file mode 100644 index 00000000..37718068 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py @@ -0,0 +1,121 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +import numpy as np +import matplotlib.pyplot as plt +from prettytable import PrettyTable +from matplotlib.patches import Patch + +parser = argparse.ArgumentParser() +parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics') +parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all') +parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder') + +def load_data(folder): + data = dict() + + if os.path.isfile(os.path.join(folder, 'results_moc.npy')): + data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_pesq.npy')): + data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_warpq.npy')): + data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_nomad.npy')): + data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item() + + if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')): + data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item() + + return data + +def make_table(filename, data, title=None): + + # mean values + tbl = PrettyTable() + tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE'] + for br in data.keys(): + opus = data[br][:, 0] + lace = data[br][:, 1] + nolace = data[br][:, 2] + tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"]) + + with open(filename + ".txt", "w") as f: + f.write(str(tbl)) + + with open(filename + ".html", "w") as f: + f.write(tbl.get_html_string()) + + with open(filename + ".csv", "w") as f: + f.write(tbl.get_csv_string()) + + print(tbl) + + +def make_diff_table(filename, data, title=None): + + # mean values + tbl = PrettyTable() + tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus'] + for br in data.keys(): + opus = data[br][:, 0] + lace = data[br][:, 1] - opus + nolace = data[br][:, 2] - opus + tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"]) + + with open(filename + ".txt", "w") as f: + f.write(str(tbl)) + + with open(filename + ".html", "w") as f: + f.write(tbl.get_html_string()) + + with open(filename + ".csv", "w") as f: + f.write(tbl.get_csv_string()) + + print(tbl) + +if __name__ == "__main__": + args = parser.parse_args() + data = load_data(args.folder) + + metrics = list(data.keys()) if args.metric == 'all' else [args.metric] + folder = args.folder if args.output is None else args.output + os.makedirs(folder, exist_ok=True) + + for metric in metrics: + print(f"Plotting data for {metric} metric...") + make_table(os.path.join(folder, f"table_{metric}"), data[metric]) + make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric]) + + print("Done.")
\ No newline at end of file diff --git a/dnn/torch/osce/stndrd/evaluation/moc.py b/dnn/torch/osce/stndrd/evaluation/moc.py new file mode 100644 index 00000000..bf004de9 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/moc.py @@ -0,0 +1,182 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import numpy as np +import scipy.signal + +def compute_vad_mask(x, fs, stop_db=-70): + + frame_length = (fs + 49) // 50 + x = x[: frame_length * (len(x) // frame_length)] + + frames = x.reshape(-1, frame_length) + frame_energy = np.sum(frames ** 2, axis=1) + frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same') + + max_threshold = frame_energy.max() * 10 ** (stop_db/20) + vactive = np.ones_like(frames) + vactive[frame_energy_smooth < max_threshold, :] = 0 + vactive = vactive.reshape(-1) + + filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1)) + filter = filter / filter.sum() + + mask = np.convolve(vactive, filter, mode='same') + + return x, mask + +def convert_mask(mask, num_frames, frame_size=160, hop_size=40): + num_samples = frame_size + (num_frames - 1) * hop_size + if len(mask) < num_samples: + mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype) + else: + mask = mask[:num_samples] + + new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)]) + + return new_mask + +def power_spectrum(x, window_size=160, hop_size=40, window='hamming'): + num_spectra = (len(x) - window_size - hop_size) // hop_size + window = scipy.signal.get_window(window, window_size) + N = window_size // 2 + + frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window + psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2 + + return psd + + +def frequency_mask(num_bands, up_factor, down_factor): + + up_mask = np.zeros((num_bands, num_bands)) + down_mask = np.zeros((num_bands, num_bands)) + + for i in range(num_bands): + up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1) + down_mask[i, i :] = down_factor ** np.arange(num_bands - i) + + return down_mask @ up_mask + + +def rect_fb(band_limits, num_bins=None): + num_bands = len(band_limits) - 1 + if num_bins is None: + num_bins = band_limits[-1] + + fb = np.zeros((num_bands, num_bins)) + for i in range(num_bands): + fb[i, band_limits[i]:band_limits[i+1]] = 1 + + return fb + + +def compare(x, y, apply_vad=False): + """ Modified version of opus_compare for 16 kHz mono signals + + Args: + x (np.ndarray): reference input signal scaled to [-1, 1] + y (np.ndarray): test signal scaled to [-1, 1] + + Returns: + float: perceptually weighted error + """ + # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz + band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75] + num_bands = len(band_limits) - 1 + fb = rect_fb(band_limits, num_bins=81) + + # trim samples to same size + num_samples = min(len(x), len(y)) + x = x[:num_samples] * 2**15 + y = y[:num_samples] * 2**15 + + psd_x = power_spectrum(x) + 100000 + psd_y = power_spectrum(y) + 100000 + + num_frames = psd_x.shape[0] + + # average band energies + be_x = (psd_x @ fb.T) / np.sum(fb, axis=1) + + # frequecy masking + f_mask = frequency_mask(num_bands, 0.1, 0.03) + mask_x = be_x @ f_mask.T + + # temporal masking + for i in range(1, num_frames): + mask_x[i, :] += 0.5 * mask_x[i-1, :] + + # apply mask + masked_psd_x = psd_x + 0.1 * (mask_x @ fb) + masked_psd_y = psd_y + 0.1 * (mask_x @ fb) + + # 2-frame average + masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1] + masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1] + + # distortion metric + re = masked_psd_y / masked_psd_x + im = np.log(re) ** 2 + Eb = ((im @ fb.T) / np.sum(fb, axis=1)) + Ef = np.mean(Eb , axis=1) + + if apply_vad: + _, mask = compute_vad_mask(x, 16000) + mask = convert_mask(mask, Ef.shape[0]) + else: + mask = np.ones_like(Ef) + + err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6) + + return float(err) + +if __name__ == "__main__": + import argparse + from scipy.io import wavfile + + parser = argparse.ArgumentParser() + parser.add_argument('ref', type=str, help='reference wav file') + parser.add_argument('deg', type=str, help='degraded wav file') + parser.add_argument('--apply-vad', action='store_true') + args = parser.parse_args() + + + fs1, x = wavfile.read(args.ref) + fs2, y = wavfile.read(args.deg) + + if max(fs1, fs2) != 16000: + raise ValueError('error: encountered sampling frequency diffrent from 16kHz') + + x = x.astype(np.float32) / 2**15 + y = y.astype(np.float32) / 2**15 + + err = compare(x, y, apply_vad=args.apply_vad) + + print(f"MOC: {err}") diff --git a/dnn/torch/osce/stndrd/evaluation/moc2.py b/dnn/torch/osce/stndrd/evaluation/moc2.py new file mode 100644 index 00000000..7e155f01 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/moc2.py @@ -0,0 +1,190 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import numpy as np +import scipy.signal + +def compute_vad_mask(x, fs, stop_db=-70): + + frame_length = (fs + 49) // 50 + x = x[: frame_length * (len(x) // frame_length)] + + frames = x.reshape(-1, frame_length) + frame_energy = np.sum(frames ** 2, axis=1) + frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same') + + max_threshold = frame_energy.max() * 10 ** (stop_db/20) + vactive = np.ones_like(frames) + vactive[frame_energy_smooth < max_threshold, :] = 0 + vactive = vactive.reshape(-1) + + filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1)) + filter = filter / filter.sum() + + mask = np.convolve(vactive, filter, mode='same') + + return x, mask + +def convert_mask(mask, num_frames, frame_size=160, hop_size=40): + num_samples = frame_size + (num_frames - 1) * hop_size + if len(mask) < num_samples: + mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype) + else: + mask = mask[:num_samples] + + new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)]) + + return new_mask + +def power_spectrum(x, window_size=160, hop_size=40, window='hamming'): + num_spectra = (len(x) - window_size - hop_size) // hop_size + window = scipy.signal.get_window(window, window_size) + N = window_size // 2 + + frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window + psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2 + + return psd + + +def frequency_mask(num_bands, up_factor, down_factor): + + up_mask = np.zeros((num_bands, num_bands)) + down_mask = np.zeros((num_bands, num_bands)) + + for i in range(num_bands): + up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1) + down_mask[i, i :] = down_factor ** np.arange(num_bands - i) + + return down_mask @ up_mask + + +def rect_fb(band_limits, num_bins=None): + num_bands = len(band_limits) - 1 + if num_bins is None: + num_bins = band_limits[-1] + + fb = np.zeros((num_bands, num_bins)) + for i in range(num_bands): + fb[i, band_limits[i]:band_limits[i+1]] = 1 + + return fb + + +def _compare(x, y, apply_vad=False, factor=1): + """ Modified version of opus_compare for 16 kHz mono signals + + Args: + x (np.ndarray): reference input signal scaled to [-1, 1] + y (np.ndarray): test signal scaled to [-1, 1] + + Returns: + float: perceptually weighted error + """ + # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz + band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]] + window_size = factor * 160 + hop_size = factor * 40 + num_bins = window_size // 2 + 1 + num_bands = len(band_limits) - 1 + fb = rect_fb(band_limits, num_bins=num_bins) + + # trim samples to same size + num_samples = min(len(x), len(y)) + x = x[:num_samples].copy() * 2**15 + y = y[:num_samples].copy() * 2**15 + + psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000 + psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000 + + num_frames = psd_x.shape[0] + + # average band energies + be_x = (psd_x @ fb.T) / np.sum(fb, axis=1) + + # frequecy masking + f_mask = frequency_mask(num_bands, 0.1, 0.03) + mask_x = be_x @ f_mask.T + + # temporal masking + for i in range(1, num_frames): + mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :] + + # apply mask + masked_psd_x = psd_x + 0.1 * (mask_x @ fb) + masked_psd_y = psd_y + 0.1 * (mask_x @ fb) + + # 2-frame average + masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1] + masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1] + + # distortion metric + re = masked_psd_y / masked_psd_x + #im = re - np.log(re) - 1 + im = np.log(re) ** 2 + Eb = ((im @ fb.T) / np.sum(fb, axis=1)) + Ef = np.mean(Eb ** 1, axis=1) + + if apply_vad: + _, mask = compute_vad_mask(x, 16000) + mask = convert_mask(mask, Ef.shape[0]) + else: + mask = np.ones_like(Ef) + + err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6) + + return float(err) + +def compare(x, y, apply_vad=False): + err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2) + return err + +if __name__ == "__main__": + import argparse + from scipy.io import wavfile + + parser = argparse.ArgumentParser() + parser.add_argument('ref', type=str, help='reference wav file') + parser.add_argument('deg', type=str, help='degraded wav file') + parser.add_argument('--apply-vad', action='store_true') + args = parser.parse_args() + + + fs1, x = wavfile.read(args.ref) + fs2, y = wavfile.read(args.deg) + + if max(fs1, fs2) != 16000: + raise ValueError('error: encountered sampling frequency diffrent from 16kHz') + + x = x.astype(np.float32) / 2**15 + y = y.astype(np.float32) / 2**15 + + err = compare(x, y, apply_vad=args.apply_vad) + + print(f"MOC: {err}") diff --git a/dnn/torch/osce/stndrd/evaluation/process_dataset.sh b/dnn/torch/osce/stndrd/evaluation/process_dataset.sh new file mode 100755 index 00000000..a490da93 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/process_dataset.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +if [ ! -f "$PYTHON" ] +then + echo "PYTHON variable does not link to a file. Please point it to your python executable." + exit 1 +fi + +if [ ! -f "$TESTMODEL" ] +then + echo "TESTMODEL variable does not link to a file. Please point it to your copy of test_model.py" + exit 1 +fi + +if [ ! -f "$OPUSDEMO" ] +then + echo "OPUSDEMO variable does not link to a file. Please point it to your patched version of opus_demo." + exit 1 +fi + +if [ ! -f "$LACE" ] +then + echo "LACE variable does not link to a file. Please point it to your copy of the LACE checkpoint." + exit 1 +fi + +if [ ! -f "$NOLACE" ] +then + echo "LACE variable does not link to a file. Please point it to your copy of the NOLACE checkpoint." + exit 1 +fi + +case $# in + 2) INPUT=$1; OUTPUT=$2;; + *) echo "process_dataset.sh <input folder> <output folder>"; exit 1;; +esac + +if [ -d $OUTPUT ] +then + echo "output folder $OUTPUT exists, aborting..." + exit 1 +fi + +mkdir -p $OUTPUT + +if [ "$BITRATES" == "" ] +then + BITRATES=( 6000 7500 9000 12000 15000 18000 24000 32000 ) + echo "BITRATES variable not defined. Proceeding with default bitrates ${BITRATES[@]}." +fi + + +echo "LACE=${LACE}" > ${OUTPUT}/info.txt +echo "NOLACE=${NOLACE}" >> ${OUTPUT}/info.txt + +ITEMFILE=${OUTPUT}/items.txt +BITRATEFILE=${OUTPUT}/bitrates.txt + +FPROCESSING=${OUTPUT}/processing +FCLEAN=${OUTPUT}/clean +FOPUS=${OUTPUT}/opus +FLACE=${OUTPUT}/lace +FNOLACE=${OUTPUT}/nolace + +mkdir -p $FPROCESSING $FCLEAN $FOPUS $FLACE $FNOLACE + +echo "${BITRATES[@]}" > $BITRATEFILE + +for fn in $(find $INPUT -type f -name "*.wav") +do + UUID=$(uuid) + echo "$UUID $fn" >> $ITEMFILE + PIDS=( ) + for br in ${BITRATES[@]} + do + # run opus + pfolder=${FPROCESSING}/${UUID}_${br} + mkdir -p $pfolder + sox $fn -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16 + (cd ${pfolder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16) + + # copy clean and opus + sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16 $FCLEAN/${UUID}_${br}_clean.wav + sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/noisy.s16 $FOPUS/${UUID}_${br}_opus.wav + + # run LACE + $PYTHON $TESTMODEL $pfolder $LACE $FLACE/${UUID}_${br}_lace.wav & + PIDS+=( "$!" ) + + # run NoLACE + $PYTHON $TESTMODEL $pfolder $NOLACE $FNOLACE/${UUID}_${br}_nolace.wav & + PIDS+=( "$!" ) + done + for pid in ${PIDS[@]} + do + wait $pid + done +done diff --git a/dnn/torch/osce/stndrd/evaluation/run_nomad.py b/dnn/torch/osce/stndrd/evaluation/run_nomad.py new file mode 100644 index 00000000..0267bc92 --- /dev/null +++ b/dnn/torch/osce/stndrd/evaluation/run_nomad.py @@ -0,0 +1,138 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse +import tempfile +import shutil + +import pandas as pd +from scipy.spatial.distance import cdist +from scipy.io import wavfile +import numpy as np + +from nomad_audio.nomad import Nomad + + +parser = argparse.ArgumentParser() +parser.add_argument('folder', type=str, help='folder with processed items') +parser.add_argument('--full-reference', action='store_true', help='use NOMAD as full-reference metric') +parser.add_argument('--device', type=str, default=None, help='device for Nomad') + + +def get_bitrates(folder): + with open(os.path.join(folder, 'bitrates.txt')) as f: + x = f.read() + + bitrates = [int(y) for y in x.rstrip('\n').split()] + + return bitrates + +def get_itemlist(folder): + with open(os.path.join(folder, 'items.txt')) as f: + lines = f.readlines() + + items = [x.split()[0] for x in lines] + + return items + + +def nomad_wrapper(ref_folder, deg_folder, full_reference=False, ref_embeddings=None, device=None): + model = Nomad(device=device) + if not full_reference: + results = model.predict(nmr=ref_folder, deg=deg_folder)[0].to_dict()['NOMAD'] + return results, None + else: + if ref_embeddings is None: + print(f"Computing reference embeddings from {ref_folder}") + ref_data = pd.DataFrame(sorted(os.listdir(ref_folder))) + ref_data.columns = ['filename'] + ref_data['filename'] = [os.path.join(ref_folder, x) for x in ref_data['filename']] + ref_embeddings = model.get_embeddings_csv(model.model, ref_data).set_index('filename') + + print(f"Computing degraded embeddings from {deg_folder}") + deg_data = pd.DataFrame(sorted(os.listdir(deg_folder))) + deg_data.columns = ['filename'] + deg_data['filename'] = [os.path.join(deg_folder, x) for x in deg_data['filename']] + deg_embeddings = model.get_embeddings_csv(model.model, deg_data).set_index('filename') + + dist = np.diag(cdist(ref_embeddings, deg_embeddings)) # wasteful + test_files = [x.split('/')[-1].split('.')[0] for x in deg_embeddings.index] + + results = dict(zip(test_files, dist)) + + return results, ref_embeddings + + + + +def nomad_process_all(folder, full_reference=False, device=None): + bitrates = get_bitrates(folder) + items = get_itemlist(folder) + with tempfile.TemporaryDirectory() as dir: + cleandir = os.path.join(dir, 'clean') + opusdir = os.path.join(dir, 'opus') + lacedir = os.path.join(dir, 'lace') + nolacedir = os.path.join(dir, 'nolace') + + # prepare files + for d in [cleandir, opusdir, lacedir, nolacedir]: os.makedirs(d) + for br in bitrates: + for item in items: + for cond in ['clean', 'opus', 'lace', 'nolace']: + shutil.copyfile(os.path.join(folder, cond, f"{item}_{br}_{cond}.wav"), os.path.join(dir, cond, f"{item}_{br}.wav")) + + nomad_opus, ref_embeddings = nomad_wrapper(cleandir, opusdir, full_reference=full_reference, ref_embeddings=None) + nomad_lace, ref_embeddings = nomad_wrapper(cleandir, lacedir, full_reference=full_reference, ref_embeddings=ref_embeddings) + nomad_nolace, ref_embeddings = nomad_wrapper(cleandir, nolacedir, full_reference=full_reference, ref_embeddings=ref_embeddings) + + results = dict() + for br in bitrates: + results[br] = np.zeros((len(items), 3)) + for i, item in enumerate(items): + key = f"{item}_{br}" + results[br][i, 0] = nomad_opus[key] + results[br][i, 1] = nomad_lace[key] + results[br][i, 2] = nomad_nolace[key] + + return results + + + +if __name__ == "__main__": + args = parser.parse_args() + + items = get_itemlist(args.folder) + bitrates = get_bitrates(args.folder) + + results = nomad_process_all(args.folder, full_reference=args.full_reference, device=args.device) + + np.save(os.path.join(args.folder, f'results_nomad.npy'), results) + + print("Done.") |