Have DDP ignore freqs_cis to avoid broadcast

This commit is contained in:
Andrew Gu
2023-07-24 13:58:09 +00:00
parent d548245321
commit 25494f9cbc
+4
View File
@@ -191,6 +191,10 @@ if compile:
# wrap model into DDP container
if ddp:
# Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
# construction time since NCCL does not support `ComplexFloat`
prefix = "_orig_mod." if compile else ""
model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
model = DDP(model, device_ids=[ddp_local_rank])
# helps estimate an arbitrarily accurate loss over either split using many batches