diff options
author | John Bauer <horatio@gmail.com> | 2022-01-28 21:43:11 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-06-17 06:46:25 +0300 |
commit | 664d86111bd732e98468248f7e3c461b58d87180 (patch) | |
tree | 0491a2223ff85e6db4ae97e9521bec35f9ab682e | |
parent | 7d2e45a8626706db26ae104d64b0160466c81b6c (diff) |
low weight decay for the normscon_warmup
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 13 | ||||
-rw-r--r-- | stanza/models/constituency/trainer.py | 11 | ||||
-rw-r--r-- | stanza/models/constituency/utils.py | 18 |
3 files changed, 31 insertions, 11 deletions
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 73169f38..9e807ef1 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -488,6 +488,19 @@ class LSTMModel(BaseModel, nn.Module): lines.append("%s %.6g" % (name, torch.norm(param).item())) logger.info("\n".join(lines)) + def is_low_decay_parameter(self, name): + if name.find("embedding") >= 0: + return False + return True + + def base_parameters(self): + params = [param for name, param in self.named_parameters() if not self.is_low_decay_parameter(name)] + return params + + def low_decay_parameters(self): + params = [param for name, param in self.named_parameters() if self.is_low_decay_parameter(name)] + return params + def initial_word_queues(self, tagged_word_lists): """ Produce initial word queues out of the model's LSTMs for use in the tagged word lists. diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index cd2b8a6a..00edeeef 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -339,9 +339,9 @@ def build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_cha # run adadelta over the model for a few iterations # then use just the embeddings from the temp model logger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['adadelta_warmup']) - temp_model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, args) + model = LSTMModel(pt, forward_charlm, backward_charlm, bert_model, bert_tokenizer, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, args) if args['cuda']: - temp_model.cuda() + model.cuda() temp_args = dict(args) temp_args['optim'] = 'adadelta' temp_args['learning_rate'] = DEFAULT_LEARNING_RATES['adadelta'] @@ -349,11 +349,10 @@ def build_trainer(args, train_trees, dev_trees, pt, forward_charlm, backward_cha temp_args['learning_rho'] = DEFAULT_LEARNING_RHO temp_args['weight_decay'] = DEFAULT_WEIGHT_DECAY['adadelta'] temp_args['epochs'] = args['adadelta_warmup'] - temp_optim = build_optimizer(temp_args, temp_model) - temp_trainer = Trainer(temp_args, temp_model, temp_optim) + temp_args['separate_learning'] = True # find some other way to pass this in when making an optimizer + temp_optim = build_optimizer(temp_args, model) + temp_trainer = Trainer(temp_args, model, temp_optim) iterate_training(temp_trainer, train_trees, train_sequences, train_transitions, dev_trees, temp_args, None, None, evaluator) - logger.info("Using embedding weights from initial training to train full model") - model.init_embeddings_from_other(temp_model) optimizer = build_optimizer(args, model) scheduler = build_scheduler(args, optimizer) diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py index 5e896989..440c9543 100644 --- a/stanza/models/constituency/utils.py +++ b/stanza/models/constituency/utils.py @@ -99,25 +99,33 @@ def build_optimizer(args, model): """ Build an optimizer based on the arguments given """ + if args.get('separate_learning', False): + parameters = [ + {'params': model.base_parameters()}, + {'params': model.low_decay_parameters(), 'weight_decay': args['weight_decay'] * 0.01, 'lr': args['learning_rate'] * 0.01} + ] + else: + parameters = model.parameters() + if args['optim'].lower() == 'sgd': - optimizer = optim.SGD(model.parameters(), lr=args['learning_rate'], momentum=0.9, weight_decay=args['weight_decay']) + optimizer = optim.SGD(parameters, lr=args['learning_rate'], momentum=0.9, weight_decay=args['weight_decay']) elif args['optim'].lower() == 'adadelta': - optimizer = optim.Adadelta(model.parameters(), lr=args['learning_rate'], eps=args['learning_eps'], weight_decay=args['weight_decay'], rho=args['learning_rho']) + optimizer = optim.Adadelta(parameters, lr=args['learning_rate'], eps=args['learning_eps'], weight_decay=args['weight_decay'], rho=args['learning_rho']) elif args['optim'].lower() == 'adamw': - optimizer = optim.AdamW(model.parameters(), lr=args['learning_rate'], betas=(0.9, args['learning_beta2']), eps=args['learning_eps'], weight_decay=args['weight_decay']) + optimizer = optim.AdamW(parameters, lr=args['learning_rate'], betas=(0.9, args['learning_beta2']), eps=args['learning_eps'], weight_decay=args['weight_decay']) elif args['optim'].lower() == 'adabelief': try: from adabelief_pytorch import AdaBelief except ModuleNotFoundError as e: raise ModuleNotFoundError("Could not create adabelief optimizer. Perhaps the adabelief-pytorch package is not installed") from e # TODO: make these args - optimizer = AdaBelief(model.parameters(), lr=args['learning_rate'], eps=args['learning_eps'], weight_decay=args['weight_decay'], weight_decouple=False, rectify=False) + optimizer = AdaBelief(parameters, lr=args['learning_rate'], eps=args['learning_eps'], weight_decay=args['weight_decay'], weight_decouple=False, rectify=False) elif args['optim'].lower() == 'madgrad': try: import madgrad except ModuleNotFoundError as e: raise ModuleNotFoundError("Could not create madgrad optimizer. Perhaps the madgrad package is not installed") from e - optimizer = madgrad.MADGRAD(model.parameters(), lr=args['learning_rate'], weight_decay=args['weight_decay']) + optimizer = madgrad.MADGRAD(parameters, lr=args['learning_rate'], weight_decay=args['weight_decay']) else: raise ValueError("Unknown optimizer: %s" % args['optim']) return optimizer |