Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'transquest/model_args.py')
-rw-r--r--transquest/model_args.py120
1 files changed, 120 insertions, 0 deletions
diff --git a/transquest/model_args.py b/transquest/model_args.py
new file mode 100644
index 0000000..af8be51
--- /dev/null
+++ b/transquest/model_args.py
@@ -0,0 +1,120 @@
+import json
+import os
+import sys
+from dataclasses import dataclass, field, asdict
+from multiprocessing import cpu_count
+
+
+def get_default_process_count():
+ process_count = cpu_count() - 2 if cpu_count() > 2 else 1
+ if sys.platform == "win32":
+ process_count = min(process_count, 61)
+
+ return process_count
+
+
+def get_special_tokens():
+ return ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
+
+
+@dataclass
+class TransQuestArgs:
+ adam_epsilon: float = 1e-8
+ best_model_dir: str = "outputs/best_model"
+ cache_dir: str = "cache_dir/"
+ config: dict = field(default_factory=dict)
+ cosine_schedule_num_cycles: float = 0.5
+ custom_layer_parameters: list = field(default_factory=list)
+ custom_parameter_groups: list = field(default_factory=list)
+ dataloader_num_workers: int = 0
+ do_lower_case: bool = False
+ dynamic_quantize: bool = False
+ early_stopping_consider_epochs: bool = False
+ early_stopping_delta: float = 0
+ early_stopping_metric: str = "eval_loss"
+ early_stopping_metric_minimize: bool = True
+ early_stopping_patience: int = 3
+ encoding: str = None
+ adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3))
+ adafactor_clip_threshold: float = 1.0
+ adafactor_decay_rate: float = -0.8
+ adafactor_beta1: float = None
+ adafactor_scale_parameter: bool = True
+ adafactor_relative_step: bool = True
+ adafactor_warmup_init: bool = True
+ eval_batch_size: int = 8
+ evaluate_during_training: bool = False
+ evaluate_during_training_silent: bool = True
+ evaluate_during_training_steps: int = 2000
+ evaluate_during_training_verbose: bool = False
+ evaluate_each_epoch: bool = True
+ fp16: bool = True
+ gradient_accumulation_steps: int = 1
+ learning_rate: float = 4e-5
+ local_rank: int = -1
+ logging_steps: int = 50
+ manual_seed: int = None
+ max_grad_norm: float = 1.0
+ max_seq_length: int = 128
+ model_name: str = None
+ model_type: str = None
+ multiprocessing_chunksize: int = 500
+ n_gpu: int = 1
+ no_cache: bool = False
+ no_save: bool = False
+ not_saved_args: list = field(default_factory=list)
+ num_train_epochs: int = 1
+ optimizer: str = "AdamW"
+ output_dir: str = "outputs/"
+ overwrite_output_dir: bool = False
+ process_count: int = field(default_factory=get_default_process_count)
+ polynomial_decay_schedule_lr_end: float = 1e-7
+ polynomial_decay_schedule_power: float = 1.0
+ quantized_model: bool = False
+ reprocess_input_data: bool = True
+ save_best_model: bool = True
+ save_eval_checkpoints: bool = True
+ save_model_every_epoch: bool = True
+ save_optimizer_and_scheduler: bool = True
+ save_recent_only: bool = True
+ save_steps: int = 2000
+ scheduler: str = "linear_schedule_with_warmup"
+ silent: bool = False
+ skip_special_tokens: bool = True
+ tensorboard_dir: str = None
+ thread_count: int = None
+ train_batch_size: int = 8
+ train_custom_parameters_only: bool = False
+ use_cached_eval_features: bool = False
+ use_early_stopping: bool = False
+ use_multiprocessing: bool = True
+ wandb_kwargs: dict = field(default_factory=dict)
+ wandb_project: str = None
+ warmup_ratio: float = 0.06
+ warmup_steps: int = 0
+ weight_decay: float = 0.0
+
+ def update_from_dict(self, new_values):
+ if isinstance(new_values, dict):
+ for key, value in new_values.items():
+ setattr(self, key, value)
+ else:
+ raise (TypeError(f"{new_values} is not a Python dict."))
+
+ def get_args_for_saving(self):
+ args_for_saving = {key: value for key, value in asdict(self).items() if key not in self.not_saved_args}
+ return args_for_saving
+
+ def save(self, output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ with open(os.path.join(output_dir, "model_args.json"), "w") as f:
+ json.dump(self.get_args_for_saving(), f)
+
+ def load(self, input_dir):
+ if input_dir:
+ model_args_file = os.path.join(input_dir, "model_args.json")
+ if os.path.isfile(model_args_file):
+ with open(model_args_file, "r") as f:
+ model_args = json.load(f)
+
+ self.update_from_dict(model_args) \ No newline at end of file