Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/dnntools/dnntools/relegance/meta_critic.py')
-rw-r--r--dnn/torch/dnntools/dnntools/relegance/meta_critic.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/dnn/torch/dnntools/dnntools/relegance/meta_critic.py b/dnn/torch/dnntools/dnntools/relegance/meta_critic.py
new file mode 100644
index 00000000..1af0f8ff
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/relegance/meta_critic.py
@@ -0,0 +1,85 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+class MetaCritic():
+ def __init__(self, normalize=False, gamma=0.9, beta=0.0, joint_stats=False):
+ """ Class for assessing relevance of discriminator scores
+
+ Args:
+ gamma (float, optional): update rate for tracking discriminator stats. Defaults to 0.9.
+ beta (float, optional): Miminum confidence related threshold. Defaults to 0.0.
+ """
+ self.normalize = normalize
+ self.gamma = gamma
+ self.beta = beta
+ self.joint_stats = joint_stats
+
+ self.disc_stats = dict()
+
+ def __call__(self, disc_id, real_scores, generated_scores):
+ """ calculates relevance from normalized scores
+
+ Args:
+ disc_id (any valid key): id for tracking discriminator statistics
+ real_scores (torch.tensor): scores for real data
+ generated_scores (torch.tensor): scores for generated data; expecting device to match real_scores.device
+
+ Returns:
+ torch.tensor: output-domain relevance
+ """
+
+ if self.normalize:
+ real_std = torch.std(real_scores.detach()).cpu().item()
+ gen_std = torch.std(generated_scores.detach()).cpu().item()
+ std = (real_std**2 + gen_std**2) ** .5
+ mean = torch.mean(real_scores.detach()).cpu().item() - torch.mean(generated_scores.detach()).cpu().item()
+
+ key = 0 if self.joint_stats else disc_id
+
+ if key in self.disc_stats:
+ self.disc_stats[key]['std'] = self.gamma * self.disc_stats[key]['std'] + (1 - self.gamma) * std
+ self.disc_stats[key]['mean'] = self.gamma * self.disc_stats[key]['mean'] + (1 - self.gamma) * mean
+ else:
+ self.disc_stats[key] = {
+ 'std': std + 1e-5,
+ 'mean': mean
+ }
+
+ std = self.disc_stats[key]['std']
+ mean = self.disc_stats[key]['mean']
+ else:
+ mean, std = 0, 1
+
+ relevance = torch.relu((real_scores - generated_scores - mean) / std + mean - self.beta)
+
+ if False: print(f"relevance({disc_id}): {relevance.min()=} {relevance.max()=} {relevance.mean()=}")
+
+ return relevance \ No newline at end of file