Fix WQ and WK permutation in huggingface models
This commit is contained in:
@@ -303,11 +303,15 @@ def load_hf_model(model_path):
|
||||
model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
|
||||
model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
|
||||
|
||||
# huggingface permutes WQ and WK, this function reverses it
|
||||
def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
|
||||
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
for layer in model.layers:
|
||||
i = layer.layer_id
|
||||
layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
|
||||
layer.attention.wq.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight'])
|
||||
layer.attention.wk.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight'])
|
||||
layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
|
||||
layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
|
||||
layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
|
||||
layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
|
||||
layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
|
||||
|
||||
Reference in New Issue
Block a user