diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-30 10:56:41 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-30 10:57:38 +0300 |
commit | 750bcff15355c63a7d596cc653da9f7b629555c8 (patch) | |
tree | ff5f0d0ee682fe9891c2cea37fa09af92b8ebfdc | |
parent | 4af5f57415ee4d6eca4a757afc4a503d54ff3abf (diff) |
Increase size, init convolutions with zerosexp_densenet3x
-rw-r--r-- | dnn/dred_rdovae_dec.c | 2 | ||||
-rw-r--r-- | dnn/dred_rdovae_enc.c | 2 | ||||
-rw-r--r-- | dnn/torch/rdovae/rdovae/rdovae.py | 58 |
3 files changed, 31 insertions, 31 deletions
diff --git a/dnn/dred_rdovae_dec.c b/dnn/dred_rdovae_dec.c index 6cfd6577..96c933bc 100644 --- a/dnn/dred_rdovae_dec.c +++ b/dnn/dred_rdovae_dec.c @@ -37,7 +37,7 @@ static void conv1_cond_init(float *mem, const float *input, int len, int dilatio { if (!*init) { int i; - for (i=0;i<dilation;i++) OPUS_COPY(&mem[i*len], input, len); + for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len); } *init = 1; } diff --git a/dnn/dred_rdovae_enc.c b/dnn/dred_rdovae_enc.c index b34ac40e..ee55e7cc 100644 --- a/dnn/dred_rdovae_enc.c +++ b/dnn/dred_rdovae_enc.c @@ -39,7 +39,7 @@ static void conv1_cond_init(float *mem, const float *input, int len, int dilatio { if (!*init) { int i; - for (i=0;i<dilation;i++) OPUS_COPY(&mem[i*len], input, len); + for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len); } *init = 1; } diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index ad118b99..1a16f3e4 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -233,7 +233,7 @@ class MyConv(nn.Module): self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation) def forward(self, x, state=None): device = x.device - conv_in = torch.cat([x[:,0:1,:].repeat(1,self.dilation,1), x], -2).permute(0, 2, 1) + conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1) return torch.tanh(self.conv(conv_in)).permute(0, 2, 1) class CoreEncoder(nn.Module): @@ -263,20 +263,20 @@ class CoreEncoder(nn.Module): # layers self.dense_1 = nn.Linear(self.input_dim, 64) self.gru1 = nn.GRU(64, 64, batch_first=True) - self.conv1 = MyConv(128, 32) - self.gru2 = nn.GRU(160, 64, batch_first=True) - self.conv2 = MyConv(224, 32, dilation=2) - self.gru3 = nn.GRU(256, 64, batch_first=True) - self.conv3 = MyConv(320, 32, dilation=2) - self.gru4 = nn.GRU(352, 64, batch_first=True) - self.conv4 = MyConv(416, 32, dilation=2) - self.gru5 = nn.GRU(448, 64, batch_first=True) - self.conv5 = MyConv(512, 32, dilation=2) + self.conv1 = MyConv(128, 96) + self.gru2 = nn.GRU(224, 64, batch_first=True) + self.conv2 = MyConv(288, 96, dilation=2) + self.gru3 = nn.GRU(384, 64, batch_first=True) + self.conv3 = MyConv(448, 96, dilation=2) + self.gru4 = nn.GRU(544, 64, batch_first=True) + self.conv4 = MyConv(608, 96, dilation=2) + self.gru5 = nn.GRU(704, 64, batch_first=True) + self.conv5 = MyConv(768, 96, dilation=2) - self.z_dense = nn.Linear(544, self.output_dim) + self.z_dense = nn.Linear(864, self.output_dim) - self.state_dense_1 = nn.Linear(544, self.STATE_HIDDEN) + self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN) self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size) nb_params = sum(p.numel() for p in self.parameters()) @@ -341,20 +341,20 @@ class CoreDecoder(nn.Module): # layers self.dense_1 = nn.Linear(self.input_size, 96) - self.gru1 = nn.GRU(96, 64, batch_first=True) - self.conv1 = MyConv(160, 32) - self.gru2 = nn.GRU(192, 64, batch_first=True) - self.conv2 = MyConv(256, 32) - self.gru3 = nn.GRU(288, 64, batch_first=True) - self.conv3 = MyConv(352, 32) - self.gru4 = nn.GRU(384, 64, batch_first=True) - self.conv4 = MyConv(448, 32) - self.gru5 = nn.GRU(480, 64, batch_first=True) - self.conv5 = MyConv(544, 32) - self.output = nn.Linear(576, self.FRAMES_PER_STEP * self.output_dim) + self.gru1 = nn.GRU(96, 96, batch_first=True) + self.conv1 = MyConv(192, 32) + self.gru2 = nn.GRU(224, 96, batch_first=True) + self.conv2 = MyConv(320, 32) + self.gru3 = nn.GRU(352, 96, batch_first=True) + self.conv3 = MyConv(448, 32) + self.gru4 = nn.GRU(480, 96, batch_first=True) + self.conv4 = MyConv(576, 32) + self.gru5 = nn.GRU(608, 96, batch_first=True) + self.conv5 = MyConv(704, 32) + self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim) self.hidden_init = nn.Linear(self.state_size, 128) - self.gru_init = nn.Linear(128, 320) + self.gru_init = nn.Linear(128, 480) nb_params = sum(p.numel() for p in self.parameters()) print(f"decoder: {nb_params} weights") @@ -365,11 +365,11 @@ class CoreDecoder(nn.Module): hidden = torch.tanh(self.hidden_init(initial_state)) gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2)) - h1_state = gru_state[:,:,:64].contiguous() - h2_state = gru_state[:,:,64:128].contiguous() - h3_state = gru_state[:,:,128:192].contiguous() - h4_state = gru_state[:,:,192:256].contiguous() - h5_state = gru_state[:,:,256:].contiguous() + h1_state = gru_state[:,:,:96].contiguous() + h2_state = gru_state[:,:,96:192].contiguous() + h3_state = gru_state[:,:,192:288].contiguous() + h4_state = gru_state[:,:,288:384].contiguous() + h5_state = gru_state[:,:,384:].contiguous() # run decoding layer stack x = torch.tanh(self.dense_1(z)) |