diff options
author | John Bauer <horatio@gmail.com> | 2022-11-03 20:26:10 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-04 18:40:33 +0300 |
commit | 0efa21be676593e6b96893a5dbe60b2994fe69c9 (patch) | |
tree | d941185bd20afda8a8e4fd2d1fb2271044150e8e | |
parent | ad26679ea5297860c318b64084fb56ec35bc8bdb (diff) |
AddSinulsoidalEncoding as a module
Allow tensors of 2d (no batch) to the SinusoidalEncoding modules
-rw-r--r-- | stanza/models/constituency/positional_encoding.py | 38 | ||||
-rw-r--r-- | stanza/tests/constituency/test_positional_encoding.py | 18 |
2 files changed, 52 insertions, 4 deletions
diff --git a/stanza/models/constituency/positional_encoding.py b/stanza/models/constituency/positional_encoding.py index ec641992..cce1c5b3 100644 --- a/stanza/models/constituency/positional_encoding.py +++ b/stanza/models/constituency/positional_encoding.py @@ -9,6 +9,9 @@ import torch from torch import nn class SinusoidalEncoding(nn.Module): + """ + Uses sine & cosine to represent position + """ def __init__(self, model_dim, max_len): super().__init__() self.register_buffer('pe', self.build_position(model_dim, max_len)) @@ -31,6 +34,8 @@ class SinusoidalEncoding(nn.Module): device = self.pe.device shape = self.pe.shape[1] self.register_buffer('pe', None) + # TODO: this may result in very poor performance + # in the event of a model that increases size one at a time self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device)) return self.pe[x] @@ -38,6 +43,30 @@ class SinusoidalEncoding(nn.Module): return self.pe.shape[0] +class AddSinusoidalEncoding(nn.Module): + """ + Uses sine & cosine to represent position + """ + def __init__(self, d_model=256, max_len=512): + super().__init__() + self.encoding = SinusoidalEncoding(d_model, max_len) + + def forward(self, x): + """ + Adds the positional encoding to the input tensor + + The tensor is expected to be of the shape B, N, D + Properly masking the output tensor is up to the caller + """ + if len(x.shape) == 3: + timing = self.encoding(torch.arange(x.shape[1], device=x.device)) + timing = timing.expand(x.shape[0], -1, -1) + return x + timing + elif len(x.shape) == 2: + timing = self.encoding(torch.arange(x.shape[0], device=x.device)) + return x + timing + + class ConcatSinusoidalEncoding(nn.Module): """ Uses sine & cosine to represent position @@ -47,8 +76,11 @@ class ConcatSinusoidalEncoding(nn.Module): self.encoding = SinusoidalEncoding(d_model, max_len) def forward(self, x): - timing = self.encoding(torch.arange(x.shape[1], device=x.device)) - timing = timing.expand(x.shape[0], -1, -1) + if len(x.shape) == 3: + timing = self.encoding(torch.arange(x.shape[1], device=x.device)) + timing = timing.expand(x.shape[0], -1, -1) + else: + timing = self.encoding(torch.arange(x.shape[0], device=x.device)) + out = torch.cat((x, timing), dim=-1) return out - diff --git a/stanza/tests/constituency/test_positional_encoding.py b/stanza/tests/constituency/test_positional_encoding.py index ede2f9a7..864219ea 100644 --- a/stanza/tests/constituency/test_positional_encoding.py +++ b/stanza/tests/constituency/test_positional_encoding.py @@ -3,7 +3,7 @@ import pytest import torch from stanza import Pipeline -from stanza.models.constituency.positional_encoding import SinusoidalEncoding +from stanza.models.constituency.positional_encoding import SinusoidalEncoding, AddSinusoidalEncoding from stanza.tests import * @@ -27,3 +27,19 @@ def test_arange(): foo = encoding(torch.arange(4)) assert foo.shape == (4, 10) assert encoding.max_len() == 4 + +def test_add(): + encoding = AddSinusoidalEncoding(d_model=10, max_len=4) + x = torch.zeros(1, 4, 10) + y = encoding(x) + + r = torch.randn(1, 4, 10) + r2 = encoding(r) + + assert torch.allclose(r2 - r, y, atol=1e-07) + + r = torch.randn(2, 4, 10) + r2 = encoding(r) + + assert torch.allclose(r2[0] - r[0], y, atol=1e-07) + assert torch.allclose(r2[1] - r[1], y, atol=1e-07) |