From 25494f9cbceeb8cdefcf834085e4f1cb9a8c3a75 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 24 Jul 2023 13:58:09 +0000 Subject: [PATCH] Have DDP ignore `freqs_cis` to avoid broadcast --- train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train.py b/train.py index 7aa46c4..34248b8 100644 --- a/train.py +++ b/train.py @@ -191,6 +191,10 @@ if compile: # wrap model into DDP container if ddp: + # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at + # construction time since NCCL does not support `ComplexFloat` + prefix = "_orig_mod." if compile else "" + model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"} model = DDP(model, device_ids=[ddp_local_rank]) # helps estimate an arbitrarily accurate loss over either split using many batches