bigchange: add multiquery support in run.c. we can now train and inference multiquery models (where n_kv_heads < n_heads). this also means that we, in principle, support Llama 2 34B and 70B models, which are multiquery
This commit is contained in:
@@ -94,6 +94,7 @@ class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
model_parallel_size = 1
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
|
||||
Reference in New Issue
Block a user