Get vocab_size from token embeddings size
This commit is contained in:
@@ -323,9 +323,10 @@ def load_meta_model(model_path):
|
|||||||
config.multiple_of = params["multiple_of"]
|
config.multiple_of = params["multiple_of"]
|
||||||
config.norm_eps = params["norm_eps"]
|
config.norm_eps = params["norm_eps"]
|
||||||
|
|
||||||
config.vocab_size = 32000
|
config.vocab_size = state_dict['tok_embeddings.weight'].shape[0]
|
||||||
config.max_seq_len = 2048
|
config.max_seq_len = 2048
|
||||||
|
|
||||||
|
|
||||||
# create a new Transformer object and set weights
|
# create a new Transformer object and set weights
|
||||||
model = Transformer(config)
|
model = Transformer(config)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user