Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-09-29 15:25:26 +0300
committerJan Buethe <jbuethe@amazon.de>2023-09-29 15:25:26 +0300
commitc5c214df1b214375fce964949598f5b4405c655e (patch)
treea68b3273a91cfb4ad6067440a919d3ede9dc41f1
parent25c65a0c0b9ce8282cfc713a7c0581664c93ab18 (diff)
added rudimentary support for dumping nn.Conv2d layers
-rw-r--r--dnn/torch/weight-exchange/wexchange/c_export/__init__.py2
-rw-r--r--dnn/torch/weight-exchange/wexchange/c_export/common.py24
-rw-r--r--dnn/torch/weight-exchange/wexchange/torch/__init__.py1
-rw-r--r--dnn/torch/weight-exchange/wexchange/torch/torch.py33
4 files changed, 58 insertions, 2 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/__init__.py b/dnn/torch/weight-exchange/wexchange/c_export/__init__.py
index 331c2e05..46bbf007 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/__init__.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/__init__.py
@@ -28,4 +28,4 @@ from .c_writer import CWriter
*/
"""
-from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_vector \ No newline at end of file
+from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer, print_vector \ No newline at end of file
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py
index a8986816..98a45c3f 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/common.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py
@@ -291,6 +291,7 @@ def print_conv1d_layer(writer : CWriter,
lin_weight = np.reshape(weight, (-1, weight.shape[-1]))
print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize)
+
writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n")
writer.header.write(f"\n#define {name.upper()}_IN_SIZE {weight.shape[1]}\n")
writer.header.write(f"\n#define {name.upper()}_STATE_SIZE ({weight.shape[1]} * ({weight.shape[0] - 1}))\n")
@@ -298,6 +299,29 @@ def print_conv1d_layer(writer : CWriter,
return weight.shape[0] * weight.shape[1]
+def print_conv2d_layer(writer : CWriter,
+ name : str,
+ weight : np.ndarray,
+ bias : np.ndarray,
+ scale : float=1/128,
+ quantize : bool=False):
+
+ if quantize:
+ print("[print_conv2d_layer] warning: quantize argument ignored")
+
+ bias_name = name + "_bias"
+ float_weight_name = name + "_weight_float"
+
+ print_vector(writer, weight, float_weight_name)
+ print_vector(writer, bias, bias_name)
+
+ # init function
+ out_channels, in_channels, ksize1, ksize2 = weight.shape
+ init_call = f'conv2d_init(&model->{name}, arrays, "{bias_name}", "{float_weight_name}", {in_channels}, {out_channels}, {ksize1}, {ksize2})'
+
+ writer.layer_dict[name] = ('Conv2dLayer', init_call)
+
+
def print_gru_layer(writer : CWriter,
name : str,
diff --git a/dnn/torch/weight-exchange/wexchange/torch/__init__.py b/dnn/torch/weight-exchange/wexchange/torch/__init__.py
index 09e80e93..2a9b9792 100644
--- a/dnn/torch/weight-exchange/wexchange/torch/__init__.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/__init__.py
@@ -28,6 +28,7 @@
"""
from .torch import dump_torch_conv1d_weights, load_torch_conv1d_weights
+from .torch import dump_torch_conv2d_weights, load_torch_conv2d_weights
from .torch import dump_torch_dense_weights, load_torch_dense_weights
from .torch import dump_torch_gru_weights, load_torch_gru_weights
from .torch import dump_torch_embedding_weights, load_torch_embedding_weights
diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py
index 580ea3bf..35723c22 100644
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -32,7 +32,7 @@ import os
import torch
import numpy as np
-from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer
+from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer
def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
@@ -138,6 +138,33 @@ def load_torch_conv1d_weights(where, conv):
conv.bias.set_(torch.from_numpy(b))
+def dump_torch_conv2d_weights(where, conv, name='conv', scale=1/128, quantize=False):
+ w = conv.weight.detach().cpu().permute(0, 1, 3, 2).numpy().copy()
+ if conv.bias is None:
+ b = np.zeros(conv.out_channels, dtype=w.dtype)
+ else:
+ b = conv.bias.detach().cpu().numpy().copy()
+
+ if isinstance(where, CWriter):
+ return print_conv2d_layer(where, name, w, b, scale=scale, quantize=quantize)
+
+ else:
+ os.makedirs(where, exist_ok=True)
+
+ np.save(os.path.join(where, 'weight_oiwh.npy'), w)
+
+ np.save(os.path.join(where, 'bias.npy'), b)
+
+def load_torch_conv2d_weights(where, conv):
+ with torch.no_grad():
+ w = np.load(os.path.join(where, 'weight_oiwh.npy'))
+ conv.weight.set_(torch.from_numpy(w).permute(0, 1, 3, 2))
+ if type(conv.bias) != type(None):
+ b = np.load(os.path.join(where, 'bias.npy'))
+ if conv.bias is not None:
+ conv.bias.set_(torch.from_numpy(b))
+
+
def dump_torch_embedding_weights(where, emb):
os.makedirs(where, exist_ok=True)
@@ -162,6 +189,8 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs):
return dump_torch_gru_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Conv1d):
return dump_torch_conv1d_weights(where, module, name, **kwargs)
+ elif isinstance(module, torch.nn.Conv2d):
+ return dump_torch_conv2d_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Embedding):
return dump_torch_embedding_weights(where, module)
else:
@@ -175,6 +204,8 @@ def load_torch_weights(where, module):
load_torch_gru_weights(where, module)
elif isinstance(module, torch.nn.Conv1d):
load_torch_conv1d_weights(where, module)
+ elif isinstance(module, torch.nn.Conv2d):
+ load_torch_conv2d_weights(where, module)
elif isinstance(module, torch.nn.Embedding):
load_torch_embedding_weights(where, module)
else: