Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-14 22:12:54 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-15 01:21:18 +0300
commit6a90ad4bacf923c88438da53219c48355b847ed3 (patch)
treee2b4909229c9f24664f93cdff4d05643f65205c0
parent788b2a9cf3d32b79149480073c446a00c729b5be (diff)
Hide the imports of SiLU and Mish from older versions of torch. #1120
-rw-r--r--stanza/models/constituency/utils.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py
index 5ae19f01..a11581dc 100644
--- a/stanza/models/constituency/utils.py
+++ b/stanza/models/constituency/utils.py
@@ -107,10 +107,17 @@ NONLINEARITY = {
'relu': nn.ReLU,
'gelu': nn.GELU,
'leaky_relu': nn.LeakyReLU,
- 'silu': nn.SiLU,
- 'mish': nn.Mish,
}
+# separating these out allows for backwards compatibility with earlier versions of pytorch
+# NOTE torch compatibility: if we ever *release* models with these
+# activation functions, we will need to break that compatibility
+if hasattr(nn, 'SiLU'):
+ NONLINEARITY['silu'] = nn.SiLU
+
+if hasattr(nn, 'Mish'):
+ NONLINEARITY['mish'] = nn.Mish
+
def build_nonlinearity(nonlinearity):
"""
Look up "nonlinearity" in a map from function name to function, build the appropriate layer.