removed transformers from requirements.txt, added error message
This commit is contained in:
@@ -280,7 +280,12 @@ def load_checkpoint(checkpoint):
|
|||||||
|
|
||||||
def load_hf_model(model_path):
|
def load_hf_model(model_path):
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
try:
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
print("Error: transformers package is required to load huggingface models")
|
||||||
|
print("Please run `pip install transformers` to install it")
|
||||||
|
return None
|
||||||
|
|
||||||
# load HF model
|
# load HF model
|
||||||
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
|
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||||
@@ -357,5 +362,8 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
parser.error("Input model missing: --checkpoint or --hf is required")
|
parser.error("Input model missing: --checkpoint or --hf is required")
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
parser.error("Can't load input model!")
|
||||||
|
|
||||||
# export
|
# export
|
||||||
model_export(model, args.filepath, args.version)
|
model_export(model, args.filepath, args.version)
|
||||||
|
|||||||
@@ -5,4 +5,3 @@ sentencepiece==0.1.99
|
|||||||
torch==2.0.1
|
torch==2.0.1
|
||||||
tqdm==4.64.1
|
tqdm==4.64.1
|
||||||
wandb==0.15.5
|
wandb==0.15.5
|
||||||
transformers==4.31.0
|
|
||||||
|
|||||||
Reference in New Issue
Block a user