Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Nguyen <vince62s@yahoo.com>2022-09-08 17:15:06 +0300
committerGitHub <noreply@github.com>2022-09-08 17:15:06 +0300
commit8029c4028b1acacef20f874c477ed4795229ba5a (patch)
treef2b2c086b109a4191adc1b7b2dee4fd10e402e9e
parent5742168816e236c627db58abdabfea3febe0b6c7 (diff)
Include wider scope in try / exception training loop (#2195)
* include wider scope in try / exception be more specific in the try / Exception to only pass on Cuda OOM
-rw-r--r--onmt/trainer.py44
1 files changed, 24 insertions, 20 deletions
diff --git a/onmt/trainer.py b/onmt/trainer.py
index 50b4a781..ce1c58cb 100644
--- a/onmt/trainer.py
+++ b/onmt/trainer.py
@@ -368,33 +368,37 @@ class Trainer(object):
if self.accum_count == 1:
self.optim.zero_grad()
- with torch.cuda.amp.autocast(enabled=self.optim.amp):
- outputs, attns = self.model(
- src, tgt, src_lengths, bptt=bptt,
- with_align=self.with_align)
- bptt = True
-
- # 3. Compute loss.
- loss, batch_stats = self.train_loss(
- batch,
- outputs,
- attns,
- normalization=normalization,
- shard_size=self.shard_size,
- trunc_start=j,
- trunc_size=trunc_size)
-
try:
+ with torch.cuda.amp.autocast(enabled=self.optim.amp):
+ outputs, attns = self.model(
+ src, tgt, src_lengths, bptt=bptt,
+ with_align=self.with_align)
+ bptt = True
+
+ # 3. Compute loss.
+ loss, batch_stats = self.train_loss(
+ batch,
+ outputs,
+ attns,
+ normalization=normalization,
+ shard_size=self.shard_size,
+ trunc_start=j,
+ trunc_size=trunc_size)
+
if loss is not None:
self.optim.backward(loss)
total_stats.update(batch_stats)
report_stats.update(batch_stats)
- except Exception:
- traceback.print_exc()
- logger.info("At step %d, we removed a batch - accum %d",
- self.optim.training_step, k)
+ except Exception as exc:
+ trace_content = traceback.format_exc()
+ if "CUDA out of memory" in trace_content:
+ logger.info("Step %d, cuda OOM - batch removed",
+ self.optim.training_step)
+ else:
+ traceback.print_exc()
+ raise exc
# 4. Update the parameters and statistics.
if self.accum_count == 1: