Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-03 20:26:10 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-04 18:40:33 +0300
commit0efa21be676593e6b96893a5dbe60b2994fe69c9 (patch)
treed941185bd20afda8a8e4fd2d1fb2271044150e8e
parentad26679ea5297860c318b64084fb56ec35bc8bdb (diff)
AddSinulsoidalEncoding as a module
Allow tensors of 2d (no batch) to the SinusoidalEncoding modules
-rw-r--r--stanza/models/constituency/positional_encoding.py38
-rw-r--r--stanza/tests/constituency/test_positional_encoding.py18
2 files changed, 52 insertions, 4 deletions
diff --git a/stanza/models/constituency/positional_encoding.py b/stanza/models/constituency/positional_encoding.py
index ec641992..cce1c5b3 100644
--- a/stanza/models/constituency/positional_encoding.py
+++ b/stanza/models/constituency/positional_encoding.py
@@ -9,6 +9,9 @@ import torch
from torch import nn
class SinusoidalEncoding(nn.Module):
+ """
+ Uses sine & cosine to represent position
+ """
def __init__(self, model_dim, max_len):
super().__init__()
self.register_buffer('pe', self.build_position(model_dim, max_len))
@@ -31,6 +34,8 @@ class SinusoidalEncoding(nn.Module):
device = self.pe.device
shape = self.pe.shape[1]
self.register_buffer('pe', None)
+ # TODO: this may result in very poor performance
+ # in the event of a model that increases size one at a time
self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device))
return self.pe[x]
@@ -38,6 +43,30 @@ class SinusoidalEncoding(nn.Module):
return self.pe.shape[0]
+class AddSinusoidalEncoding(nn.Module):
+ """
+ Uses sine & cosine to represent position
+ """
+ def __init__(self, d_model=256, max_len=512):
+ super().__init__()
+ self.encoding = SinusoidalEncoding(d_model, max_len)
+
+ def forward(self, x):
+ """
+ Adds the positional encoding to the input tensor
+
+ The tensor is expected to be of the shape B, N, D
+ Properly masking the output tensor is up to the caller
+ """
+ if len(x.shape) == 3:
+ timing = self.encoding(torch.arange(x.shape[1], device=x.device))
+ timing = timing.expand(x.shape[0], -1, -1)
+ return x + timing
+ elif len(x.shape) == 2:
+ timing = self.encoding(torch.arange(x.shape[0], device=x.device))
+ return x + timing
+
+
class ConcatSinusoidalEncoding(nn.Module):
"""
Uses sine & cosine to represent position
@@ -47,8 +76,11 @@ class ConcatSinusoidalEncoding(nn.Module):
self.encoding = SinusoidalEncoding(d_model, max_len)
def forward(self, x):
- timing = self.encoding(torch.arange(x.shape[1], device=x.device))
- timing = timing.expand(x.shape[0], -1, -1)
+ if len(x.shape) == 3:
+ timing = self.encoding(torch.arange(x.shape[1], device=x.device))
+ timing = timing.expand(x.shape[0], -1, -1)
+ else:
+ timing = self.encoding(torch.arange(x.shape[0], device=x.device))
+
out = torch.cat((x, timing), dim=-1)
return out
-
diff --git a/stanza/tests/constituency/test_positional_encoding.py b/stanza/tests/constituency/test_positional_encoding.py
index ede2f9a7..864219ea 100644
--- a/stanza/tests/constituency/test_positional_encoding.py
+++ b/stanza/tests/constituency/test_positional_encoding.py
@@ -3,7 +3,7 @@ import pytest
import torch
from stanza import Pipeline
-from stanza.models.constituency.positional_encoding import SinusoidalEncoding
+from stanza.models.constituency.positional_encoding import SinusoidalEncoding, AddSinusoidalEncoding
from stanza.tests import *
@@ -27,3 +27,19 @@ def test_arange():
foo = encoding(torch.arange(4))
assert foo.shape == (4, 10)
assert encoding.max_len() == 4
+
+def test_add():
+ encoding = AddSinusoidalEncoding(d_model=10, max_len=4)
+ x = torch.zeros(1, 4, 10)
+ y = encoding(x)
+
+ r = torch.randn(1, 4, 10)
+ r2 = encoding(r)
+
+ assert torch.allclose(r2 - r, y, atol=1e-07)
+
+ r = torch.randn(2, 4, 10)
+ r2 = encoding(r)
+
+ assert torch.allclose(r2[0] - r[0], y, atol=1e-07)
+ assert torch.allclose(r2[1] - r[1], y, atol=1e-07)