diff options
author | John Bauer <horatio@gmail.com> | 2022-11-13 01:16:44 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-13 03:11:37 +0300 |
commit | c16bd4e42181aad386ef43e41d8595d0d41ff4ca (patch) | |
tree | efdb0932c22c48ccb1bb06cb3a98a72200b24933 | |
parent | 33685a1f976a7e3c6024b2a23188bda336b657c5 (diff) |
Add AdamW with amsgrad as an option
-rw-r--r-- | stanza/models/common/utils.py | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py index 88f37bea..435538df 100644 --- a/stanza/models/common/utils.py +++ b/stanza/models/common/utils.py @@ -146,6 +146,8 @@ def get_optimizer(name, parameters, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0 extra_args["weight_decay"] = weight_decay if name == 'amsgrad': return torch.optim.Adam(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args) + elif name == 'amsgradw': + return torch.optim.AdamW(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args) elif name == 'sgd': return torch.optim.SGD(parameters, lr=lr, momentum=momentum, **extra_args) elif name == 'adagrad': |