From 518524f4580c6fb410044aa347b1ede5fd157a65 Mon Sep 17 00:00:00 2001 From: Daniel Gross Date: Sun, 23 Jul 2023 10:41:03 -0700 Subject: [PATCH] default to whatever system has --- sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sample.py b/sample.py index 2038a63..040bc14 100644 --- a/sample.py +++ b/sample.py @@ -17,7 +17,7 @@ max_new_tokens = 100 # number of tokens generated in each sample temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability seed = 1337 -device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. +device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' dtype = "float32" compile = False # use PyTorch 2.0 to compile the model to be faster