diff options
Diffstat (limited to 'dnn/torch/osce/models/silk_feature_net_pl.py')
-rw-r--r-- | dnn/torch/osce/models/silk_feature_net_pl.py | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index 72f8531c..799064d2 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -44,7 +44,8 @@ class SilkFeatureNetPL(nn.Module): num_channels=256, hidden_feature_dim=64, softquant=False, - sparsify=True): + sparsify=True, + sparsification_density=0.5): super(SilkFeatureNetPL, self).__init__() @@ -64,17 +65,17 @@ class SilkFeatureNetPL(nn.Module): if sparsify: - mark_for_sparsification(self.conv2, (0.25, [8, 4])) - mark_for_sparsification(self.tconv, (0.25, [8, 4])) + mark_for_sparsification(self.conv2, (sparsification_density, [8, 4])) + mark_for_sparsification(self.tconv, (sparsification_density, [8, 4])) mark_for_sparsification( self.gru, { - 'W_ir' : (0.25, [8, 4], False), - 'W_iz' : (0.25, [8, 4], False), - 'W_in' : (0.25, [8, 4], False), - 'W_hr' : (0.125, [8, 4], True), - 'W_hz' : (0.125, [8, 4], True), - 'W_hn' : (0.125, [8, 4], True), + 'W_ir' : (sparsification_density, [8, 4], False), + 'W_iz' : (sparsification_density, [8, 4], False), + 'W_in' : (sparsification_density, [8, 4], False), + 'W_hr' : (sparsification_density, [8, 4], True), + 'W_hz' : (sparsification_density, [8, 4], True), + 'W_hn' : (sparsification_density, [8, 4], True), } ) |