fixed bugs in inference code
This commit is contained in:
+1
-1
@@ -21,7 +21,7 @@ options.intra_op_num_threads = 1
|
|||||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||||
|
|
||||||
# Initialize onnxruntime session for Pangu-Weather Models
|
# Initialize onnxruntime session for Pangu-Weather Models
|
||||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=['CPUExecutionProvider'])
|
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider'])
|
||||||
|
|
||||||
# Load the upper-air numpy arrays
|
# Load the upper-air numpy arrays
|
||||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||||
|
|||||||
+1
-1
@@ -21,7 +21,7 @@ options.intra_op_num_threads = 1
|
|||||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||||
|
|
||||||
# Initialize onnxruntime session for Pangu-Weather Models
|
# Initialize onnxruntime session for Pangu-Weather Models
|
||||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)])
|
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||||
|
|
||||||
# Load the upper-air numpy arrays
|
# Load the upper-air numpy arrays
|
||||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ options.intra_op_num_threads = 1
|
|||||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||||
|
|
||||||
# Initialize onnxruntime session for Pangu-Weather Models
|
# Initialize onnxruntime session for Pangu-Weather Models
|
||||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)])
|
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||||
ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)])
|
ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||||
|
|
||||||
# Load the upper-air numpy arrays
|
# Load the upper-air numpy arrays
|
||||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||||
|
|||||||
Reference in New Issue
Block a user