diff --git a/train.py b/train.py index 7aa46c4..34248b8 100644 --- a/train.py +++ b/train.py @@ -191,6 +191,10 @@ if compile: # wrap model into DDP container if ddp: + # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at + # construction time since NCCL does not support `ComplexFloat` + prefix = "_orig_mod." if compile else "" + model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"} model = DDP(model, device_ids=[ddp_local_rank]) # helps estimate an arbitrarily accurate loss over either split using many batches