diff options
Diffstat (limited to 'transquest/model_args.py')
-rw-r--r-- | transquest/model_args.py | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/transquest/model_args.py b/transquest/model_args.py new file mode 100644 index 0000000..af8be51 --- /dev/null +++ b/transquest/model_args.py @@ -0,0 +1,120 @@ +import json +import os +import sys +from dataclasses import dataclass, field, asdict +from multiprocessing import cpu_count + + +def get_default_process_count(): + process_count = cpu_count() - 2 if cpu_count() > 2 else 1 + if sys.platform == "win32": + process_count = min(process_count, 61) + + return process_count + + +def get_special_tokens(): + return ["<s>", "<pad>", "</s>", "<unk>", "<mask>"] + + +@dataclass +class TransQuestArgs: + adam_epsilon: float = 1e-8 + best_model_dir: str = "outputs/best_model" + cache_dir: str = "cache_dir/" + config: dict = field(default_factory=dict) + cosine_schedule_num_cycles: float = 0.5 + custom_layer_parameters: list = field(default_factory=list) + custom_parameter_groups: list = field(default_factory=list) + dataloader_num_workers: int = 0 + do_lower_case: bool = False + dynamic_quantize: bool = False + early_stopping_consider_epochs: bool = False + early_stopping_delta: float = 0 + early_stopping_metric: str = "eval_loss" + early_stopping_metric_minimize: bool = True + early_stopping_patience: int = 3 + encoding: str = None + adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3)) + adafactor_clip_threshold: float = 1.0 + adafactor_decay_rate: float = -0.8 + adafactor_beta1: float = None + adafactor_scale_parameter: bool = True + adafactor_relative_step: bool = True + adafactor_warmup_init: bool = True + eval_batch_size: int = 8 + evaluate_during_training: bool = False + evaluate_during_training_silent: bool = True + evaluate_during_training_steps: int = 2000 + evaluate_during_training_verbose: bool = False + evaluate_each_epoch: bool = True + fp16: bool = True + gradient_accumulation_steps: int = 1 + learning_rate: float = 4e-5 + local_rank: int = -1 + logging_steps: int = 50 + manual_seed: int = None + max_grad_norm: float = 1.0 + max_seq_length: int = 128 + model_name: str = None + model_type: str = None + multiprocessing_chunksize: int = 500 + n_gpu: int = 1 + no_cache: bool = False + no_save: bool = False + not_saved_args: list = field(default_factory=list) + num_train_epochs: int = 1 + optimizer: str = "AdamW" + output_dir: str = "outputs/" + overwrite_output_dir: bool = False + process_count: int = field(default_factory=get_default_process_count) + polynomial_decay_schedule_lr_end: float = 1e-7 + polynomial_decay_schedule_power: float = 1.0 + quantized_model: bool = False + reprocess_input_data: bool = True + save_best_model: bool = True + save_eval_checkpoints: bool = True + save_model_every_epoch: bool = True + save_optimizer_and_scheduler: bool = True + save_recent_only: bool = True + save_steps: int = 2000 + scheduler: str = "linear_schedule_with_warmup" + silent: bool = False + skip_special_tokens: bool = True + tensorboard_dir: str = None + thread_count: int = None + train_batch_size: int = 8 + train_custom_parameters_only: bool = False + use_cached_eval_features: bool = False + use_early_stopping: bool = False + use_multiprocessing: bool = True + wandb_kwargs: dict = field(default_factory=dict) + wandb_project: str = None + warmup_ratio: float = 0.06 + warmup_steps: int = 0 + weight_decay: float = 0.0 + + def update_from_dict(self, new_values): + if isinstance(new_values, dict): + for key, value in new_values.items(): + setattr(self, key, value) + else: + raise (TypeError(f"{new_values} is not a Python dict.")) + + def get_args_for_saving(self): + args_for_saving = {key: value for key, value in asdict(self).items() if key not in self.not_saved_args} + return args_for_saving + + def save(self, output_dir): + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(output_dir, "model_args.json"), "w") as f: + json.dump(self.get_args_for_saving(), f) + + def load(self, input_dir): + if input_dir: + model_args_file = os.path.join(input_dir, "model_args.json") + if os.path.isfile(model_args_file): + with open(model_args_file, "r") as f: + model_args = json.load(f) + + self.update_from_dict(model_args)
\ No newline at end of file |