Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2024-01-16 12:13:44 +0300
committerJan Buethe <jbuethe@amazon.de>2024-01-16 12:13:44 +0300
commit58c8bc617461c53d9b20ee896150a58ce2803de1 (patch)
tree8fdb62224dca7f6bf6fdd69ba9f68f388eb3f512
parent0e8a527eebd34f3c51960c816c277b873d6792da (diff)
added relegance inspection tool to dnntools
-rw-r--r--dnn/torch/dnntools/dnntools/relegance/__init__.py2
-rw-r--r--dnn/torch/dnntools/dnntools/relegance/meta_critic.py85
-rw-r--r--dnn/torch/dnntools/dnntools/relegance/relegance.py449
3 files changed, 536 insertions, 0 deletions
diff --git a/dnn/torch/dnntools/dnntools/relegance/__init__.py b/dnn/torch/dnntools/dnntools/relegance/__init__.py
new file mode 100644
index 00000000..cee0143b
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/relegance/__init__.py
@@ -0,0 +1,2 @@
+from .relegance import relegance_gradient_weighting, relegance_create_tconv_kernel, relegance_map_relevance_to_input_domain, relegance_resize_relevance_to_input_size
+from .meta_critic import MetaCritic \ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/relegance/meta_critic.py b/dnn/torch/dnntools/dnntools/relegance/meta_critic.py
new file mode 100644
index 00000000..1af0f8ff
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/relegance/meta_critic.py
@@ -0,0 +1,85 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+class MetaCritic():
+ def __init__(self, normalize=False, gamma=0.9, beta=0.0, joint_stats=False):
+ """ Class for assessing relevance of discriminator scores
+
+ Args:
+ gamma (float, optional): update rate for tracking discriminator stats. Defaults to 0.9.
+ beta (float, optional): Miminum confidence related threshold. Defaults to 0.0.
+ """
+ self.normalize = normalize
+ self.gamma = gamma
+ self.beta = beta
+ self.joint_stats = joint_stats
+
+ self.disc_stats = dict()
+
+ def __call__(self, disc_id, real_scores, generated_scores):
+ """ calculates relevance from normalized scores
+
+ Args:
+ disc_id (any valid key): id for tracking discriminator statistics
+ real_scores (torch.tensor): scores for real data
+ generated_scores (torch.tensor): scores for generated data; expecting device to match real_scores.device
+
+ Returns:
+ torch.tensor: output-domain relevance
+ """
+
+ if self.normalize:
+ real_std = torch.std(real_scores.detach()).cpu().item()
+ gen_std = torch.std(generated_scores.detach()).cpu().item()
+ std = (real_std**2 + gen_std**2) ** .5
+ mean = torch.mean(real_scores.detach()).cpu().item() - torch.mean(generated_scores.detach()).cpu().item()
+
+ key = 0 if self.joint_stats else disc_id
+
+ if key in self.disc_stats:
+ self.disc_stats[key]['std'] = self.gamma * self.disc_stats[key]['std'] + (1 - self.gamma) * std
+ self.disc_stats[key]['mean'] = self.gamma * self.disc_stats[key]['mean'] + (1 - self.gamma) * mean
+ else:
+ self.disc_stats[key] = {
+ 'std': std + 1e-5,
+ 'mean': mean
+ }
+
+ std = self.disc_stats[key]['std']
+ mean = self.disc_stats[key]['mean']
+ else:
+ mean, std = 0, 1
+
+ relevance = torch.relu((real_scores - generated_scores - mean) / std + mean - self.beta)
+
+ if False: print(f"relevance({disc_id}): {relevance.min()=} {relevance.max()=} {relevance.mean()=}")
+
+ return relevance \ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/relegance/relegance.py b/dnn/torch/dnntools/dnntools/relegance/relegance.py
new file mode 100644
index 00000000..29c5be23
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/relegance/relegance.py
@@ -0,0 +1,449 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+import torch.nn.functional as F
+
+
+def view_one_hot(index, length):
+ vec = length * [1]
+ vec[index] = -1
+ return vec
+
+def create_smoothing_kernel(widths, gamma=1.5):
+ """ creates a truncated gaussian smoothing kernel for the given widths
+
+ Parameters:
+ -----------
+ widths: list[Int] or torch.LongTensor
+ specifies the shape of the smoothing kernel, entries must be > 0.
+
+ gamma: float, optional
+ decay factor for gaussian relative to kernel size
+
+ Returns:
+ --------
+ kernel: torch.FloatTensor
+ """
+
+ widths = torch.LongTensor(widths)
+ num_dims = len(widths)
+
+ assert(widths.min() > 0)
+
+ centers = widths.float() / 2 - 0.5
+ sigmas = gamma * (centers + 1)
+
+ vals = []
+
+ vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)]
+ vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)])
+
+ kernel = torch.exp(- vals)
+ kernel = kernel / kernel.sum()
+
+ return kernel
+
+
+def create_partition_kernel(widths, strides):
+ """ creates a partition kernel for mapping a convolutional network output back to the input domain
+
+ Given a fully convolutional network with receptive field of shape widths and the given strides, this
+ function construncts an intorpolation kernel whose tranlations by multiples of the given strides form
+ a partition of one on the input domain.
+
+ Parameter:
+ ----------
+ widths: list[Int] or torch.LongTensor
+ shape of receptive field
+
+ strides: list[Int] or torch.LongTensor
+ total strides of convolutional network
+
+ Returns:
+ kernel: torch.FloatTensor
+ """
+
+ num_dims = len(widths)
+ assert num_dims == len(strides) and num_dims in {1, 2, 3}
+
+ convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d}
+
+ widths = torch.LongTensor(widths)
+ strides = torch.LongTensor(strides)
+
+ proto_kernel = torch.ones(torch.minimum(strides, widths).tolist())
+
+ # create interpolation kernel eta
+ eta_widths = widths - strides + 1
+ if eta_widths.min() <= 0:
+ print("[create_partition_kernel] warning: receptive field does not cover input domain")
+ eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths))
+
+
+ eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist())
+
+ padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch
+ padded_proto_kernel = F.pad(proto_kernel, padding)
+ padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape)
+ kernel = convs[num_dims](padded_proto_kernel, eta)
+
+ return kernel
+
+
+def receptive_field(conv_model, input_shape, output_position):
+ """ estimates boundaries of receptive field connected to output_position via autograd
+
+ Parameters:
+ -----------
+ conv_model: nn.Module or autograd function
+ function or model implementing fully convolutional model
+
+ input_shape: List[Int]
+ input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...]
+
+ output_position: List[Int]
+ output position for which the receptive field is determined; the function raises an exception
+ if output_position is out of bounds for the given input_shape.
+
+ Returns:
+ --------
+ low: List[Int]
+ start indices of receptive field
+
+ high: List[Int]
+ stop indices of receptive field
+
+ """
+
+ x = torch.randn((1,) + tuple(input_shape), requires_grad=True)
+ y = conv_model(x)
+
+ # collapse channels and remove batch dimension
+ y = torch.sum(y, 1)[0]
+
+ # create mask
+ mask = torch.zeros_like(y)
+ index = [torch.tensor(i) for i in output_position]
+ try:
+ mask.index_put_(index, torch.tensor(1, dtype=mask.dtype))
+ except IndexError:
+ raise ValueError('output_position out of bounds')
+
+ (mask * y).sum().backward()
+
+ # sum over channels and remove batch dimension
+ grad = torch.sum(x.grad, dim=1)[0]
+ tmp = torch.nonzero(grad, as_tuple=True)
+ low = [t.min().item() for t in tmp]
+ high = [t.max().item() for t in tmp]
+
+ return low, high
+
+def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10):
+ """ attempts to estimate receptive field size, strides and left paddings for given model
+
+
+ Parameters:
+ -----------
+ model: nn.Module or autograd function
+ fully convolutional model for which parameters are estimated
+
+ num_channels: Int
+ number of input channels for model
+
+ num_dims: Int
+ number of input dimensions for model (without channel dimension)
+
+ width: Int
+ width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
+
+ max_stride: Int, optional
+ assumed maximal stride of the model for any dimension, when set too low the function may fail for
+ any value of width
+
+ Returns:
+ --------
+ receptive_field_size: List[Int]
+ receptive field size in all dimension
+
+ strides: List[Int]
+ stride in all dimensions
+
+ left_paddings: List[Int]
+ left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
+
+ Raises:
+ -------
+ ValueError, KeyError
+
+ """
+
+ input_shape = [num_channels] + num_dims * [width]
+ output_position1 = num_dims * [width // (2 * max_stride)]
+ output_position2 = num_dims * [width // (2 * max_stride) + 1]
+
+ low1, high1 = receptive_field(model, input_shape, output_position1)
+ low2, high2 = receptive_field(model, input_shape, output_position2)
+
+ widths1 = [h - l + 1 for l, h in zip(low1, high1)]
+ widths2 = [h - l + 1 for l, h in zip(low2, high2)]
+
+ if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]):
+ raise ValueError("[estimate_strides]: widths to small to determine strides")
+
+ receptive_field_size = widths1
+ strides = [l2 - l1 for l1, l2 in zip(low1, low2)]
+ left_paddings = [s * p - l for l, s, p in zip(low1, strides, output_position1)]
+
+ return receptive_field_size, strides, left_paddings
+
+def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False):
+ """ determines size of receptive field, strides and padding probabilistically
+
+
+ Parameters:
+ -----------
+ model: nn.Module or autograd function
+ fully convolutional model for which parameters are estimated
+
+ num_channels: Int
+ number of input channels for model
+
+ num_dims: Int
+ number of input dimensions for model (without channel dimension)
+
+ max_width: Int
+ maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
+
+ verbose: bool, optional
+ if true, the function prints parameters for individual trials
+
+ Returns:
+ --------
+ receptive_field_size: List[Int]
+ receptive field size in all dimension
+
+ strides: List[Int]
+ stride in all dimensions
+
+ left_paddings: List[Int]
+ left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
+
+ Raises:
+ -------
+ ValueError
+
+ """
+
+ max_stride = max_width // 2
+ stride = max_stride // 100
+ width = max_width // 100
+
+ if width_hint is not None: width = 2 * width_hint
+ if stride_hint is not None: stride = stride_hint
+
+ did_it = False
+ while width < max_width and stride < max_stride:
+ try:
+ if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}")
+ receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride)
+ did_it = True
+ except:
+ pass
+
+ if did_it: break
+
+ width *= 2
+ if width >= max_width and stride < max_stride:
+ stride *= 2
+ width = 2 * stride
+
+ if not did_it:
+ raise ValueError(f'could not determine conv parameter with given max_width={max_width}')
+
+ return receptive_field_size, strides, left_paddings
+
+
+class GradWeight(torch.autograd.Function):
+ def __init__(self):
+ super().__init__()
+
+ @staticmethod
+ def forward(ctx, x, weight):
+ ctx.save_for_backward(weight)
+ return x.clone()
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ weight, = ctx.saved_tensors
+
+ grad_input = grad_output * weight
+
+ return grad_input, None
+
+
+# API
+
+def relegance_gradient_weighting(x, weight):
+ """
+
+ Args:
+ x (torch.tensor): input tensor
+ weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass
+
+ Returns:
+ torch.tensor: the unmodified input tensor x
+
+ Raises:
+ RuntimeError: if estimation of parameters fails due to exceeded compute budget
+ """
+ if weight is None:
+ return x
+ else:
+ return GradWeight.apply(x, weight)
+
+
+
+def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False):
+ """ creates parameters for mapping back output domain relevance to input tomain
+
+ Args:
+ model (nn.Module or autograd.Function): fully convolutional model
+ num_channels (int): number of input channels to model
+ num_dims (int): number of input dimensions of model (without channel and batch dimension)
+ width_hint(int or None): optional hint at maximal width of receptive field
+ stride_hint(int or None): optional hint at maximal stride
+
+ Returns:
+ dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution
+ """
+
+ max_width = int(100000 / (10 ** num_dims))
+
+ did_it = False
+ try:
+ receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
+ did_it = True
+ except:
+ # try once again with larger max_width
+ max_width *= 10
+
+ # crash if exception is raised
+ try:
+ if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
+ except:
+ raise RuntimeError("could not determine parameters within given compute budget")
+
+ partition_kernel = create_partition_kernel(receptive_field_size, strides)
+ partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1)
+
+ tconv_parameters = {
+ 'kernel': partition_kernel,
+ 'receptive_field_shape': receptive_field_size,
+ 'stride': strides,
+ 'left_padding': left_paddings,
+ 'num_dims': num_dims
+ }
+
+ return tconv_parameters
+
+
+
+def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters):
+ """ maps output-domain relevance to input-domain relevance via transpose convolution
+
+ Args:
+ od_relevance (torch.tensor): output-domain relevance
+ tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel
+
+ Returns:
+ torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input.
+ Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a
+ convenient way to adjust the output to the correct size.
+
+ Raises:
+ ValueError: if number of dimensions is not supported
+ """
+
+ kernel = tconv_parameters['kernel'].to(od_relevance.device)
+ rf_shape = tconv_parameters['receptive_field_shape']
+ stride = tconv_parameters['stride']
+ left_padding = tconv_parameters['left_padding']
+
+ num_dims = len(kernel.shape) - 2
+
+ # repeat boundary values
+ od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)]
+ padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate')
+ od_padding = od_padding[::2]
+
+ # apply mapping and left trimming
+ if num_dims == 1:
+ id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride)
+ id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :]
+ elif num_dims == 2:
+ id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
+ id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:]
+ elif num_dims == 3:
+ id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
+ id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :]
+ else:
+ raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
+
+ return id_relevance
+
+
+def relegance_resize_relevance_to_input_size(reference_input, relevance):
+ """ adjusts size of relevance tensor to reference input size
+
+ Args:
+ reference_input (torch.tensor): discriminator input tensor for reference
+ relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input
+
+ Returns:
+ torch.tensor: resized relevance
+
+ Raises:
+ ValueError: if number of dimensions is not supported
+ """
+ resized_relevance = torch.zeros_like(reference_input)
+
+ num_dims = len(reference_input.shape) - 2
+ with torch.no_grad():
+ if num_dims == 1:
+ resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))]
+ elif num_dims == 2:
+ resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
+ elif num_dims == 3:
+ resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
+ else:
+ raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
+
+ return resized_relevance \ No newline at end of file