diff --git a/inference_cpu.py b/inference_cpu.py index cc6acda..017a2fe 100644 --- a/inference_cpu.py +++ b/inference_cpu.py @@ -21,7 +21,7 @@ options.intra_op_num_threads = 1 cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} # Initialize onnxruntime session for Pangu-Weather Models -ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=['CPUExecutionProvider']) +ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider']) # Load the upper-air numpy arrays input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32) diff --git a/inference_gpu.py b/inference_gpu.py index 833b16f..c46b2a6 100644 --- a/inference_gpu.py +++ b/inference_gpu.py @@ -21,7 +21,7 @@ options.intra_op_num_threads = 1 cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} # Initialize onnxruntime session for Pangu-Weather Models -ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)]) +ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) # Load the upper-air numpy arrays input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32) diff --git a/inference_iterative.py b/inference_iterative.py index 39037bd..4a9c767 100644 --- a/inference_iterative.py +++ b/inference_iterative.py @@ -22,8 +22,8 @@ options.intra_op_num_threads = 1 cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} # Initialize onnxruntime session for Pangu-Weather Models -ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)]) -ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)]) +ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) +ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) # Load the upper-air numpy arrays input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)