Merge pull request #37 from awgu/pt2
Have DDP ignore `freqs_cis` to avoid broadcast
This commit is contained in:
@@ -191,6 +191,10 @@ if compile:
|
||||
|
||||
# wrap model into DDP container
|
||||
if ddp:
|
||||
# Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
|
||||
# construction time since NCCL does not support `ComplexFloat`
|
||||
prefix = "_orig_mod." if compile else ""
|
||||
model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
|
||||
model = DDP(model, device_ids=[ddp_local_rank])
|
||||
|
||||
# helps estimate an arbitrarily accurate loss over either split using many batches
|
||||
|
||||
Reference in New Issue
Block a user