diff options
author | Vincent Nguyen <vince62s@yahoo.com> | 2022-09-08 17:15:06 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-08 17:15:06 +0300 |
commit | 8029c4028b1acacef20f874c477ed4795229ba5a (patch) | |
tree | f2b2c086b109a4191adc1b7b2dee4fd10e402e9e | |
parent | 5742168816e236c627db58abdabfea3febe0b6c7 (diff) |
Include wider scope in try / exception training loop (#2195)
* include wider scope in try / exception
be more specific in the try / Exception to only pass on Cuda OOM
-rw-r--r-- | onmt/trainer.py | 44 |
1 files changed, 24 insertions, 20 deletions
diff --git a/onmt/trainer.py b/onmt/trainer.py index 50b4a781..ce1c58cb 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -368,33 +368,37 @@ class Trainer(object): if self.accum_count == 1: self.optim.zero_grad() - with torch.cuda.amp.autocast(enabled=self.optim.amp): - outputs, attns = self.model( - src, tgt, src_lengths, bptt=bptt, - with_align=self.with_align) - bptt = True - - # 3. Compute loss. - loss, batch_stats = self.train_loss( - batch, - outputs, - attns, - normalization=normalization, - shard_size=self.shard_size, - trunc_start=j, - trunc_size=trunc_size) - try: + with torch.cuda.amp.autocast(enabled=self.optim.amp): + outputs, attns = self.model( + src, tgt, src_lengths, bptt=bptt, + with_align=self.with_align) + bptt = True + + # 3. Compute loss. + loss, batch_stats = self.train_loss( + batch, + outputs, + attns, + normalization=normalization, + shard_size=self.shard_size, + trunc_start=j, + trunc_size=trunc_size) + if loss is not None: self.optim.backward(loss) total_stats.update(batch_stats) report_stats.update(batch_stats) - except Exception: - traceback.print_exc() - logger.info("At step %d, we removed a batch - accum %d", - self.optim.training_step, k) + except Exception as exc: + trace_content = traceback.format_exc() + if "CUDA out of memory" in trace_content: + logger.info("Step %d, cuda OOM - batch removed", + self.optim.training_step) + else: + traceback.print_exc() + raise exc # 4. Update the parameters and statistics. if self.accum_count == 1: |