From cf3547114d84b941d62b120281e8e864f76b035f Mon Sep 17 00:00:00 2001 From: Lingxi Xie <198808xc@gmail.com> Date: Thu, 2 Mar 2023 22:37:27 +0800 Subject: [PATCH] fixed bugs in inference code --- inference_cpu.py | 2 +- inference_gpu.py | 2 +- inference_iterative.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/inference_cpu.py b/inference_cpu.py index cc6acda..017a2fe 100644 --- a/inference_cpu.py +++ b/inference_cpu.py @@ -21,7 +21,7 @@ options.intra_op_num_threads = 1 cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} # Initialize onnxruntime session for Pangu-Weather Models -ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=['CPUExecutionProvider']) +ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider']) # Load the upper-air numpy arrays input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32) diff --git a/inference_gpu.py b/inference_gpu.py index 833b16f..c46b2a6 100644 --- a/inference_gpu.py +++ b/inference_gpu.py @@ -21,7 +21,7 @@ options.intra_op_num_threads = 1 cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} # Initialize onnxruntime session for Pangu-Weather Models -ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)]) +ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) # Load the upper-air numpy arrays input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32) diff --git a/inference_iterative.py b/inference_iterative.py index 39037bd..4a9c767 100644 --- a/inference_iterative.py +++ b/inference_iterative.py @@ -22,8 +22,8 @@ options.intra_op_num_threads = 1 cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',} # Initialize onnxruntime session for Pangu-Weather Models -ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)]) -ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)]) +ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) +ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) # Load the upper-air numpy arrays input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)