// printed out if another exception happens.// NB(jerry): added a flush to mitigate this
print(msg, file=sys.stderr)
iftorch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(device=device_idx),
file=sys.stderr)
sys.stderr.flush()
After Change
if"out of memory"in str(e):
self._log_oom(e)if raise_oom:
raise eprint("| WARNING: attempting to recover from OOM in forward/backward pass",
file=sys.stderr)
ooms += 1
self.zero_grad()