Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--transquest/algo/transformers/run_model.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/transquest/algo/transformers/run_model.py b/transquest/algo/transformers/run_model.py
index 75dbe15..43c454e 100644
--- a/transquest/algo/transformers/run_model.py
+++ b/transquest/algo/transformers/run_model.py
@@ -23,7 +23,8 @@ from sklearn.metrics import (
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from tqdm.auto import trange, tqdm
-from transformers import AdamW, get_linear_schedule_with_warmup, FlaubertForSequenceClassification
+from transformers import AdamW, get_linear_schedule_with_warmup, FlaubertForSequenceClassification, \
+ get_cosine_with_hard_restarts_schedule_with_warmup
from transformers import (
BertConfig,
BertTokenizer,
@@ -292,8 +293,11 @@ class QuestModel:
args["warmup_steps"] = warmup_steps if args["warmup_steps"] == 0 else args["warmup_steps"]
optimizer = AdamW(optimizer_grouped_parameters, lr=args["learning_rate"], eps=args["adam_epsilon"])
- scheduler = get_linear_schedule_with_warmup(
- optimizer, num_warmup_steps=args["warmup_steps"], num_training_steps=t_total
+ # scheduler = get_linear_schedule_with_warmup(
+ # optimizer, num_warmup_steps=args["warmup_steps"], num_training_steps=t_total
+ # )
+ scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
+ optimizer, num_warmup_steps=args["warmup_steps"], num_training_steps=t_total, num_cycles=2
)
if args["fp16"]: