Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/dnn/torch
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-11-08 16:03:39 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-12-20 11:42:44 +0300
commit7d328f5bfaa321d823ff4d11b62d5357c99e0693 (patch)
tree873593e93c87a7b9b1de7f710696502737f1922b /dnn/torch
parent591c8bad70d8aa414729d1a243a6d930f64d6316 (diff)
Merge LACE/NoLACE under OSCE frameworkopus-ng-lace-integration5
Diffstat (limited to 'dnn/torch')
-rw-r--r--dnn/torch/osce/create_testvectors.py165
-rw-r--r--dnn/torch/osce/data/silk_enhancement_set.py6
-rw-r--r--dnn/torch/osce/export_model_weights.py101
-rw-r--r--dnn/torch/osce/models/lace.py2
-rw-r--r--dnn/torch/osce/models/no_lace.py4
-rw-r--r--dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py18
-rw-r--r--dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py15
-rw-r--r--dnn/torch/osce/utils/silk_features.py16
-rw-r--r--dnn/torch/osce/utils/spec.py1
-rw-r--r--dnn/torch/weight-exchange/wexchange/c_export/c_writer.py14
-rw-r--r--dnn/torch/weight-exchange/wexchange/torch/__init__.py1
-rw-r--r--dnn/torch/weight-exchange/wexchange/torch/torch.py157
12 files changed, 433 insertions, 67 deletions
diff --git a/dnn/torch/osce/create_testvectors.py b/dnn/torch/osce/create_testvectors.py
new file mode 100644
index 00000000..a037d0db
--- /dev/null
+++ b/dnn/torch/osce/create_testvectors.py
@@ -0,0 +1,165 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+import torch
+import numpy as np
+
+from models import model_dict
+from utils import endoscopy
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('checkpoint_path', type=str, help='path to folder containing checkpoints "lace_checkpoint.pth" and nolace_checkpoint.pth"')
+parser.add_argument('output_folder', type=str, help='output folder for testvectors')
+parser.add_argument('--debug', action='store_true', help='add debug output to output folder')
+
+
+def create_adaconv_testvector(prefix, adaconv, num_frames, debug=False):
+ feature_dim = adaconv.feature_dim
+ in_channels = adaconv.in_channels
+ out_channels = adaconv.out_channels
+ frame_size = adaconv.frame_size
+
+ features = torch.randn((1, num_frames, feature_dim))
+ x_in = torch.randn((1, in_channels, num_frames * frame_size))
+
+ x_out = adaconv(x_in, features, debug=debug)
+
+ features = features[0].detach().numpy()
+ x_in = x_in[0].reshape(in_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
+ x_out = x_out[0].reshape(out_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
+
+ features.tofile(prefix + '_features.f32')
+ x_in.tofile(prefix + '_x_in.f32')
+ x_out.tofile(prefix + '_x_out.f32')
+
+def create_adacomb_testvector(prefix, adacomb, num_frames, debug=False):
+ feature_dim = adacomb.feature_dim
+ in_channels = 1
+ frame_size = adacomb.frame_size
+
+ features = torch.randn((1, num_frames, feature_dim))
+ x_in = torch.randn((1, in_channels, num_frames * frame_size))
+ p_in = torch.randint(adacomb.kernel_size, 250, (1, num_frames))
+
+ x_out = adacomb(x_in, features, p_in, debug=debug)
+
+ features = features[0].detach().numpy()
+ x_in = x_in[0].permute(1, 0).detach().numpy()
+ p_in = p_in[0].detach().numpy().astype(np.int32)
+ x_out = x_out[0].permute(1, 0).detach().numpy()
+
+ features.tofile(prefix + '_features.f32')
+ x_in.tofile(prefix + '_x_in.f32')
+ p_in.tofile(prefix + '_p_in.s32')
+ x_out.tofile(prefix + '_x_out.f32')
+
+def create_adashape_testvector(prefix, adashape, num_frames):
+ feature_dim = adashape.feature_dim
+ frame_size = adashape.frame_size
+
+ features = torch.randn((1, num_frames, feature_dim))
+ x_in = torch.randn((1, 1, num_frames * frame_size))
+
+ x_out = adashape(x_in, features)
+
+ features = features[0].detach().numpy()
+ x_in = x_in.flatten().detach().numpy()
+ x_out = x_out.flatten().detach().numpy()
+
+ features.tofile(prefix + '_features.f32')
+ x_in.tofile(prefix + '_x_in.f32')
+ x_out.tofile(prefix + '_x_out.f32')
+
+def create_feature_net_testvector(prefix, model, num_frames):
+ num_features = model.num_features
+ num_subframes = 4 * num_frames
+
+ input_features = torch.randn((1, num_subframes, num_features))
+ periods = torch.randint(32, 300, (1, num_subframes))
+ numbits = model.numbits_range[0] + torch.rand((1, num_frames, 2)) * (model.numbits_range[1] - model.numbits_range[0])
+
+
+ pembed = model.pitch_embedding(periods)
+ nembed = torch.repeat_interleave(model.numbits_embedding(numbits).flatten(2), 4, dim=1)
+ full_features = torch.cat((input_features, pembed, nembed), dim=-1)
+
+ cf = model.feature_net(full_features)
+
+ input_features.float().numpy().tofile(prefix + "_in_features.f32")
+ periods.numpy().astype(np.int32).tofile(prefix + "_periods.s32")
+ numbits.float().numpy().tofile(prefix + "_numbits.f32")
+ full_features.detach().numpy().tofile(prefix + "_full_features.f32")
+ cf.detach().numpy().tofile(prefix + "_out_features.f32")
+
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ os.makedirs(args.output_folder, exist_ok=True)
+
+ lace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "lace_checkpoint.pth"), map_location='cpu')
+ nolace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "nolace_checkpoint.pth"), map_location='cpu')
+
+ lace = model_dict['lace'](**lace_checkpoint['setup']['model']['kwargs'])
+ nolace = model_dict['nolace'](**nolace_checkpoint['setup']['model']['kwargs'])
+
+ lace.load_state_dict(lace_checkpoint['state_dict'])
+ nolace.load_state_dict(nolace_checkpoint['state_dict'])
+
+ if args.debug:
+ endoscopy.init(args.output_folder)
+
+ # lace af1, 1 input channel, 1 output channel
+ create_adaconv_testvector(os.path.join(args.output_folder, "lace_af1"), lace.af1, 5, debug=args.debug)
+
+ # nolace af1, 1 input channel, 2 output channels
+ create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af1"), nolace.af1, 5, debug=args.debug)
+
+ # nolace af4, 2 input channel, 1 output channels
+ create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af4"), nolace.af4, 5, debug=args.debug)
+
+ # nolace af2, 2 input channel, 2 output channels
+ create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af2"), nolace.af2, 5, debug=args.debug)
+
+ # lace cf1
+ create_adacomb_testvector(os.path.join(args.output_folder, "lace_cf1"), lace.cf1, 5, debug=args.debug)
+
+ # nolace tdshape1
+ create_adashape_testvector(os.path.join(args.output_folder, "nolace_tdshape1"), nolace.tdshape1, 5)
+
+ # lace feature net
+ create_feature_net_testvector(os.path.join(args.output_folder, 'lace'), lace, 5)
+
+ if args.debug:
+ endoscopy.close()
diff --git a/dnn/torch/osce/data/silk_enhancement_set.py b/dnn/torch/osce/data/silk_enhancement_set.py
index 65e97508..fd18c4de 100644
--- a/dnn/torch/osce/data/silk_enhancement_set.py
+++ b/dnn/torch/osce/data/silk_enhancement_set.py
@@ -49,7 +49,6 @@ class SilkEnhancementSet(Dataset):
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
- add_offset=False,
add_double_lag_acorr=False,
):
@@ -73,7 +72,6 @@ class SilkEnhancementSet(Dataset):
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
- self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
self.clean_signal_hp = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
self.clean_signal = np.fromfile(os.path.join(path, 'clean.s16'), dtype=np.int16)
@@ -86,7 +84,6 @@ class SilkEnhancementSet(Dataset):
num_bands_noisy_spec,
noisy_spec_scale,
noisy_apply_dct,
- add_offset,
add_double_lag_acorr)
self.history_len = 700 if add_double_lag_acorr else 350
@@ -120,8 +117,7 @@ class SilkEnhancementSet(Dataset):
self.lpcs[frame_start : frame_stop],
self.gains[frame_start : frame_stop],
self.ltps[frame_start : frame_stop],
- self.periods[frame_start : frame_stop],
- self.offsets[frame_start : frame_stop]
+ self.periods[frame_start : frame_stop]
)
if self.preemph > 0:
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py
index 8b95aca9..f94431d3 100644
--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -40,10 +40,53 @@ import wexchange.torch
from wexchange.torch import dump_torch_weights
from models import model_dict
+from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
+from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
+from utils.layers.td_shaper import TDShaper
+from wexchange.torch import dump_torch_weights
+
+
+
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
+parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
+
+
+schedules = {
+ 'nolace': [
+ ('pitch_embedding', dict()),
+ ('feature_net.conv1', dict()),
+ ('feature_net.conv2', dict(quantize=True, scale=None)),
+ ('feature_net.tconv', dict(quantize=True, scale=None)),
+ ('feature_net.gru', dict()),
+ ('cf1', dict(quantize=True, scale=None)),
+ ('cf2', dict(quantize=True, scale=None)),
+ ('af1', dict(quantize=True, scale=None)),
+ ('tdshape1', dict()),
+ ('tdshape2', dict()),
+ ('tdshape3', dict()),
+ ('af2', dict(quantize=True, scale=None)),
+ ('af3', dict(quantize=True, scale=None)),
+ ('af4', dict(quantize=True, scale=None)),
+ ('post_cf1', dict(quantize=True, scale=None)),
+ ('post_cf2', dict(quantize=True, scale=None)),
+ ('post_af1', dict(quantize=True, scale=None)),
+ ('post_af2', dict(quantize=True, scale=None)),
+ ('post_af3', dict(quantize=True, scale=None))
+ ],
+ 'lace' : [
+ ('pitch_embedding', dict()),
+ ('feature_net.conv1', dict()),
+ ('feature_net.conv2', dict(quantize=True, scale=None)),
+ ('feature_net.tconv', dict(quantize=True, scale=None)),
+ ('feature_net.gru', dict()),
+ ('cf1', dict(quantize=True, scale=None)),
+ ('cf2', dict(quantize=True, scale=None)),
+ ('af1', dict(quantize=True, scale=None))
+ ]
+}
# auxiliary functions
@@ -60,8 +103,28 @@ def sha1(filename):
return sha1.hexdigest()
+def osce_dump_generic(writer, name, module):
+ if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \
+ or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding) \
+ or isinstance(module, LimitedAdaptiveConv1d) or isinstance(module, LimitedAdaptiveComb1d) \
+ or isinstance(module, TDShaper) or isinstance(module, torch.nn.GRU):
+ dump_torch_weights(writer, module, name=name, verbose=True)
+ else:
+ for child_name, child in module.named_children():
+ osce_dump_generic(writer, (name + "_" + child_name).replace("feature_net", "fnet"), child)
+
+
def export_name(name):
- return name.replace('.', '_')
+ name = name.replace('.', '_')
+ name = name.replace('feature_net', 'fnet')
+ return name
+
+def osce_scheduled_dump(writer, prefix, model, schedule):
+ if not prefix.endswith('_'):
+ prefix += '_'
+
+ for name, kwargs in schedule:
+ dump_torch_weights(writer, model.get_submodule(name), prefix + export_name(name), **kwargs, verbose=True)
if __name__ == "__main__":
args = parser.parse_args()
@@ -76,22 +139,34 @@ if __name__ == "__main__":
# create model and load weights
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
+ model.load_state_dict(checkpoint['state_dict'])
# CWriter
model_name = checkpoint['setup']['model']['name']
- cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper())
-
- # dump numbits_embedding parameters by hand
- numbits_embedding = model.get_submodule('numbits_embedding')
- weights = next(iter(numbits_embedding.parameters()))
- for i, c in enumerate(weights):
- cwriter.header.write(f"\nNUMBITS_COEF_{i} {float(c.detach())}f")
- cwriter.header.write("\n\n")
+ cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper() + 'Layers', add_typedef=True)
+
+ # Add custom includes and global parameters
+ cwriter.header.write(f'''
+#define {model_name.upper()}_PREEMPH {model.preemph}f
+#define {model_name.upper()}_FRAME_SIZE {model.FRAME_SIZE}
+#define {model_name.upper()}_OVERLAP_SIZE 40
+#define {model_name.upper()}_NUM_FEATURES {model.num_features}
+#define {model_name.upper()}_PITCH_MAX {model.pitch_max}
+#define {model_name.upper()}_PITCH_EMBEDDING_DIM {model.pitch_embedding_dim}
+#define {model_name.upper()}_NUMBITS_RANGE_LOW {model.numbits_range[0]}
+#define {model_name.upper()}_NUMBITS_RANGE_HIGH {model.numbits_range[1]}
+#define {model_name.upper()}_NUMBITS_EMBEDDING_DIM {model.numbits_embedding_dim}
+#define {model_name.upper()}_COND_DIM {model.cond_dim}
+#define {model_name.upper()}_HIDDEN_FEATURE_DIM {model.hidden_feature_dim}
+''')
+
+ for i, s in enumerate(model.numbits_embedding.scale_factors):
+ cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n")
# dump layers
- for name, module in model.named_modules():
- if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \
- or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding):
- dump_torch_weights(cwriter, module, name=export_name(name), verbose=True)
+ if model_name in schedules and args.quantize:
+ osce_scheduled_dump(cwriter, model_name, model, schedules[model_name])
+ else:
+ osce_dump_generic(cwriter, model_name, model)
cwriter.close()
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py
index a11dfc41..58293de4 100644
--- a/dnn/torch/osce/models/lace.py
+++ b/dnn/torch/osce/models/lace.py
@@ -96,7 +96,7 @@ class LACE(NNSBase):
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
# spectral shaping
- self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+ self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
def flop_count(self, rate=16000, verbose=False):
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py
index 2709274c..0e0fb1b3 100644
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -96,8 +96,8 @@ class NoLACE(NNSBase):
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
- self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
- self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+ self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+ self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
index b146240e..3bb6fa07 100644
--- a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
@@ -41,13 +41,13 @@ class LimitedAdaptiveComb1d(nn.Module):
feature_dim,
frame_size=160,
overlap_size=40,
- use_bias=True,
padding=None,
max_lag=256,
name=None,
gain_limit_db=10,
global_gain_limits_db=[-6, 6],
- norm_p=2):
+ norm_p=2,
+ **kwargs):
"""
Parameters:
@@ -87,7 +87,6 @@ class LimitedAdaptiveComb1d(nn.Module):
self.kernel_size = kernel_size
self.frame_size = frame_size
self.overlap_size = overlap_size
- self.use_bias = use_bias
self.max_lag = max_lag
self.limit_db = gain_limit_db
self.norm_p = norm_p
@@ -101,8 +100,6 @@ class LimitedAdaptiveComb1d(nn.Module):
# network for generating convolution weights
self.conv_kernel = nn.Linear(feature_dim, kernel_size)
- if self.use_bias:
- self.conv_bias = nn.Linear(feature_dim,1)
# comb filter gain
self.filter_gain = nn.Linear(feature_dim, 1)
@@ -154,9 +151,6 @@ class LimitedAdaptiveComb1d(nn.Module):
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
- if self.use_bias:
- conv_biases = self.conv_bias(features).permute(0, 2, 1)
-
conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
# calculate gains
global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
@@ -190,10 +184,6 @@ class LimitedAdaptiveComb1d(nn.Module):
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
-
- if self.use_bias:
- new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
-
offset = self.max_lag + self.padding[0]
new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
@@ -223,10 +213,6 @@ class LimitedAdaptiveComb1d(nn.Module):
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
- # bias computation
- if self.use_bias:
- count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
-
# a0 computation
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
index 073ea1b1..a17b0e9b 100644
--- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
@@ -46,12 +46,12 @@ class LimitedAdaptiveConv1d(nn.Module):
feature_dim,
frame_size=160,
overlap_size=40,
- use_bias=True,
padding=None,
name=None,
gain_limits_db=[-6, 6],
shape_gain_db=0,
- norm_p=2):
+ norm_p=2,
+ **kwargs):
"""
Parameters:
@@ -90,7 +90,6 @@ class LimitedAdaptiveConv1d(nn.Module):
self.kernel_size = kernel_size
self.frame_size = frame_size
self.overlap_size = overlap_size
- self.use_bias = use_bias
self.gain_limits_db = gain_limits_db
self.shape_gain_db = shape_gain_db
self.norm_p = norm_p
@@ -104,9 +103,6 @@ class LimitedAdaptiveConv1d(nn.Module):
# network for generating convolution weights
self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size)
- if self.use_bias:
- self.conv_bias = nn.Linear(feature_dim, out_channels)
-
self.shape_gain = min(1, 10**(shape_gain_db / 20))
self.filter_gain = nn.Linear(feature_dim, out_channels)
@@ -133,10 +129,6 @@ class LimitedAdaptiveConv1d(nn.Module):
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
- # bias computation
- if self.use_bias:
- count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
-
# gain computation
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
@@ -183,9 +175,6 @@ class LimitedAdaptiveConv1d(nn.Module):
conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
- if self.use_bias:
- conv_biases = self.conv_bias(features).permute(0, 2, 1)
-
# calculate gains
conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b)
if debug and batch_size == 1:
diff --git a/dnn/torch/osce/utils/silk_features.py b/dnn/torch/osce/utils/silk_features.py
index 2997ef5f..8c5dbf05 100644
--- a/dnn/torch/osce/utils/silk_features.py
+++ b/dnn/torch/osce/utils/silk_features.py
@@ -33,6 +33,7 @@ import numpy as np
import torch
import scipy
+import scipy.signal
from utils.pitch import hangover, calculate_acorr_window
from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
@@ -59,7 +60,6 @@ def silk_feature_factory(no_pitch_value=256,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
- add_offset=False,
add_double_lag_acorr=False
):
@@ -67,7 +67,7 @@ def silk_feature_factory(no_pitch_value=256,
fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
- def create_features(noisy, noisy_history, lpcs, gains, ltps, periods, offsets):
+ def create_features(noisy, noisy_history, lpcs, gains, ltps, periods):
periods = periods.copy()
@@ -89,10 +89,7 @@ def silk_feature_factory(no_pitch_value=256,
acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
- if add_offset:
- features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains, offsets.reshape(-1, 1)), axis=-1, dtype=np.float32)
- else:
- features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
+ features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
return features, periods.astype(np.int64)
@@ -110,7 +107,6 @@ def load_inference_data(path,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
- add_offset=False,
add_double_lag_acorr=False,
**kwargs):
@@ -122,13 +118,12 @@ def load_inference_data(path,
periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
- offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
# load signal, add back delay and pre-emphasize
signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
- create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_offset, add_double_lag_acorr)
+ create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_double_lag_acorr)
num_frames = min((len(signal) // 320) * 4, len(lpcs))
signal = signal[: num_frames * 80]
@@ -138,11 +133,10 @@ def load_inference_data(path,
periods = periods[: num_frames]
num_bits = num_bits[: num_frames // 4]
num_bits_smooth = num_bits[: num_frames // 4]
- offsets = offsets[: num_frames]
numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
- features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods, offsets)
+ features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods)
if preemph > 0:
signal[1:] -= preemph * signal[:-1]
diff --git a/dnn/torch/osce/utils/spec.py b/dnn/torch/osce/utils/spec.py
index 01b923ae..59f53538 100644
--- a/dnn/torch/osce/utils/spec.py
+++ b/dnn/torch/osce/utils/spec.py
@@ -30,6 +30,7 @@
import math as m
import numpy as np
import scipy
+import scipy.fftpack
import torch
def erb(f):
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
index 36050881..2745f337 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
@@ -38,7 +38,8 @@ class CWriter:
create_state_struct=False,
enable_binary_blob=True,
model_struct_name="Model",
- nnet_header="nnet.h"):
+ nnet_header="nnet.h",
+ add_typedef=False):
"""
Writer class for creating souce and header files for weight exports to C
@@ -73,6 +74,7 @@ class CWriter:
self.enable_binary_blob = enable_binary_blob
self.create_state_struct = create_state_struct
self.model_struct_name = model_struct_name
+ self.add_typedef = add_typedef
# for binary blob format, format is key=<layer name>, value=(<layer type>, <init call>)
self.layer_dict = OrderedDict()
@@ -119,11 +121,17 @@ f"""
# create model type
if self.enable_binary_blob:
- self.header.write(f"\nstruct {self.model_struct_name} {{")
+ if self.add_typedef:
+ self.header.write(f"\ntypedef struct {{")
+ else:
+ self.header.write(f"\nstruct {self.model_struct_name} {{")
for name, data in self.layer_dict.items():
layer_type = data[0]
self.header.write(f"\n {layer_type} {name};")
- self.header.write(f"\n}};\n")
+ if self.add_typedef:
+ self.header.write(f"\n}} {self.model_struct_name};\n")
+ else:
+ self.header.write(f"\n}};\n")
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
self.header.write(f"\n{init_prototype};\n")
diff --git a/dnn/torch/weight-exchange/wexchange/torch/__init__.py b/dnn/torch/weight-exchange/wexchange/torch/__init__.py
index 98c96fad..8245566d 100644
--- a/dnn/torch/weight-exchange/wexchange/torch/__init__.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/__init__.py
@@ -34,3 +34,4 @@ from .torch import dump_torch_gru_weights, load_torch_gru_weights
from .torch import dump_torch_grucell_weights
from .torch import dump_torch_embedding_weights, load_torch_embedding_weights
from .torch import dump_torch_weights, load_torch_weights
+from .torch import dump_torch_adaptive_conv1d_weights \ No newline at end of file
diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py
index 281d9be3..f7e16032 100644
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -28,12 +28,154 @@
"""
import os
+import sys
import torch
import numpy as np
+sys.path.append(sys.path.append(os.path.join(os.path.dirname(__file__), '../osce')))
+try:
+ import utils.layers as osce_layers
+ from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
+ from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
+ from utils.layers.td_shaper import TDShaper
+ has_osce=True
+except:
+ has_osce=False
+
from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer
+def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False):
+
+
+ w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy()
+ b_kernel = adaconv.conv_kernel.bias.detach().cpu().numpy().copy()
+ w_gain = adaconv.filter_gain.weight.detach().cpu().numpy().copy()
+ b_gain = adaconv.filter_gain.bias.detach().cpu().numpy().copy()
+
+ if isinstance(where, CWriter):
+ # pad kernel for quantization
+ left_padding = adaconv.padding[0]
+ kernel_size = adaconv.kernel_size
+ in_channels = adaconv.in_channels
+ out_channels = adaconv.out_channels
+ feature_dim = adaconv.feature_dim
+
+ if quantize and kernel_size % 8:
+ kernel_padding = 8 - (kernel_size % 8)
+ w_kernel = np.concatenate(
+ (np.zeros((out_channels, in_channels, kernel_padding, feature_dim)), w_kernel.reshape(out_channels, in_channels, kernel_size, feature_dim)),
+ dtype=w_kernel.dtype,
+ axis=2).reshape(-1, feature_dim)
+ b_kernel = np.concatenate(
+ (np.zeros((out_channels, in_channels, kernel_padding)), b_kernel.reshape(out_channels, in_channels, kernel_size)),
+ dtype=b_kernel.dtype,
+ axis=2).reshape(-1)
+ left_padding += kernel_padding
+ kernel_size += kernel_padding
+
+ # write relevant scalar parameters to header file
+ where.header.write(f"""
+#define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f
+#define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f
+#define {name.upper()}_SHAPE_GAIN {adaconv.shape_gain:f}f
+#define {name.upper()}_KERNEL_SIZE {kernel_size}
+#define {name.upper()}_FRAME_SIZE {adaconv.frame_size}
+#define {name.upper()}_LEFT_PADDING {left_padding}
+#define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size}
+#define {name.upper()}_IN_CHANNELS {adaconv.in_channels}
+#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels}
+#define {name.upper()}_NORM_P {adaconv.norm_p}
+#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim}
+"""
+ )
+
+ print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize)
+ print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False)
+
+
+ else:
+ np.save(where, 'weight_kernel.npy', w_kernel)
+ np.save(where, 'bias_kernel.npy', b_kernel)
+ np.save(where, 'weight_gain.npy', w_gain)
+ np.save(where, 'bias_gain.npy', b_gain)
+
+
+def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False):
+
+
+ w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy()
+ b_kernel = adaconv.conv_kernel.bias.detach().cpu().numpy().copy()
+ w_gain = adaconv.filter_gain.weight.detach().cpu().numpy().copy()
+ b_gain = adaconv.filter_gain.bias.detach().cpu().numpy().copy()
+ w_global_gain = adaconv.global_filter_gain.weight.detach().cpu().numpy().copy()
+ b_global_gain = adaconv.global_filter_gain.bias.detach().cpu().numpy().copy()
+
+
+ if isinstance(where, CWriter):
+ # pad kernel for quantization
+ left_padding = adaconv.padding[0]
+ kernel_size = adaconv.kernel_size
+
+ if quantize and w_kernel.shape[0] % 8:
+ kernel_padding = 8 - (w_kernel.shape[0] % 8)
+ w_kernel = np.concatenate((np.zeros((kernel_padding, w_kernel.shape[1])), w_kernel), dtype=w_kernel.dtype)
+ b_kernel = np.concatenate((np.zeros((kernel_padding)), b_kernel), dtype=b_kernel.dtype)
+ left_padding += kernel_padding
+ kernel_size += kernel_padding
+ # write relevant scalar parameters to header file
+ where.header.write(f"""
+#define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f
+#define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f
+#define {name.upper()}_LOG_GAIN_LIMIT {adaconv.log_gain_limit:f}f
+#define {name.upper()}_KERNEL_SIZE {kernel_size}
+#define {name.upper()}_LEFT_PADDING {left_padding}
+#define {name.upper()}_FRAME_SIZE {adaconv.frame_size}
+#define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size}
+#define {name.upper()}_IN_CHANNELS {adaconv.in_channels}
+#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels}
+#define {name.upper()}_NORM_P {adaconv.norm_p}
+#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim}
+#define {name.upper()}_MAX_LAG {adaconv.max_lag}
+"""
+ )
+
+ print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize)
+ print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False)
+ print_dense_layer(where, name + "_global_gain", w_global_gain, b_global_gain, format='torch', sparse=False, diagonal=False, quantize=False)
+
+
+ else:
+ np.save(where, 'weight_kernel.npy', w_kernel)
+ np.save(where, 'bias_kernel.npy', b_kernel)
+ np.save(where, 'weight_gain.npy', w_gain)
+ np.save(where, 'bias_gain.npy', b_gain)
+ np.save(where, 'weight_global_gain.npy', w_global_gain)
+ np.save(where, 'bias_global_gain.npy', b_global_gain)
+
+def dump_torch_tdshaper(where, shaper, name='tdshaper'):
+
+ if isinstance(where, CWriter):
+ where.header.write(f"""
+#define {name.upper()}_FEATURE_DIM {shaper.feature_dim}
+#define {name.upper()}_FRAME_SIZE {shaper.frame_size}
+#define {name.upper()}_AVG_POOL_K {shaper.avg_pool_k}
+#define {name.upper()}_INNOVATE {1 if shaper.innovate else 0}
+#define {name.upper()}_POOL_AFTER {1 if shaper.pool_after else 0}
+"""
+ )
+
+ dump_torch_conv1d_weights(where, shaper.feature_alpha1, name + "_alpha1")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2")
+
+ if shaper.innovate:
+ dump_torch_conv1d_weights(where, shaper.feature_alpha1b, name + "_alpha1b")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha1c, name + "_alpha1c")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha2b, name + "_alpha2b")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha2c, name + "_alpha2c")
+
+
+
def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
assert gru.num_layers == 1
@@ -221,7 +363,6 @@ def load_torch_conv2d_weights(where, conv):
def dump_torch_embedding_weights(where, embed, name='embed', scale=1/128, sparse=False, diagonal=False, quantize=False):
- print("quantize = ", quantize)
w = embed.weight.detach().cpu().numpy().copy().transpose()
b = np.zeros(w.shape[0], dtype=w.dtype)
@@ -257,11 +398,21 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs):
elif isinstance(module, torch.nn.Conv2d):
return dump_torch_conv2d_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Embedding):
- return dump_torch_embedding_weights(where, module)
+ return dump_torch_embedding_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.ConvTranspose1d):
return dump_torch_tconv1d_weights(where, module, name, **kwargs)
else:
- raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
+ if has_osce:
+ if isinstance(module, LimitedAdaptiveConv1d):
+ dump_torch_adaptive_conv1d_weights(where, module, name, **kwargs)
+ elif isinstance(module, LimitedAdaptiveComb1d):
+ dump_torch_adaptive_comb1d_weights(where, module, name, **kwargs)
+ elif isinstance(module, TDShaper):
+ dump_torch_tdshaper(where, module, name, **kwargs)
+ else:
+ raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
+ else:
+ raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
def load_torch_weights(where, module):
""" generic function for loading weights of some torch.nn.Module """