diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-08 14:00:49 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-08 14:05:04 +0300 |
commit | 999ddbe09ca1e521cc8d71005c1ecd8513e47611 (patch) | |
tree | 85022704a9352cb03751863f9a1ae788afcc805b /dnn/torch/osce/models/no_lace.py | |
parent | e968878f06fb96e022f09c0fab60b88ab3f3ac81 (diff) |
more sparsification stuff
Diffstat (limited to 'dnn/torch/osce/models/no_lace.py')
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 5654db6c..7e021930 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -72,7 +72,9 @@ class NoLACE(NNSBase): avg_pool_k=4, pool_after=False, softquant=False, - sparsify=False): + sparsify=False, + sparsification_schedule=[100, 1000, 100], + sparsification_density=0.5): super().__init__(skip=skip, preemph=preemph) @@ -97,7 +99,7 @@ class NoLACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) @@ -136,13 +138,13 @@ class NoLACE(NNSBase): if sparsify: - mark_for_sparsification(self.post_cf1, (0.25, [8, 4])) - mark_for_sparsification(self.post_cf2, (0.25, [8, 4])) - mark_for_sparsification(self.post_af1, (0.25, [8, 4])) - mark_for_sparsification(self.post_af2, (0.25, [8, 4])) - mark_for_sparsification(self.post_af3, (0.25, [8, 4])) + mark_for_sparsification(self.post_cf1, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_cf2, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_af1, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_af2, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_af3, (sparsification_density, [8, 4])) - self.sparsify = create_sparsifier(self, 500, 1000, 100) + self.sparsifier = create_sparsifier(self, *sparsification_schedule) def flop_count(self, rate=16000, verbose=False): |