diff options
-rw-r--r-- | transquest/algo/transformers/run_model.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/transquest/algo/transformers/run_model.py b/transquest/algo/transformers/run_model.py index 75dbe15..43c454e 100644 --- a/transquest/algo/transformers/run_model.py +++ b/transquest/algo/transformers/run_model.py @@ -23,7 +23,8 @@ from sklearn.metrics import ( from tensorboardX import SummaryWriter from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from tqdm.auto import trange, tqdm -from transformers import AdamW, get_linear_schedule_with_warmup, FlaubertForSequenceClassification +from transformers import AdamW, get_linear_schedule_with_warmup, FlaubertForSequenceClassification, \ + get_cosine_with_hard_restarts_schedule_with_warmup from transformers import ( BertConfig, BertTokenizer, @@ -292,8 +293,11 @@ class QuestModel: args["warmup_steps"] = warmup_steps if args["warmup_steps"] == 0 else args["warmup_steps"] optimizer = AdamW(optimizer_grouped_parameters, lr=args["learning_rate"], eps=args["adam_epsilon"]) - scheduler = get_linear_schedule_with_warmup( - optimizer, num_warmup_steps=args["warmup_steps"], num_training_steps=t_total + # scheduler = get_linear_schedule_with_warmup( + # optimizer, num_warmup_steps=args["warmup_steps"], num_training_steps=t_total + # ) + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer, num_warmup_steps=args["warmup_steps"], num_training_steps=t_total, num_cycles=2 ) if args["fp16"]: |