if not torch.distributed._initialized or torch.distributed.get_rank() == 0:
with open(datapath, "w") as f:
f.write("".join(strs))
str_ends = list(accumulate(map(len, strs)))
pkl.dump(str_ends, open(lenpath, "wb"))
else:
while not os.path.exists(lenpath):
time.sleep(1)
After Change
str_cnt = 0
for s in strs:
f.write(s.encode("utf-8"))
str_cnt += len(s)
str_ends.append(str_cnt)
pkl.dump(str_ends, open(lenpath, "wb"))
else:
while not os.path.exists(lenpath):