Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-16 00:27:44 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-16 00:28:59 +0300
commit2b463a3499efba0c6a2ddfac761e0f4aa2707303 (patch)
tree02ce9299be5c936f3bc1131eb124b4f4fd6a6ac9
parent82f48d368b41d8bc4286e1375419daacbd10dbca (diff)
quantizing initial state with rdovae too
-rw-r--r--dnn/torch/rdovae/rdovae/rdovae.py45
-rw-r--r--dnn/torch/rdovae/train_rdovae.py15
2 files changed, 38 insertions, 22 deletions
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
index 1eec42c1..0dc943ec 100644
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -372,7 +372,7 @@ class CoreDecoder(nn.Module):
class StatisticalModel(nn.Module):
- def __init__(self, quant_levels, latent_dim):
+ def __init__(self, quant_levels, latent_dim, state_dim):
""" Statistical model for latent space
Computes scaling, deadzone, r, and theta
@@ -383,8 +383,10 @@ class StatisticalModel(nn.Module):
# copy parameters
self.latent_dim = latent_dim
+ self.state_dim = state_dim
+ self.total_dim = latent_dim + state_dim
self.quant_levels = quant_levels
- self.embedding_dim = 6 * latent_dim
+ self.embedding_dim = 6 * self.total_dim
# quantization embedding
self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
@@ -400,12 +402,12 @@ class StatisticalModel(nn.Module):
x = self.quant_embedding(quant_ids)
# CAVE: theta_soft is not used anymore. Kick it out?
- quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
- dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
- theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
- r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
- theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
- r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
+ quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim])
+ dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim])
+ theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim])
+ r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim])
+ theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim])
+ r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim])
return {
@@ -445,7 +447,7 @@ class RDOVAE(nn.Module):
self.state_dropout_rate = state_dropout_rate
# submodules encoder and decoder share the statistical model
- self.statistical_model = StatisticalModel(quant_levels, latent_dim)
+ self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
@@ -522,13 +524,18 @@ class RDOVAE(nn.Module):
z, states = self.core_encoder(features)
# scaling, dead-zone and quantization
- z = z * statistical_model['quant_scale']
- z = soft_dead_zone(z, statistical_model['dead_zone'])
+ z = z * statistical_model['quant_scale'][:,:,:self.latent_dim]
+ z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim])
# quantization
- z_q = hard_quantize(z) / statistical_model['quant_scale']
- z_n = noise_quantize(z) / statistical_model['quant_scale']
- states_q = soft_pvq(states, self.pvq_num_pulses)
+ z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
+ z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
+ #states_q = soft_pvq(states, self.pvq_num_pulses)
+ states = states * statistical_model['quant_scale'][:,:,self.latent_dim:]
+ states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:])
+
+ states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
+ states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
if self.state_dropout_rate > 0:
drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
@@ -551,6 +558,7 @@ class RDOVAE(nn.Module):
# decoder with soft quantized input
z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
+ dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
@@ -558,6 +566,7 @@ class RDOVAE(nn.Module):
'outputs_hard_quant' : outputs_hq,
'outputs_soft_quant' : outputs_sq,
'z' : z,
+ 'states' : states,
'statistical_model' : statistical_model
}
@@ -586,11 +595,11 @@ class RDOVAE(nn.Module):
stats = self.statistical_model(q_ids)
- zq = z * stats['quant_scale']
- zq = soft_dead_zone(zq, stats['dead_zone'])
+ zq = z * stats['quant_scale'][:self.latent_dim]
+ zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim])
zq = torch.round(zq)
- sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
+ sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False)
return zq, sizes
@@ -599,7 +608,7 @@ class RDOVAE(nn.Module):
stats = self.statistical_model(q_ids)
- z = zq / stats['quant_scale']
+ z = zq / stats['quant_scale'][:,:,:self.latent_dim]
return z
diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py
index f29ed98f..3f8484e1 100644
--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -172,6 +172,7 @@ if __name__ == '__main__':
running_soft_rate_loss = 0
running_total_loss = 0
running_rate_metric = 0
+ running_states_rate_metric = 0
previous_total_loss = 0
running_first_frame_loss = 0
@@ -194,17 +195,21 @@ if __name__ == '__main__':
# collect outputs
z = model_output['z']
+ states = model_output['states']
outputs_hard_quant = model_output['outputs_hard_quant']
outputs_soft_quant = model_output['outputs_soft_quant']
statistical_model = model_output['statistical_model']
# rate loss
- hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
- soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
- soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
- hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
+ hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
+ soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
+ states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
+ states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
+ soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
+ hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
hard_rate_metric = torch.mean(hard_rate)
+ states_rate_metric = torch.mean(states_hard_rate)
## distortion losses
@@ -242,6 +247,7 @@ if __name__ == '__main__':
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
running_rate_loss += float(rate_loss.detach().cpu())
running_rate_metric += float(hard_rate_metric.detach().cpu())
+ running_states_rate_metric += float(states_rate_metric.detach().cpu())
running_total_loss += float(total_loss.detach().cpu())
running_first_frame_loss += float(first_frame_loss.detach().cpu())
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
@@ -256,6 +262,7 @@ if __name__ == '__main__':
dist_sq=running_soft_dist_loss / (i + 1),
rate_loss=running_rate_loss / (i + 1),
rate=running_rate_metric / (i + 1),
+ states_rate=running_states_rate_metric / (i + 1),
ffloss=running_first_frame_loss / (i + 1),
rateloss_hard=running_hard_rate_loss / (i + 1),
rateloss_soft=running_soft_rate_loss / (i + 1)