Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-13 01:16:44 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-13 03:11:37 +0300
commitc16bd4e42181aad386ef43e41d8595d0d41ff4ca (patch)
treeefdb0932c22c48ccb1bb06cb3a98a72200b24933
parent33685a1f976a7e3c6024b2a23188bda336b657c5 (diff)
Add AdamW with amsgrad as an option
-rw-r--r--stanza/models/common/utils.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py
index 88f37bea..435538df 100644
--- a/stanza/models/common/utils.py
+++ b/stanza/models/common/utils.py
@@ -146,6 +146,8 @@ def get_optimizer(name, parameters, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0
extra_args["weight_decay"] = weight_decay
if name == 'amsgrad':
return torch.optim.Adam(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)
+ elif name == 'amsgradw':
+ return torch.optim.AdamW(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)
elif name == 'sgd':
return torch.optim.SGD(parameters, lr=lr, momentum=momentum, **extra_args)
elif name == 'adagrad':