diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-11-07 13:54:22 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-11-07 13:54:22 +0300 |
commit | 8af5c6b4a13cb66e0f3dcd465c246d2d2e4128c7 (patch) | |
tree | 2f2fd813d800f279f22875c1cf7d5b5d341a3a89 | |
parent | b6095cf22d501cb1950685e46b334b0a2ca7e78b (diff) |
added transposed 1d convolutions to wexchange
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/__init__.py | 2 | ||||
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/common.py | 22 | ||||
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/torch/torch.py | 38 |
3 files changed, 59 insertions, 3 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/__init__.py b/dnn/torch/weight-exchange/wexchange/c_export/__init__.py index 46bbf007..2a580c80 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/__init__.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/__init__.py @@ -28,4 +28,4 @@ from .c_writer import CWriter */ """ -from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer, print_vector
\ No newline at end of file +from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer, print_vector
\ No newline at end of file diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py index 5dd9f138..524f1cc3 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/common.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py @@ -361,3 +361,25 @@ def print_gru_layer(writer : CWriter, writer.header.write(f"\n#define {name.upper()}_STATE_SIZE {N}\n") return N + + +def print_tconv1d_layer(writer : CWriter, + name : str, + weight : np.ndarray, + bias : np.ndarray, + stride: int, + scale=1/128, + quantize=False): + + in_channels, out_channels, kernel_size = weight.shape + + + linear_weight = weight.transpose(2, 1, 0).reshape(kernel_size * out_channels, in_channels).transpose(1, 0) + linear_bias = np.repeat(bias[np.newaxis, :], kernel_size, 0).flatten() + + print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize) + + writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n") + writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n") + writer.header.write(f"\n#define {name.upper()}_IN_CHANNELS {in_channels}\n") + writer.header.write(f"\n#define {name.upper()}_OUT_CHANNELS {out_channels}\n")
\ No newline at end of file diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py index 1e56b9d5..281d9be3 100644 --- a/dnn/torch/weight-exchange/wexchange/torch/torch.py +++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py @@ -32,7 +32,7 @@ import os import torch import numpy as np -from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer +from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): @@ -162,6 +162,36 @@ def load_torch_conv1d_weights(where, conv): conv.bias.set_(torch.from_numpy(b)) +def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): + + w = conv.weight.detach().cpu().numpy().copy() + if conv.bias is None: + b = np.zeros(conv.out_channels, dtype=w.dtype) + else: + b = conv.bias.detach().cpu().numpy().copy() + + if isinstance(where, CWriter): + + return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize) + else: + os.makedirs(where, exist_ok=True) + + np.save(os.path.join(where, 'weight_oik.npy'), w) + + np.save(os.path.join(where, 'bias.npy'), b) + + +def load_torch_tconv1d_weights(where, conv): + + with torch.no_grad(): + w = np.load(os.path.join(where, 'weight_oik.npy')) + conv.weight.set_(torch.from_numpy(w)) + if type(conv.bias) != type(None): + b = np.load(os.path.join(where, 'bias.npy')) + if conv.bias is not None: + conv.bias.set_(torch.from_numpy(b)) + + def dump_torch_conv2d_weights(where, conv, name='conv', scale=1/128, quantize=False): w = conv.weight.detach().cpu().permute(0, 1, 3, 2).numpy().copy() if conv.bias is None: @@ -228,6 +258,8 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs): return dump_torch_conv2d_weights(where, module, name, **kwargs) elif isinstance(module, torch.nn.Embedding): return dump_torch_embedding_weights(where, module) + elif isinstance(module, torch.nn.ConvTranspose1d): + return dump_torch_tconv1d_weights(where, module, name, **kwargs) else: raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported') @@ -243,5 +275,7 @@ def load_torch_weights(where, module): load_torch_conv2d_weights(where, module) elif isinstance(module, torch.nn.Embedding): load_torch_embedding_weights(where, module) + elif isinstance(module, torch.nn.ConvTranspose1d): + return load_torch_tconv1d_weights(where, module) else: - raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported') + raise ValueError(f'load_torch_weights: layer of type {type(module)} not supported') |