diff options
author | John Bauer <horatio@gmail.com> | 2022-11-04 10:39:25 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-04 10:39:25 +0300 |
commit | 2c83f417e4f487453657bfce3ffbb0378c30c95c (patch) | |
tree | ca7ab7080aad80013873d64b0f080e1bf64d9f56 | |
parent | 78daeae46b78bd875e08a56e9f3397f63def00c2 (diff) |
Adjust learning rate, don't print out infinite ppltrans_lm
-rw-r--r-- | stanza/models/constituency/trans_lm.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/stanza/models/constituency/trans_lm.py b/stanza/models/constituency/trans_lm.py index 4b1de1ba..6e8de99d 100644 --- a/stanza/models/constituency/trans_lm.py +++ b/stanza/models/constituency/trans_lm.py @@ -240,10 +240,14 @@ def train(optimizer, scheduler, epoch, device, train_data, model: nn.Module) -> lr = scheduler.get_last_lr()[0] ms_per_batch = (time.time() - start_time) * 1000 / log_interval cur_loss = total_loss / log_interval - ppl = math.exp(cur_loss) + try: + ppl = math.exp(cur_loss) + ppl = f"{ppl:8.2f}" + except OverflowError: + ppl = "inf" print(f'| epoch {epoch:3d} | {batch_idx:5d}/{num_batches:5d} batches | ' f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | ' - f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}') + f'loss {cur_loss:5.2f} | ppl {ppl}') total_loss = 0 start_time = time.time() @@ -286,7 +290,7 @@ DEFAULT_FILES = { def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument('--epochs', default=5, type=int, help='Num epochs to run') - parser.add_argument('--learning_rate', default=5.0, type=float, help='Initial learning rate') + parser.add_argument('--learning_rate', default=0.2, type=float, help='Initial learning rate') parser.add_argument('--data_dir', default="data/trans_lm", help='Where to find the data') parser.add_argument('--lang', default='vi', help='Which language to train - sets defaults for the data files, for example') |