Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/osce/export_model_weights.py')
-rw-r--r--dnn/torch/osce/export_model_weights.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py
index f94431d3..0bec9604 100644
--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -43,6 +43,7 @@ from models import model_dict
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
+from utils.misc import remove_all_weight_norm
from wexchange.torch import dump_torch_weights
@@ -58,30 +59,30 @@ schedules = {
'nolace': [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
- ('feature_net.conv2', dict(quantize=True, scale=None)),
- ('feature_net.tconv', dict(quantize=True, scale=None)),
- ('feature_net.gru', dict()),
+ ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
+ ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
+ ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None)),
- ('tdshape1', dict()),
- ('tdshape2', dict()),
- ('tdshape3', dict()),
+ ('tdshape1', dict(quantize=True, scale=None)),
+ ('tdshape2', dict(quantize=True, scale=None)),
+ ('tdshape3', dict(quantize=True, scale=None)),
('af2', dict(quantize=True, scale=None)),
('af3', dict(quantize=True, scale=None)),
('af4', dict(quantize=True, scale=None)),
- ('post_cf1', dict(quantize=True, scale=None)),
- ('post_cf2', dict(quantize=True, scale=None)),
- ('post_af1', dict(quantize=True, scale=None)),
- ('post_af2', dict(quantize=True, scale=None)),
- ('post_af3', dict(quantize=True, scale=None))
+ ('post_cf1', dict(quantize=True, scale=None, sparse=True)),
+ ('post_cf2', dict(quantize=True, scale=None, sparse=True)),
+ ('post_af1', dict(quantize=True, scale=None, sparse=True)),
+ ('post_af2', dict(quantize=True, scale=None, sparse=True)),
+ ('post_af3', dict(quantize=True, scale=None, sparse=True))
],
'lace' : [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
- ('feature_net.conv2', dict(quantize=True, scale=None)),
- ('feature_net.tconv', dict(quantize=True, scale=None)),
- ('feature_net.gru', dict()),
+ ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
+ ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
+ ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None))
@@ -140,6 +141,7 @@ if __name__ == "__main__":
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
+ remove_all_weight_norm(model, verbose=True)
# CWriter
model_name = checkpoint['setup']['model']['name']