Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-01 23:58:28 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-01 23:58:28 +0300
commit8d9ab720182cd3f0e60be473772c79f3ed571128 (patch)
treef02ad967f5236336f0b7e3a33e49628c276c909e
parent1a5d64667e5413626b0f3eb2ff8875ff4bbfe0a1 (diff)
slice in a more generic manner when copying model. makes it easier to make future changes
-rw-r--r--stanza/models/constituency/lstm_model.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index 42e8fd83..6c10c472 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -542,10 +542,11 @@ class LSTMModel(BaseModel, nn.Module):
elif name.startswith('word_lstm.weight_ih_l0'):
# bottom layer shape may have changed from adding a new pattn / lattn block
my_parameter = self.get_parameter(name)
- copy_size = min(other_parameter.data.shape[1], my_parameter.data.shape[1])
+ # -1 so that it can be converted easier to a different parameter
+ copy_size = min(other_parameter.data.shape[-1], my_parameter.data.shape[-1])
#new_values = my_parameter.data.clone().detach()
new_values = torch.zeros_like(my_parameter.data)
- new_values[:, :copy_size] = other_parameter.data[:, :copy_size]
+ new_values[..., :copy_size] = other_parameter.data[..., :copy_size]
my_parameter.data.copy_(new_values)
else:
self.get_parameter(name).data.copy_(other_parameter.data)