Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-11-21 20:05:39 +0300
committerJan Buethe <jbuethe@amazon.de>2023-11-21 20:05:39 +0300
commit1bfce89465491da582132eb4a2f1d2c2ab2fe1a0 (patch)
treee40e4376be4d38b07fd32dac82074045e60ab6fd
parent464f33c8dad85b29fdc98562f5ad8d4718a1ff35 (diff)
added TDShaper to wexchange
-rw-r--r--dnn/torch/weight-exchange/wexchange/torch/torch.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py
index 5f7126b2..7392b522 100644
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -38,6 +38,7 @@ try:
import utils.layers as osce_layers
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
+ from utils.layers.td_shaper import TDShaper
has_osce=True
except:
has_osce=False
@@ -65,6 +66,7 @@ def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_sc
#define {name.upper()}_IN_CHANNELS {adaconv.in_channels}
#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels}
#define {name.upper()}_NORM_P {adaconv.norm_p}
+#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim}
"""
)
@@ -103,6 +105,8 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_sc
#define {name.upper()}_IN_CHANNELS {adaconv.in_channels}
#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels}
#define {name.upper()}_NORM_P {adaconv.norm_p}
+#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim}
+#define {name.upper()}_MAX_LAG {adaconv.max_lag}
"""
)
@@ -119,6 +123,29 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_sc
np.save(where, 'weight_global_gain.npy', w_global_gain)
np.save(where, 'bias_global_gain.npy', b_global_gain)
+def dump_torch_tdshaper(where, shaper, name='tdshaper'):
+
+ if isinstance(where, CWriter):
+ where.header.write(f"""
+#define {name.upper()}_FEATURE_DIM {shaper.feature_dim}
+#define {name.upper()}_FRAME_SIZE {shaper.frame_size}
+#define {name.upper()}_AVG_POOL_K {shaper.avg_pool_k}
+#define {name.upper()}_INNOVATE {1 if shaper.innovate else 0}
+#define {name.upper()}_POOL_AFTER {1 if shaper.pool_after else 0}
+"""
+ )
+
+ dump_torch_conv1d_weights(where, shaper.feature_alpha1, name + "_alpha1")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2")
+
+ if shaper.innovate:
+ dump_torch_conv1d_weights(where, shaper.feature_alpha1b, name + "_alpha1b")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha1c, name + "_alpha1c")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha2b, name + "_alpha2b")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha2c, name + "_alpha2c")
+
+
+
def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
assert gru.num_layers == 1
@@ -351,6 +378,8 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs):
dump_torch_adaptive_conv1d_weights(where, module, name, **kwargs)
elif isinstance(module, LimitedAdaptiveComb1d):
dump_torch_adaptive_comb1d_weights(where, module, name, **kwargs)
+ elif isinstance(module, TDShaper):
+ dump_torch_tdshaper(where, module, name, **kwargs)
else:
raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
else: