diff options
author | John Bauer <horatio@gmail.com> | 2022-11-01 23:58:28 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-01 23:58:28 +0300 |
commit | 8d9ab720182cd3f0e60be473772c79f3ed571128 (patch) | |
tree | f02ad967f5236336f0b7e3a33e49628c276c909e | |
parent | 1a5d64667e5413626b0f3eb2ff8875ff4bbfe0a1 (diff) |
slice in a more generic manner when copying model. makes it easier to make future changes
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 42e8fd83..6c10c472 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -542,10 +542,11 @@ class LSTMModel(BaseModel, nn.Module): elif name.startswith('word_lstm.weight_ih_l0'): # bottom layer shape may have changed from adding a new pattn / lattn block my_parameter = self.get_parameter(name) - copy_size = min(other_parameter.data.shape[1], my_parameter.data.shape[1]) + # -1 so that it can be converted easier to a different parameter + copy_size = min(other_parameter.data.shape[-1], my_parameter.data.shape[-1]) #new_values = my_parameter.data.clone().detach() new_values = torch.zeros_like(my_parameter.data) - new_values[:, :copy_size] = other_parameter.data[:, :copy_size] + new_values[..., :copy_size] = other_parameter.data[..., :copy_size] my_parameter.data.copy_(new_values) else: self.get_parameter(name).data.copy_(other_parameter.data) |