Have DDP ignore freqs_cis to avoid broadcast
This commit is contained in:
@@ -191,6 +191,10 @@ if compile:
|
|||||||
|
|
||||||
# wrap model into DDP container
|
# wrap model into DDP container
|
||||||
if ddp:
|
if ddp:
|
||||||
|
# Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
|
||||||
|
# construction time since NCCL does not support `ComplexFloat`
|
||||||
|
prefix = "_orig_mod." if compile else ""
|
||||||
|
model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
|
||||||
model = DDP(model, device_ids=[ddp_local_rank])
|
model = DDP(model, device_ids=[ddp_local_rank])
|
||||||
|
|
||||||
# helps estimate an arbitrarily accurate loss over either split using many batches
|
# helps estimate an arbitrarily accurate loss over either split using many batches
|
||||||
|
|||||||
Reference in New Issue
Block a user