diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-11-21 20:05:39 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-11-21 20:05:39 +0300 |
commit | 1bfce89465491da582132eb4a2f1d2c2ab2fe1a0 (patch) | |
tree | e40e4376be4d38b07fd32dac82074045e60ab6fd | |
parent | 464f33c8dad85b29fdc98562f5ad8d4718a1ff35 (diff) |
added TDShaper to wexchange
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/torch/torch.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py index 5f7126b2..7392b522 100644 --- a/dnn/torch/weight-exchange/wexchange/torch/torch.py +++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py @@ -38,6 +38,7 @@ try: import utils.layers as osce_layers from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d + from utils.layers.td_shaper import TDShaper has_osce=True except: has_osce=False @@ -65,6 +66,7 @@ def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', kernel_sc #define {name.upper()}_IN_CHANNELS {adaconv.in_channels} #define {name.upper()}_OUT_CHANNELS {adaconv.out_channels} #define {name.upper()}_NORM_P {adaconv.norm_p} +#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim} """ ) @@ -103,6 +105,8 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_sc #define {name.upper()}_IN_CHANNELS {adaconv.in_channels} #define {name.upper()}_OUT_CHANNELS {adaconv.out_channels} #define {name.upper()}_NORM_P {adaconv.norm_p} +#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim} +#define {name.upper()}_MAX_LAG {adaconv.max_lag} """ ) @@ -119,6 +123,29 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', kernel_sc np.save(where, 'weight_global_gain.npy', w_global_gain) np.save(where, 'bias_global_gain.npy', b_global_gain) +def dump_torch_tdshaper(where, shaper, name='tdshaper'): + + if isinstance(where, CWriter): + where.header.write(f""" +#define {name.upper()}_FEATURE_DIM {shaper.feature_dim} +#define {name.upper()}_FRAME_SIZE {shaper.frame_size} +#define {name.upper()}_AVG_POOL_K {shaper.avg_pool_k} +#define {name.upper()}_INNOVATE {1 if shaper.innovate else 0} +#define {name.upper()}_POOL_AFTER {1 if shaper.pool_after else 0} +""" + ) + + dump_torch_conv1d_weights(where, shaper.feature_alpha1, name + "_alpha1") + dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2") + + if shaper.innovate: + dump_torch_conv1d_weights(where, shaper.feature_alpha1b, name + "_alpha1b") + dump_torch_conv1d_weights(where, shaper.feature_alpha1c, name + "_alpha1c") + dump_torch_conv1d_weights(where, shaper.feature_alpha2b, name + "_alpha2b") + dump_torch_conv1d_weights(where, shaper.feature_alpha2c, name + "_alpha2c") + + + def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): assert gru.num_layers == 1 @@ -351,6 +378,8 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs): dump_torch_adaptive_conv1d_weights(where, module, name, **kwargs) elif isinstance(module, LimitedAdaptiveComb1d): dump_torch_adaptive_comb1d_weights(where, module, name, **kwargs) + elif isinstance(module, TDShaper): + dump_torch_tdshaper(where, module, name, **kwargs) else: raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported') else: |