diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-11-02 18:52:50 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-11-02 18:52:50 +0300 |
commit | da60266f6e11cb8d2d28aafa8ea05e5dadf3e8b6 (patch) | |
tree | 7679a3a52846480c6b2fd049e5eb39e9e89c0a8e | |
parent | feb32828877ea5e8723ea2a446eb20d7b3fba426 (diff) |
updated moc method
-rw-r--r-- | dnn/torch/osce/utils/moc.py (renamed from dnn/torch/osce/utils/compare.py) | 71 |
1 files changed, 67 insertions, 4 deletions
diff --git a/dnn/torch/osce/utils/compare.py b/dnn/torch/osce/utils/moc.py index f6422f63..a29f9338 100644 --- a/dnn/torch/osce/utils/compare.py +++ b/dnn/torch/osce/utils/moc.py @@ -1,6 +1,38 @@ import numpy as np import scipy.signal +def compute_vad_mask(x, fs, stop_db=-70): + + frame_length = (fs + 49) // 50 + x = x[: frame_length * (len(x) // frame_length)] + + frames = x.reshape(-1, frame_length) + frame_energy = np.sum(frames ** 2, axis=1) + frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same') + + max_threshold = frame_energy.max() * 10 ** (stop_db/20) + vactive = np.ones_like(frames) + vactive[frame_energy_smooth < max_threshold, :] = 0 + vactive = vactive.reshape(-1) + + filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1)) + filter = filter / filter.sum() + + mask = np.convolve(vactive, filter, mode='same') + + return x, mask + +def convert_mask(mask, num_frames, frame_size=160, hop_size=40): + num_samples = frame_size + (num_frames - 1) * hop_size + if len(mask) < num_samples: + mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype) + else: + mask = mask[:num_samples] + + new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)]) + + return new_mask + def power_spectrum(x, window_size=160, hop_size=40, window='hamming'): num_spectra = (len(x) - window_size - hop_size) // hop_size window = scipy.signal.get_window(window, window_size) @@ -36,7 +68,7 @@ def rect_fb(band_limits, num_bins=None): return fb -def compare(x, y): +def compare(x, y, apply_vad=False): """ Modified version of opus_compare for 16 kHz mono signals Args: @@ -84,7 +116,38 @@ def compare(x, y): re = masked_psd_y / masked_psd_x im = re - np.log(re) - 1 Eb = ((im @ fb.T) / np.sum(fb, axis=1)) - Ef = np.mean(Eb ** 2, axis=1) - err = np.mean(Ef ** 4, axis=0) ** (1/16) + Ef = np.mean(Eb , axis=1) + + if apply_vad: + _, mask = compute_vad_mask(x, 16000) + mask = convert_mask(mask, Ef.shape[0]) + else: + mask = np.ones_like(Ef) + + err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6) + + return float(err) + +if __name__ == "__main__": + import argparse + from scipy.io import wavfile + + parser = argparse.ArgumentParser() + parser.add_argument('ref', type=str, help='reference wav file') + parser.add_argument('deg', type=str, help='degraded wav file') + parser.add_argument('--apply-vad', action='store_true') + args = parser.parse_args() + + + fs1, x = wavfile.read(args.ref) + fs2, y = wavfile.read(args.deg) + + if max(fs1, fs2) != 16000: + raise ValueError('error: encountered sampling frequency diffrent from 16kHz') + + x = x.astype(np.float32) / 2**15 + y = y.astype(np.float32) / 2**15 + + err = compare(x, y, apply_vad=args.apply_vad) - return float(err)
\ No newline at end of file + print(f"MOC: {err}") |