diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-16 00:27:44 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-16 00:28:59 +0300 |
commit | 2b463a3499efba0c6a2ddfac761e0f4aa2707303 (patch) | |
tree | 02ce9299be5c936f3bc1131eb124b4f4fd6a6ac9 | |
parent | 82f48d368b41d8bc4286e1375419daacbd10dbca (diff) |
quantizing initial state with rdovae too
-rw-r--r-- | dnn/torch/rdovae/rdovae/rdovae.py | 45 | ||||
-rw-r--r-- | dnn/torch/rdovae/train_rdovae.py | 15 |
2 files changed, 38 insertions, 22 deletions
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index 1eec42c1..0dc943ec 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -372,7 +372,7 @@ class CoreDecoder(nn.Module): class StatisticalModel(nn.Module): - def __init__(self, quant_levels, latent_dim): + def __init__(self, quant_levels, latent_dim, state_dim): """ Statistical model for latent space Computes scaling, deadzone, r, and theta @@ -383,8 +383,10 @@ class StatisticalModel(nn.Module): # copy parameters self.latent_dim = latent_dim + self.state_dim = state_dim + self.total_dim = latent_dim + state_dim self.quant_levels = quant_levels - self.embedding_dim = 6 * latent_dim + self.embedding_dim = 6 * self.total_dim # quantization embedding self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim) @@ -400,12 +402,12 @@ class StatisticalModel(nn.Module): x = self.quant_embedding(quant_ids) # CAVE: theta_soft is not used anymore. Kick it out? - quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim]) - dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim]) - theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim]) - r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim]) - theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim]) - r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim]) + quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim]) + dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim]) + theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim]) + r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim]) + theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim]) + r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim]) return { @@ -445,7 +447,7 @@ class RDOVAE(nn.Module): self.state_dropout_rate = state_dropout_rate # submodules encoder and decoder share the statistical model - self.statistical_model = StatisticalModel(quant_levels, latent_dim) + self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim) self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim)) self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim)) @@ -522,13 +524,18 @@ class RDOVAE(nn.Module): z, states = self.core_encoder(features) # scaling, dead-zone and quantization - z = z * statistical_model['quant_scale'] - z = soft_dead_zone(z, statistical_model['dead_zone']) + z = z * statistical_model['quant_scale'][:,:,:self.latent_dim] + z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim]) # quantization - z_q = hard_quantize(z) / statistical_model['quant_scale'] - z_n = noise_quantize(z) / statistical_model['quant_scale'] - states_q = soft_pvq(states, self.pvq_num_pulses) + z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim] + z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim] + #states_q = soft_pvq(states, self.pvq_num_pulses) + states = states * statistical_model['quant_scale'][:,:,self.latent_dim:] + states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:]) + + states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:] + states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:] if self.state_dropout_rate > 0: drop = torch.rand(states_q.size(0)) < self.state_dropout_rate @@ -551,6 +558,7 @@ class RDOVAE(nn.Module): # decoder with soft quantized input z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) + dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :] features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) @@ -558,6 +566,7 @@ class RDOVAE(nn.Module): 'outputs_hard_quant' : outputs_hq, 'outputs_soft_quant' : outputs_sq, 'z' : z, + 'states' : states, 'statistical_model' : statistical_model } @@ -586,11 +595,11 @@ class RDOVAE(nn.Module): stats = self.statistical_model(q_ids) - zq = z * stats['quant_scale'] - zq = soft_dead_zone(zq, stats['dead_zone']) + zq = z * stats['quant_scale'][:self.latent_dim] + zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim]) zq = torch.round(zq) - sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False) + sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False) return zq, sizes @@ -599,7 +608,7 @@ class RDOVAE(nn.Module): stats = self.statistical_model(q_ids) - z = zq / stats['quant_scale'] + z = zq / stats['quant_scale'][:,:,:self.latent_dim] return z diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py index f29ed98f..3f8484e1 100644 --- a/dnn/torch/rdovae/train_rdovae.py +++ b/dnn/torch/rdovae/train_rdovae.py @@ -172,6 +172,7 @@ if __name__ == '__main__': running_soft_rate_loss = 0 running_total_loss = 0 running_rate_metric = 0 + running_states_rate_metric = 0 previous_total_loss = 0 running_first_frame_loss = 0 @@ -194,17 +195,21 @@ if __name__ == '__main__': # collect outputs z = model_output['z'] + states = model_output['states'] outputs_hard_quant = model_output['outputs_hard_quant'] outputs_soft_quant = model_output['outputs_soft_quant'] statistical_model = model_output['statistical_model'] # rate loss - hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False) - soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False) - soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate) - hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate) + hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False) + soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False) + states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False) + states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False) + soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate)) + hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate)) rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss) hard_rate_metric = torch.mean(hard_rate) + states_rate_metric = torch.mean(states_hard_rate) ## distortion losses @@ -242,6 +247,7 @@ if __name__ == '__main__': running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu()) running_rate_loss += float(rate_loss.detach().cpu()) running_rate_metric += float(hard_rate_metric.detach().cpu()) + running_states_rate_metric += float(states_rate_metric.detach().cpu()) running_total_loss += float(total_loss.detach().cpu()) running_first_frame_loss += float(first_frame_loss.detach().cpu()) running_soft_rate_loss += float(soft_rate_loss.detach().cpu()) @@ -256,6 +262,7 @@ if __name__ == '__main__': dist_sq=running_soft_dist_loss / (i + 1), rate_loss=running_rate_loss / (i + 1), rate=running_rate_metric / (i + 1), + states_rate=running_states_rate_metric / (i + 1), ffloss=running_first_frame_loss / (i + 1), rateloss_hard=running_hard_rate_loss / (i + 1), rateloss_soft=running_soft_rate_loss / (i + 1) |