Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-11-02 18:52:50 +0300
committerJan Buethe <jbuethe@amazon.de>2023-11-02 18:52:50 +0300
commitda60266f6e11cb8d2d28aafa8ea05e5dadf3e8b6 (patch)
tree7679a3a52846480c6b2fd049e5eb39e9e89c0a8e
parentfeb32828877ea5e8723ea2a446eb20d7b3fba426 (diff)
updated moc method
-rw-r--r--dnn/torch/osce/utils/moc.py (renamed from dnn/torch/osce/utils/compare.py)71
1 files changed, 67 insertions, 4 deletions
diff --git a/dnn/torch/osce/utils/compare.py b/dnn/torch/osce/utils/moc.py
index f6422f63..a29f9338 100644
--- a/dnn/torch/osce/utils/compare.py
+++ b/dnn/torch/osce/utils/moc.py
@@ -1,6 +1,38 @@
import numpy as np
import scipy.signal
+def compute_vad_mask(x, fs, stop_db=-70):
+
+ frame_length = (fs + 49) // 50
+ x = x[: frame_length * (len(x) // frame_length)]
+
+ frames = x.reshape(-1, frame_length)
+ frame_energy = np.sum(frames ** 2, axis=1)
+ frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
+
+ max_threshold = frame_energy.max() * 10 ** (stop_db/20)
+ vactive = np.ones_like(frames)
+ vactive[frame_energy_smooth < max_threshold, :] = 0
+ vactive = vactive.reshape(-1)
+
+ filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
+ filter = filter / filter.sum()
+
+ mask = np.convolve(vactive, filter, mode='same')
+
+ return x, mask
+
+def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
+ num_samples = frame_size + (num_frames - 1) * hop_size
+ if len(mask) < num_samples:
+ mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
+ else:
+ mask = mask[:num_samples]
+
+ new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
+
+ return new_mask
+
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
@@ -36,7 +68,7 @@ def rect_fb(band_limits, num_bins=None):
return fb
-def compare(x, y):
+def compare(x, y, apply_vad=False):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
@@ -84,7 +116,38 @@ def compare(x, y):
re = masked_psd_y / masked_psd_x
im = re - np.log(re) - 1
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
- Ef = np.mean(Eb ** 2, axis=1)
- err = np.mean(Ef ** 4, axis=0) ** (1/16)
+ Ef = np.mean(Eb , axis=1)
+
+ if apply_vad:
+ _, mask = compute_vad_mask(x, 16000)
+ mask = convert_mask(mask, Ef.shape[0])
+ else:
+ mask = np.ones_like(Ef)
+
+ err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
+
+ return float(err)
+
+if __name__ == "__main__":
+ import argparse
+ from scipy.io import wavfile
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('ref', type=str, help='reference wav file')
+ parser.add_argument('deg', type=str, help='degraded wav file')
+ parser.add_argument('--apply-vad', action='store_true')
+ args = parser.parse_args()
+
+
+ fs1, x = wavfile.read(args.ref)
+ fs2, y = wavfile.read(args.deg)
+
+ if max(fs1, fs2) != 16000:
+ raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
+
+ x = x.astype(np.float32) / 2**15
+ y = y.astype(np.float32) / 2**15
+
+ err = compare(x, y, apply_vad=args.apply_vad)
- return float(err) \ No newline at end of file
+ print(f"MOC: {err}")