Compare commits
26 Commits
Pre-release
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 8642e9806b | |||
| fa4cec5805 | |||
| 4d0d1eaed7 | |||
| 1130d80c2d | |||
| a359fa116e | |||
| b5850da258 | |||
| d2b8ae962d | |||
| d9f2964faf | |||
| 62acbda2c4 | |||
| cf3547114d | |||
| 0387b09086 | |||
| 84a1e2448d | |||
| 597f90e3d8 | |||
| 27d322a56a | |||
| b7a32defaf | |||
| d289a9a310 | |||
| d2df335fb0 | |||
| bf7834f827 | |||
| 95bc883352 | |||
| 0be504e554 | |||
| 2d0082e16b | |||
| 9583d28ec8 | |||
| 240731f9f8 | |||
| ed94885bc1 | |||
| 85a25ba942 | |||
| 9b26464a4f |
@@ -1,27 +1,168 @@
|
||||
## Pangu-Weather
|
||||
|
||||
This is the official repository for the Pangu-Weather paper.
|
||||
This is the official repository for the Pangu-Weather papers.
|
||||
|
||||
Resources including pseudocode and pre-trained models will be updated. Stay tuned!
|
||||
[Accurate medium-range global weather forecasting with 3D neural networks](https://www.nature.com/articles/s41586-023-06185-3), Nature, Volume 619, Pages 533–538, 2023.
|
||||
|
||||
#### Policy of using the contents
|
||||
[Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast](https://arxiv.org/abs/2211.02556), arXiv preprint: 2211.02556, 2022.
|
||||
|
||||
All models are trained using the ERA5 dataset provided by ECMWF. Please follow their policy and note that commercial use of these models are forbidden.
|
||||
*by Kaifeng Bi, Lingxi Xie, Hengheng Zhang, Xin Chen, Xiaotao Gu and Qi Tian*
|
||||
|
||||
More policy to be updated.
|
||||
**Note: the arXiv version offers more technical details, and the Nature paper contains some new figures.**
|
||||
|
||||
#### Pseudocode and how to use
|
||||
Resources including pseudocode, pre-trained models, and inference code are released here.
|
||||
|
||||
To be updated.
|
||||
The slides used in a series of recent talks are attached here. [Baidu Netdisk](https://pan.baidu.com/s/14ZGywcr4XAK5dk75-8PUqA?pwd=9sco), extraction code: 9sco
|
||||
|
||||
#### Pre-trained models
|
||||
## News and Updates
|
||||
|
||||
To be updated.
|
||||
* [Jul 31 2023] We released the details of training the lite version of Pangu-Weather.
|
||||
* [Jul 19 2023] ECMWF released an official [technical report](https://arxiv.org/abs/2307.10128) for "the rise of data-driven weather forecasting". Pangu-Weather was mentioned and tested thoroughly in the paper. We thank ECMWF for testing our models in real-world scenarios.
|
||||
* [Jul 17 2023] Pangu-Weather was online as part of ECMWF's operational suite! Everyone can see 10-day global weather forecasting **without running code**. ECMWF has made use of the models released at this repository! [Please search the ECMWF charts website with the query of "PANGU".](https://charts.ecmwf.int/?query=pangu)
|
||||
* [Jul 05 2023] Pangu-Weather was published on [Nature](https://www.nature.com/articles/s41586-023-06185-3). It was made **Open Access**! We recommend the researchers to cite our Nature paper in the future.
|
||||
* [Jun 27 2023] Pangu-Weather was presented at [PASC 2023](https://pasc23.pasc-conference.org/program/schedule/).
|
||||
* [Jun 12 2023] Pangu-Weather was presented at [VALSE 2023](http://valser.org/2023/#/workshopde?id=15).
|
||||
* [May 27 2023] Pangu-Weather was presented at [the WMO Early Warning for All (EW4ALL) conference](https://community.wmo.int/en/news/exploring-possibilities-artificial-intelligence-areas-water-weather-and-climate).
|
||||
* [May 12 2023] ECMWF released a [repository](https://github.com/ecmwf-lab/ai-models-panguweather), offering a toolkit for running Pangu-Weather. We thank ECMWF for the efforts in easing everyone to test Pangu-Weather.
|
||||
* [May 09 2023] Pangu-Weather was accepted by Nature!
|
||||
|
||||
#### References
|
||||
## Installation
|
||||
|
||||
The downloaded files shall be organized as the following hierarchy:
|
||||
|
||||
```plain
|
||||
├── root
|
||||
│ ├── input_data
|
||||
│ │ ├── input_surface.npy
|
||||
│ │ ├── input_upper.npy
|
||||
│ ├── output_data
|
||||
│ ├── pangu_weather_1.onnx
|
||||
│ ├── pangu_weather_3.onnx
|
||||
│ ├── pangu_weather_6.onnx
|
||||
│ ├── pangu_weather_24.onnx
|
||||
│ ├── inference_cpu.py
|
||||
│ ├── inference_gpu.py
|
||||
│ ├── inference_iterative.py
|
||||
```
|
||||
|
||||
If you use a CPU environment, please run:
|
||||
```
|
||||
pip install -r requirements_cpu.txt
|
||||
```
|
||||
|
||||
If you use a GPU environment, please first confirm that the cuda version is 11.6 and the cudnn version is the 8.2.4 for Linux and 8.5.0.96 for Windows (please see [this page](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) for details). Then, please run:
|
||||
```
|
||||
pip install -r requirements_gpu.txt
|
||||
```
|
||||
|
||||
## Global weather forecasting (inference) using the trained models
|
||||
|
||||
#### Downloading trained models
|
||||
|
||||
Please download the four pre-trained models (~1.1GB each) from Google drive or Baidu netdisk:
|
||||
|
||||
The 1-hour model (pangu_weather_1.onnx): [Google drive](https://drive.google.com/file/d/1fg5jkiN_5dHzKb-5H9Aw4MOmfILmeY-S/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1M7SAigVsCSH8hpw6DE8TDQ?pwd=ie0h)
|
||||
|
||||
The 3-hour model (pangu_weather_3.onnx): [Google drive](https://drive.google.com/file/d/1EdoLlAXqE9iZLt9Ej9i-JW9LTJ9Jtewt/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/197fZsoiCqZYzKwM7tyRrfg?pwd=gmcl)
|
||||
|
||||
The 6-hour model (pangu_weather_6.onnx): [Google drive](https://drive.google.com/file/d/1a4XTktkZa5GCtjQxDJb_fNaqTAUiEJu4/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1q7IB7tNjqIwoGC7KVMPn4w?pwd=vxq3)
|
||||
|
||||
The 24-hour model (pangu_weather_24.onnx): [Google drive](https://drive.google.com/file/d/1lweQlxcn9fG0zKNW8ne1Khr9ehRTI6HP/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/179q2gkz2BrsOR6g3yfTVQg?pwd=eajy)
|
||||
|
||||
These models are stored using the ONNX format, and thus can be used via different languages such as Python, C++, C#, Java, etc.
|
||||
|
||||
#### Input data preparation using Python
|
||||
|
||||
Please prepare the input data using [numpy](https://numpy.org/). There are two files that shall be put under the `input_data` folder, namely, `input_surface.npy` and `input_upper.npy`.
|
||||
|
||||
`input_surface.npy` stores the input surface variables. It is a numpy array shaped (4,721,1440) where the first dimension represents the 4 surface variables (MSLP, U10, V10, T2M **in the exact order**).
|
||||
|
||||
`input_upper.npy` stores the upper-air variables. It is a numpy array shaped (5,13,721,1440) where the first dimension represents the 5 surface variables (Z, Q, T, U and V **in the exact order**), and the second dimension represents the 13 pressure levels (1000hPa, 925hPa, 850hPa, 700hPa, 600hPa, 500hPa, 400hPa, 300hPa, 250hPa, 200hPa, 150hPa, 100hPa and 50hPa **in the exact order**).
|
||||
|
||||
In both cases, the dimensions of 721 and 1440 represent the size along the latitude and longitude, where the numerical range is [90,-90] degree and [0,359.75] degree, respectively, and the spacing is 0.25 degrees. For each 721x1440 slice, the data format is exactly the same as the `.nc` file download from the ERA5 official website.
|
||||
|
||||
Note that the numpy arrays should be in single precision (`.astype(np.float32)`), not in double precision.
|
||||
|
||||
We support ERA5 initial fields and ECMWF initial fields (e.g., the initial fields of the HRES forecast), where the latter often leads to a slight accuracy drop (mainly for T2M because the two fields are quite different in temperature). A `.nc` file of ERA5 can be transformed into a `.npy` file using the netCDF4 package, and a `.grib` file of the ECMWF initial fields can be transformed into a `.npy` file using the pygrib package. Note that Z represents geopotential, not geopotential height, so a factor of 9.80665 should be multiplied if the raw data contains the geopotential height.
|
||||
|
||||
We temporarily do not support other kinds of initial fields due to the possibly dramatic differences in the fields when Z<0.
|
||||
|
||||
We provide an example of transferred input files, `input_surface.npy` and `input_upper.npy`, which correspond to the ERA5 initial fields of at 12:00UTC, 2018/09/27. Please download them from Google drive or Baidu netdisk:
|
||||
|
||||
`input_surface.npy`: [Google drive](https://drive.google.com/file/d/1pj8QEVNpC1FyJfUabDpV4oU3NpSe0BkD/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1i4o5i8guAqmOus6PWncAlA?pwd=4z9s)
|
||||
|
||||
`input_upper.npy`: [Google drive](https://drive.google.com/file/d/1--7xEBJt79E3oixizr8oFmK_haDE77SS/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1mS8X5MqEdbVfF2u2Us62FQ?pwd=sgx6)
|
||||
|
||||
#### Inference
|
||||
|
||||
After the above steps are finished, please check `inference_cpu.py` for an example of making a 24-hour weather forecast on CPU with the 24-hour model, and `inference_gpu.py` for the GPU version.
|
||||
|
||||
For example, running the following command, one can get the 24-hour forecast in the `output_data` folder:
|
||||
```
|
||||
python inference_cpu.py # python inference_gpu.py for gpu environment
|
||||
```
|
||||
|
||||
Also, `inference_iterative.py` shows an example to generate per-6-hour forecast within a week.
|
||||
|
||||
## Pseudocode and how to use
|
||||
|
||||
`pseudocode.py` contains the pseudocode that elaborates our main algorithm. It is written in Python and can be implemented using any deep learning library, e.g. PyTorch and TensorFlow.
|
||||
|
||||
Note that one needs to download about 60TB of ERA5 data and prepare for computational resource of 3000 GPU-days (in V100) to train each model.
|
||||
|
||||
## Training a lite version
|
||||
|
||||
Recently, we found that Pangu-Weather can be trained efficiently using only 1% of data and GPU computation. We call the version Pangu-Weather-lite. Note that the lite models cannot rival the full models, but the lite version offers opportunities for researchers with limited resource to explore the AI methods for weather forecasting.
|
||||
|
||||
Here are the key implementation details.
|
||||
|
||||
* Data. We reduced the training data into 11 years (2007-2017) and only used the 00UTC time point (the full version used all 24 time points throughout the day). Also, only 00UTC data is used in the testing phase. The total amount of downloaded data shall be less than 1TB.
|
||||
* Model. We adjusted the down-sampling rate in the first stage from 2x4x4 to 2x8x8.
|
||||
* Training epochs. One can remain using 100 epochs or reduce the number to 50 (half); note that the cosine annealing schedule is adjusted accordingly.
|
||||
* Model set. We only trained one model (lead time is 24 hours), which means that the lite version can only perform daily weather forecasting.
|
||||
|
||||
Here are the results.
|
||||
|
||||
| Method | RMSE, Z500 | RMSE, T850 | RMSE, T2M | RMSE, U10 | Years | Down-sampling | Epochs | GPU x days |
|
||||
| ------------------- | ---------------------- | -------------------- | -------------------- | -------------------- | ----- | ------------- | -- | ---------- |
|
||||
| Operational IFS | 152.8 (3d), 333.7 (5d) | 1.37 (3d), 2.06 (5d) | 1.34 (3d), 1.75 (5d) | 1.94 (3d), 2.90 (5d) | -- | -- | -- | -- |
|
||||
| Pangu-Weather | 134.5 (3d), 296.7 (5d) | 1.14 (3d), 1.79 (5d) | 1.05 (3d), 1.53 (5d) | 1.61 (3d), 2.53 (5d) | 39 | 2 x 4 x 4 | 100 | 192 x 16 |
|
||||
| Pangu-Weather-Lite1 | 163.1 (3d), 338.2 (5d) | 1.29 (3d), 1.96 (5d) | 1.16 (3d), 1.64 (5d) | 1.80 (3d), 2.74 (5d) | 11 | 2 x 8 x 8 | 100 | 8 x 6 |
|
||||
| Pangu-Weather-Lite2 | 177.9 (3d), 357.5 (5d) | 1.36 (3d), 2.05 (5d) | 1.24 (3d), 1.71 (5d) | 1.90 (3d), 2.84 (5d) | 11 | 2 x 8 x 8 | 50 | 8 x 3 |
|
||||
|
||||
One can observe that the lite version can surpass operational IFS (*when tested only at 00UTC time points*) in T850 (850hPa temperature), T2M (2m temperature) and U10 (u-component of 10m wind speed), while requiring less than 1% of computational costs compared to the full version.
|
||||
|
||||
Please note that the lite version was only trained and tested in 00UTC data. This means that its performance on other time points is not guaranteed. Since whether variables are closely correlated to time-in-day, it is difficult to directly use the lite version for daily whether forecasting. Again, the lite version is to ease the researchers to explore the property of AI models.
|
||||
|
||||
## License
|
||||
|
||||
Pangu-Weather was released by Huawei Cloud.
|
||||
|
||||
The trained parameters of Pangu-Weather were made available under the terms of the BY-NC-SA 4.0 license. You can find details [here](https://creativecommons.org/licenses/by-nc-sa/4.0/).
|
||||
|
||||
**The commercial use of these models is forbidden.**
|
||||
|
||||
Also, please note that all models were trained using the ERA5 dataset provided by ECMWF. Please do follow [their policy](https://apps.ecmwf.int/datasets/licences/copernicus/).
|
||||
|
||||
## References
|
||||
|
||||
If you use the resource in your research, please cite our paper:
|
||||
|
||||
```
|
||||
@article{bi2023accurate,
|
||||
title={Accurate medium-range global weather forecasting with 3D neural networks},
|
||||
author={Bi, Kaifeng and Xie, Lingxi and Zhang, Hengheng and Chen, Xin and Gu, Xiaotao and Tian, Qi},
|
||||
journal={Nature},
|
||||
volume={619},
|
||||
number={7970},
|
||||
pages={533--538},
|
||||
year={2023},
|
||||
publisher={Nature Publishing Group}
|
||||
}
|
||||
```
|
||||
|
||||
We also offer the bibliography of the arXiv preprint version for your information.
|
||||
|
||||
```
|
||||
@article{bi2022pangu,
|
||||
title={Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast},
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
# The directory of your input and output data
|
||||
input_data_dir = 'input_data'
|
||||
output_data_dir = 'output_data'
|
||||
model_24 = onnx.load('pangu_weather_24.onnx')
|
||||
|
||||
# Set the behavier of onnxruntime
|
||||
options = ort.SessionOptions()
|
||||
options.enable_cpu_mem_arena=False
|
||||
options.enable_mem_pattern = False
|
||||
options.enable_mem_reuse = False
|
||||
# Increase the number for faster inference and more memory consumption
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
# Set the behavier of cuda provider
|
||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||
|
||||
# Initialize onnxruntime session for Pangu-Weather Models
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider'])
|
||||
|
||||
# Load the upper-air numpy arrays
|
||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||
# Load the surface numpy arrays
|
||||
input_surface = np.load(os.path.join(input_data_dir, 'input_surface.npy')).astype(np.float32)
|
||||
|
||||
# Run the inference session
|
||||
output, output_surface = ort_session_24.run(None, {'input':input, 'input_surface':input_surface})
|
||||
|
||||
# Save the results
|
||||
np.save(os.path.join(output_data_dir, 'output_upper'), output)
|
||||
np.save(os.path.join(output_data_dir, 'output_surface'), output_surface)
|
||||
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
# The directory of your input and output data
|
||||
input_data_dir = 'input_data'
|
||||
output_data_dir = 'output_data'
|
||||
model_24 = onnx.load('pangu_weather_24.onnx')
|
||||
|
||||
# Set the behavier of onnxruntime
|
||||
options = ort.SessionOptions()
|
||||
options.enable_cpu_mem_arena=False
|
||||
options.enable_mem_pattern = False
|
||||
options.enable_mem_reuse = False
|
||||
# Increase the number for faster inference and more memory consumption
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
# Set the behavier of cuda provider
|
||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||
|
||||
# Initialize onnxruntime session for Pangu-Weather Models
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
|
||||
# Load the upper-air numpy arrays
|
||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||
# Load the surface numpy arrays
|
||||
input_surface = np.load(os.path.join(input_data_dir, 'input_surface.npy')).astype(np.float32)
|
||||
|
||||
# Run the inference session
|
||||
output, output_surface = ort_session_24.run(None, {'input':input, 'input_surface':input_surface})
|
||||
# Save the results
|
||||
np.save(os.path.join(output_data_dir, 'output_upper'), output)
|
||||
np.save(os.path.join(output_data_dir, 'output_surface'), output_surface)
|
||||
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
# The directory of your input and output data
|
||||
input_data_dir = 'input_data'
|
||||
output_data_dir = 'output_data'
|
||||
model_24 = onnx.load('pangu_weather_24.onnx')
|
||||
model_6 = onnx.load('pangu_weather_6.onnx')
|
||||
|
||||
# Set the behavier of onnxruntime
|
||||
options = ort.SessionOptions()
|
||||
options.enable_cpu_mem_arena=False
|
||||
options.enable_mem_pattern = False
|
||||
options.enable_mem_reuse = False
|
||||
# Increase the number for faster inference and more memory consumption
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
# Set the behavier of cuda provider
|
||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||
|
||||
# Initialize onnxruntime session for Pangu-Weather Models
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
|
||||
# Load the upper-air numpy arrays
|
||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||
# Load the surface numpy arrays
|
||||
input_surface = np.load(os.path.join(input_data_dir, 'input_surface.npy')).astype(np.float32)
|
||||
|
||||
# Run the inference session
|
||||
input_24, input_surface_24 = input, input_surface
|
||||
for i in range(28):
|
||||
if (i+1) % 4 == 0:
|
||||
output, output_surface = ort_session_24.run(None, {'input':input_24, 'input_surface':input_surface_24})
|
||||
input_24, input_surface_24 = output, output_surface
|
||||
else:
|
||||
output, output_surface = ort_session_6.run(None, {'input':input, 'input_surface':input_surface})
|
||||
input, input_surface = output, output_surface
|
||||
# Your can save the results here
|
||||
+599
@@ -0,0 +1,599 @@
|
||||
'''
|
||||
Pseudocode of Pangu-Weather
|
||||
'''
|
||||
# The pseudocode can be implemented using deep learning libraries, e.g., Pytorch and Tensorflow or other high-level APIs
|
||||
|
||||
# Basic operations used in our model, namely Linear, Conv3d, Conv2d, ConvTranspose3d and ConvTranspose2d
|
||||
# Linear: Linear transformation, available in all deep learning libraries
|
||||
# Conv3d and Con2d: Convolution with 2 or 3 dimensions, available in all deep learning libraries
|
||||
# ConvTranspose3d, ConvTranspose2d: transposed convolution with 2 or 3 dimensions, see Pytorch API or Tensorflow API
|
||||
from Your_AI_Library import Linear, Conv3d, Conv2d, ConvTranspose3d, ConvTranspose2d
|
||||
|
||||
# Functions in the networks, namely GeLU, DropOut, DropPath, LayerNorm, and SoftMax
|
||||
# GeLU: the GeLU activation function, see Pytorch API or Tensorflow API
|
||||
# DropOut: the dropout function, available in all deep learning libraries
|
||||
# DropPath: the DropPath function, see the implementation of vision-transformer, see timm pakage of Pytorch
|
||||
# A possible implementation of DropPath: from timm.models.layers import DropPath
|
||||
# LayerNorm: the layer normalization function, see Pytorch API or Tensorflow API
|
||||
# Softmax: softmax function, see Pytorch API or Tensorflow API
|
||||
from Your_AI_Library import GeLU, DropOut, DropPath, LayerNorm, SoftMax
|
||||
|
||||
# Common functions for roll, pad, and crop, depends on the data structure of your software environment
|
||||
from Your_AI_Library import roll3D, pad3D, pad2D, Crop3D, Crop2D
|
||||
|
||||
# Common functions for reshaping and changing the order of dimensions
|
||||
# reshape: change the shape of the data with the order unchanged, see Pytorch API or Tensorflow API
|
||||
# TransposeDimensions: change the order of the dimensions, see Pytorch API or Tensorflow API
|
||||
from Your_AI_Library import reshape, TransposeDimensions
|
||||
|
||||
# Common functions for creating new tensors
|
||||
# ConstructTensor: create a new tensor with an arbitrary shape
|
||||
# TruncatedNormalInit: Initialize the tensor with Truncate Normalization distribution
|
||||
# RangeTensor: create a new tensor like range(a, b)
|
||||
from Your_AI_Library import ConstructTensor, TruncatedNormalInit, RangeTensor
|
||||
|
||||
# Common operations for the data, you may design it or simply use deep learning APIs default operations
|
||||
# LinearSpace: a tensor version of numpy.linspace
|
||||
# MeshGrid: a tensor version of numpy.meshgrid
|
||||
# Stack: a tensor version of numpy.stack
|
||||
# Flatten: a tensor version of numpy.ndarray.flatten
|
||||
# TensorSum: a tensor version of numpy.sum
|
||||
# TensorAbs: a tensor version of numpy.abs
|
||||
# Concatenate: a tensor version of numpy.concatenate
|
||||
from Your_AI_Library import LinearSpace, MeshGrid, Stack, Flatten, TensorSum, TensorAbs, Concatenate
|
||||
|
||||
# Common functions for training models
|
||||
# LoadModel and SaveModel: Load and save the model, some APIs may require further adaptation to hardwares
|
||||
# Backward: Gradient backward to calculate the gratitude of each parameters
|
||||
# UpdateModelParametersWithAdam: Use Adam to update parameters, e.g., torch.optim.Adam
|
||||
from Your_AI_Library import LoadModel, Backward, UpdateModelParametersWithAdam, SaveModel
|
||||
|
||||
# Custom functions to read your data from the disc
|
||||
# LoadData: Load the ERA5 data
|
||||
# LoadConstantMask: Load constant masks, e.g., soil type
|
||||
# LoadStatic: Load mean and std of the ERA5 training data, every fields such as T850 is treated as an image and calculate the mean and std
|
||||
from Your_Data_Code import LoadData, LoadConstantMask, LoadStatic
|
||||
|
||||
|
||||
def Inference(input, input_surface, forecast_range):
|
||||
'''Inference code, describing the algorithm of inference using models with different lead times.
|
||||
PanguModel24, PanguModel6, PanguModel3 and PanguModel1 share the same training algorithm but differ in lead times.
|
||||
Args:
|
||||
input: input tensor, need to be normalized to N(0, 1) in practice
|
||||
input_surface: target tensor, need to be normalized to N(0, 1) in practice
|
||||
forecast_range: iteration numbers when roll out the forecast model
|
||||
'''
|
||||
|
||||
# Load 4 pre-trained models with different lead times
|
||||
PanguModel24 = LoadModel(ModelPath24)
|
||||
PanguModel6 = LoadModel(ModelPath6)
|
||||
PanguModel3 = LoadModel(ModelPath3)
|
||||
PanguModel1 = LoadModel(ModelPath1)
|
||||
|
||||
# Load mean and std of the weather data
|
||||
weather_mean, weather_std, weather_surface_mean, weather_surface_std = LoadStatic()
|
||||
|
||||
# Store initial input for different models
|
||||
input_24, input_surface_24 = input, input_surface
|
||||
input_6, input_surface_6 = input, input_surface
|
||||
input_3, input_surface_3 = input, input_surface
|
||||
|
||||
# Using a list to store output
|
||||
output_list = []
|
||||
|
||||
# Note: the following code is implemented for fast inference of [1,forecast_range]-hour forecasts -- if only one lead time is requested, the inference can be much faster.
|
||||
for i in range(forecast_range):
|
||||
# switch to the 24-hour model if the forecast time is 24 hours, 48 hours, ..., 24*N hours
|
||||
if (i+1) % 24 == 0:
|
||||
# Switch the input back to the stored input
|
||||
input, input_surface = input_24, input_surface_24
|
||||
|
||||
# Call the model pretrained for 24 hours forecast
|
||||
output, output_surface = PanguModel24(input, input_surface)
|
||||
|
||||
# Restore from uniformed output
|
||||
output = output * weather_std + weather_mean
|
||||
output_surface = output_surface * weather_surface_std + weather_surface_mean
|
||||
|
||||
# Stored the output for next round forecast
|
||||
input_24, input_surface_24 = output, output_surface
|
||||
input_6, input_surface_6 = output, output_surface
|
||||
input_3, input_surface_3 = output, output_surface
|
||||
|
||||
# switch to the 6-hour model if the forecast time is 30 hours, 36 hours, ..., 24*N + 6/12/18 hours
|
||||
elif (i+1) % 6 == 0:
|
||||
# Switch the input back to the stored input
|
||||
input, input_surface = input_6, input_surface_6
|
||||
|
||||
# Call the model pretrained for 6 hours forecast
|
||||
output, output_surface = PanguModel6(input, input_surface)
|
||||
|
||||
# Restore from uniformed output
|
||||
output = output * weather_std + weather_mean
|
||||
output_surface = output_surface * weather_surface_std + weather_surface_mean
|
||||
|
||||
# Stored the output for next round forecast
|
||||
input_6, input_surface_6 = output, output_surface
|
||||
input_3, input_surface_3 = output, output_surface
|
||||
|
||||
# switch to the 3-hour model if the forecast time is 3 hours, 9 hours, ..., 6*N + 3 hours
|
||||
elif (i+1) % 3 ==0:
|
||||
# Switch the input back to the stored input
|
||||
input, input_surface = input_3, input_surface_3
|
||||
|
||||
# Call the model pretrained for 3 hours forecast
|
||||
output, output_surface = PanguModel3(input, input_surface)
|
||||
|
||||
# Restore from uniformed output
|
||||
output = output * weather_std + weather_mean
|
||||
output_surface = output_surface * weather_surface_std + weather_surface_mean
|
||||
|
||||
# Stored the output for next round forecast
|
||||
input_3, input_surface_3 = output, output_surface
|
||||
|
||||
# switch to the 1-hour model
|
||||
else:
|
||||
# Call the model pretrained for 1 hours forecast
|
||||
output, output_surface = PanguModel1(input, input_surface)
|
||||
|
||||
# Restore from uniformed output
|
||||
output = output * weather_std + weather_mean
|
||||
output_surface = output_surface * weather_surface_std + weather_surface_mean
|
||||
|
||||
# Stored the output for next round forecast
|
||||
input, input_surface = output, output_surface
|
||||
|
||||
# Save the output
|
||||
output_list.append((output, output_surface))
|
||||
return output_list
|
||||
|
||||
|
||||
def Train():
|
||||
'''Training code'''
|
||||
# Initialize the model, for some APIs some adaptation is needed to fit hardwares
|
||||
model = PanguModel()
|
||||
|
||||
# Train single Pangu-Weather model
|
||||
epochs = 100
|
||||
for i in range(epochs):
|
||||
# For each epoch, we iterate from 1979 to 2017
|
||||
# dataset_length is the length of your training data, e.g., the sample between 1979 and 2017
|
||||
for step in range(dataset_length):
|
||||
# Load weather data at time t as the input; load weather data at time t+1/3/6/24 as the output
|
||||
# Note the data need to be randomly shuffled
|
||||
# Note the input and target need to be normalized, see Inference() for details
|
||||
input, input_surface, target, target_surface = LoadData(step)
|
||||
|
||||
# Call the model and get the output
|
||||
output, output_surface = model(input, input_surface)
|
||||
|
||||
# We use the MAE loss to train the model
|
||||
# The weight of surface loss is 0.25
|
||||
# Different weight can be applied for differen fields if needed
|
||||
loss = TensorAbs(output-target) + TensorAbs(output_surface-target_surface) * 0.25
|
||||
|
||||
# Call the backward algorithm and calculate the gratitude of parameters
|
||||
Backward(loss)
|
||||
|
||||
# Update model parameters with Adam optimizer
|
||||
# The learning rate is 5e-4 as in the paper, while the weight decay is 3e-6
|
||||
# A example solution is using torch.optim.adam
|
||||
UpdateModelParametersWithAdam()
|
||||
|
||||
# Save the model at the end of the training stage
|
||||
SaveModel()
|
||||
|
||||
class PanguModel:
|
||||
def __init__(self):
|
||||
# Drop path rate is linearly increased as the depth increases
|
||||
drop_path_list = LinearSpace(0, 0.2, 8)
|
||||
|
||||
# Patch embedding
|
||||
self._input_layer = PatchEmbedding((2, 4, 4), 192)
|
||||
|
||||
# Four basic layers
|
||||
self.layer1 = EarthSpecificLayer(2, 192, drop_list[:2], 6)
|
||||
self.layer2 = EarthSpecificLayer(6, 384, drop_list[6:], 12)
|
||||
self.layer3 = EarthSpecificLayer(6, 384, drop_list[6:], 12)
|
||||
self.layer4 = EarthSpecificLayer(2, 192, drop_list[:2], 6)
|
||||
|
||||
# Upsample and downsample
|
||||
self.upsample = UpSample(384, 192)
|
||||
self.downsample = DownSample(192)
|
||||
|
||||
# Patch Recovery
|
||||
self._output_layer = PatchRecovery(384)
|
||||
|
||||
def forward(self, input, input_surface):
|
||||
'''Backbone architecture'''
|
||||
# Embed the input fields into patches
|
||||
x = self._input_layer(input, input_surface)
|
||||
|
||||
# Encoder, composed of two layers
|
||||
# Layer 1, shape (8, 360, 181, C), C = 192 as in the original paper
|
||||
x = self.layer1(x, 8, 360, 181)
|
||||
|
||||
# Store the tensor for skip-connection
|
||||
skip = x
|
||||
|
||||
# Downsample from (8, 360, 181) to (8, 180, 91)
|
||||
x = self.downsample(x, 8, 360, 181)
|
||||
|
||||
# Layer 2, shape (8, 180, 91, 2C), C = 192 as in the original paper
|
||||
x = self.layer2(x, 8, 180, 91)
|
||||
|
||||
# Decoder, composed of two layers
|
||||
# Layer 3, shape (8, 180, 91, 2C), C = 192 as in the original paper
|
||||
x = self.layer3(x, 8, 180, 91)
|
||||
|
||||
# Upsample from (8, 180, 91) to (8, 360, 181)
|
||||
x = self.upsample(x)
|
||||
|
||||
# Layer 4, shape (8, 360, 181, 2C), C = 192 as in the original paper
|
||||
x = self.layer4(x, 8, 360, 181)
|
||||
|
||||
# Skip connect, in last dimension(C from 192 to 384)
|
||||
x = Concatenate(skip, x)
|
||||
|
||||
# Recover the output fields from patches
|
||||
output, output_surface = self._output_layer(x)
|
||||
return output, output_surface
|
||||
|
||||
class PatchEmbedding:
|
||||
def __init__(self, patch_size, dim):
|
||||
'''Patch embedding operation'''
|
||||
# Here we use convolution to partition data into cubes
|
||||
self.conv = Conv3d(input_dims=5, output_dims=dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.conv_surface = Conv2d(input_dims=7, output_dims=dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
|
||||
# Load constant masks from the disc
|
||||
self.land_mask, self.soil_type, self.topography = LoadConstantMask()
|
||||
|
||||
def forward(self, input, input_surface):
|
||||
# Zero-pad the input
|
||||
input = Pad3D(input)
|
||||
input_surface = Pad2D(input_surface)
|
||||
|
||||
# Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches, patch_size = (2, 4, 4) as in the original paper
|
||||
input = self.conv(input)
|
||||
|
||||
# Add three constant fields to the surface fields
|
||||
input_surface = Concatenate(input_surface, self.land_mask, self.soil_type, self.topography)
|
||||
|
||||
# Apply a linear projection for patch_size[1]*patch_size[2] patches
|
||||
input_surface = self.conv_surface(input_surface)
|
||||
|
||||
# Concatenate the input in the pressure level, i.e., in Z dimension
|
||||
x = Concatenate(input, input_surface)
|
||||
|
||||
# Reshape x for calculation of linear projections
|
||||
x = TransposeDimensions(x, (0, 2, 3, 4, 1))
|
||||
x = reshape(x, target_shape=(x.shape[0], 8*360*181, x.shape[-1]))
|
||||
return x
|
||||
|
||||
class PatchRecovery:
|
||||
def __init__(self, dim):
|
||||
'''Patch recovery operation'''
|
||||
# Hear we use two transposed convolutions to recover data
|
||||
self.conv = ConvTranspose3d(input_dims=dim, output_dims=5, kernel_size=patch_size, stride=patch_size)
|
||||
self.conv_surface = ConvTranspose2d(input_dims=dim, output_dims=4, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
|
||||
def forward(self, x, Z, H, W):
|
||||
# The inverse operation of the patch embedding operation, patch_size = (2, 4, 4) as in the original paper
|
||||
# Reshape x back to three dimensions
|
||||
x = TransposeDimensions(x, (0, 2, 1))
|
||||
x = reshape(x, target_shape=(x.shape[0], x.shape[1], Z, H, W))
|
||||
|
||||
# Call the transposed convolution
|
||||
output = self.conv(x[:, :, 1:, :, :])
|
||||
output_surface = self.conv_surface(x[:, :, 0, :, :])
|
||||
|
||||
# Crop the output to remove zero-paddings
|
||||
output = Crop3D(output)
|
||||
output_surface = Crop2D(output_surface)
|
||||
return output, output_surface
|
||||
|
||||
class DownSample:
|
||||
def __init__(self, dim):
|
||||
'''Down-sampling operation'''
|
||||
# A linear function and a layer normalization
|
||||
self.linear = Linear(4*dim, 2*dim, bias=Fasle)
|
||||
self.norm = LayerNorm(4*dim)
|
||||
|
||||
def forward(self, x, Z, H, W):
|
||||
# Reshape x to three dimensions for downsampling
|
||||
x = reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[-1]))
|
||||
|
||||
# Padding the input to facilitate downsampling
|
||||
x = Pad3D(x)
|
||||
|
||||
# Reorganize x to reduce the resolution: simply change the order and downsample from (8, 360, 182) to (8, 180, 91)
|
||||
Z, H, W = x.shape
|
||||
# Reshape x to facilitate downsampling
|
||||
x = reshape(x, target_shape=(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1]))
|
||||
# Change the order of x
|
||||
x = TransposeDimensions(x, (0,1,2,4,3,5,6))
|
||||
# Reshape to get a tensor of resolution (8, 180, 91)
|
||||
x = reshape(x, target_shape=(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1]))
|
||||
|
||||
# Call the layer normalization
|
||||
x = self.norm(x)
|
||||
|
||||
# Decrease the channels of the data to reduce computation cost
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class UpSample:
|
||||
def __init__(self, input_dim, output_dim):
|
||||
'''Up-sampling operation'''
|
||||
# Linear layers without bias to increase channels of the data
|
||||
self.linear1 = Linear(input_dim, output_dim*4, bias=False)
|
||||
|
||||
# Linear layers without bias to mix the data up
|
||||
self.linear2 = Linear(output_dim, output_dim, bias=False)
|
||||
|
||||
# Normalization
|
||||
self.norm = LayerNorm(output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# Call the linear functions to increase channels of the data
|
||||
x = self.linear1(x)
|
||||
|
||||
# Reorganize x to increase the resolution: simply change the order and upsample from (8, 180, 91) to (8, 360, 182)
|
||||
# Reshape x to facilitate upsampling.
|
||||
x = reshape(x, target_shape=(x.shape[0], 8, 180, 91, 2, 2, x.shape[-1]//4))
|
||||
# Change the order of x
|
||||
x = TransposeDimensions(x, (0,1,2,4,3,5,6))
|
||||
# Reshape to get Tensor with a resolution of (8, 360, 182)
|
||||
x = reshape(x, target_shape=(x.shape[0], 8, 360, 182, x.shape[-1]))
|
||||
|
||||
# Crop the output to the input shape of the network
|
||||
x = Crop3D(x)
|
||||
|
||||
# Reshape x back
|
||||
x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1]))
|
||||
|
||||
# Call the layer normalization
|
||||
x = self.norm(x)
|
||||
|
||||
# Mixup normalized tensors
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
class EarthSpecificLayer:
|
||||
def __init__(self, depth, dim, drop_path_ratio_list, heads):
|
||||
'''Basic layer of our network, contains 2 or 6 blocks'''
|
||||
self.depth = depth
|
||||
self.blocks = []
|
||||
|
||||
# Construct basic blocks
|
||||
for i in range(depth):
|
||||
self.blocks.append(EarthSpecificBlock(dim, drop_path_ratio_list[i], heads))
|
||||
|
||||
def forward(self, x, Z, H, W):
|
||||
for i in range(self.depth):
|
||||
# Roll the input every two blocks
|
||||
if i % 2 == 0:
|
||||
self.blocks[i](x, Z, H, W, roll=False)
|
||||
else:
|
||||
self.blocks[i](x, Z, H, W, roll=True)
|
||||
return x
|
||||
|
||||
class EarthSpecificBlock:
|
||||
def __init__(self, dim, drop_path_ratio, heads):
|
||||
'''
|
||||
3D transformer block with Earth-Specific bias and window attention,
|
||||
see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
|
||||
The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.
|
||||
'''
|
||||
# Define the window size of the neural network
|
||||
self.window_size = (2, 6, 12)
|
||||
|
||||
# Initialize serveral operations
|
||||
self.drop_path = DropPath(drop_rate=drop_path_ratio)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.linear = MLP(dim, 0)
|
||||
self.attention = EarthAttention3D(dim, heads, 0, self.window_size)
|
||||
|
||||
def forward(self, x, Z, H, W, roll):
|
||||
# Save the shortcut for skip-connection
|
||||
shortcut = x
|
||||
|
||||
# Reshape input to three dimensions to calculate window attention
|
||||
reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[2]))
|
||||
|
||||
# Zero-pad input if needed
|
||||
x = pad3D(x)
|
||||
|
||||
# Store the shape of the input for restoration
|
||||
ori_shape = x.shape
|
||||
|
||||
if roll:
|
||||
# Roll x for half of the window for 3 dimensions
|
||||
x = roll3D(x, shift=[self.window_size[0]//2, self.window_size[1]//2, self.window_size[2]//2])
|
||||
# Generate mask of attention masks
|
||||
# If two pixels are not adjacent, then mask the attention between them
|
||||
# Your can set the matrix element to -1000 when it is not adjacent, then add it to the attention
|
||||
mask = gen_mask(x)
|
||||
else:
|
||||
# e.g., zero matrix when you add mask to attention
|
||||
mask = no_mask
|
||||
|
||||
# Reorganize data to calculate window attention
|
||||
x_window = reshape(x, target_shape=(x.shape[0], Z//window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], x.shape[-1]))
|
||||
x_window = TransposeDimensions(x_window, (0, 1, 3, 5, 2, 4, 6, 7))
|
||||
|
||||
# Get data stacked in 3D cubes, which will further be used to calculated attention among each cube
|
||||
x_window = reshape(x_window, target_shape=(-1, window_size[0]* window_size[1]*window_size[2], x.shape[-1]))
|
||||
|
||||
# Apply 3D window attention with Earth-Specific bias
|
||||
x_window = self.attention(x, mask)
|
||||
|
||||
# Reorganize data to original shapes
|
||||
x = reshape(x_window, target_shape=((-1, Z // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], x_window.shape[-1])))
|
||||
x = TransposeDimensions(x, (0, 1, 4, 2, 5, 3, 6, 7))
|
||||
|
||||
# Reshape the tensor back to its original shape
|
||||
x = reshape(x_window, target_shape=ori_shape)
|
||||
|
||||
if roll:
|
||||
# Roll x back for half of the window
|
||||
x = roll3D(x, shift=[-self.window_size[0]//2, -self.window_size[1]//2, -self.window_size[2]//2])
|
||||
|
||||
# Crop the zero-padding
|
||||
x = Crop3D(x)
|
||||
|
||||
# Reshape the tensor back to the input shape
|
||||
x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[4]))
|
||||
|
||||
# Main calculation stages
|
||||
x = shortcut + self.drop_path(self.norm1(x))
|
||||
x = x + self.drop_path(self.norm2(self.linear(x)))
|
||||
return x
|
||||
|
||||
class EarthAttention3D:
|
||||
def __init__(self, dim, heads, dropout_rate, window_size):
|
||||
'''
|
||||
3D window attention with the Earth-Specific bias,
|
||||
see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
|
||||
'''
|
||||
# Initialize several operations
|
||||
self.linear1 = Linear(dim, dim=3, bias=True)
|
||||
self.linear2 = Linear(dim, dim)
|
||||
self.softmax = SoftMax(dim=-1)
|
||||
self.dropout = DropOut(dropout_rate)
|
||||
|
||||
# Store several attributes
|
||||
self.head_number = heads
|
||||
self.dim = dim
|
||||
self.scale = (dim//heads)**-0.5
|
||||
self.window_size = window_size
|
||||
|
||||
# input_shape is current shape of the self.forward function
|
||||
# You can run your code to record it, modify the code and rerun it
|
||||
# Record the number of different window types
|
||||
self.type_of_windows = (input_shape[0]//window_size[0])*(input_shape[1]//window_size[1])
|
||||
|
||||
# For each type of window, we will construct a set of parameters according to the paper
|
||||
self.earth_specific_bias = ConstructTensor(shape=((2 * window_size[2] - 1) * window_size[1] * window_size[1] * window_size[0] * window_size[0], self.type_of_windows, heads))
|
||||
|
||||
# Making these tensors to be learnable parameters
|
||||
self.earth_specific_bias = Parameters(self.earth_specific_bias)
|
||||
|
||||
# Initialize the tensors using Truncated normal distribution
|
||||
TruncatedNormalInit(self.earth_specific_bias, std=0.02)
|
||||
|
||||
# Construct position index to reuse self.earth_specific_bias
|
||||
self.position_index = self._construct_index()
|
||||
|
||||
def _construct_index(self):
|
||||
''' This function construct the position index to reuse symmetrical parameters of the position bias'''
|
||||
# Index in the pressure level of query matrix
|
||||
coords_zi = RangeTensor(self.window_size[0])
|
||||
# Index in the pressure level of key matrix
|
||||
coords_zj = -RangeTensor(self.window_size[0])*self.window_size[0]
|
||||
|
||||
# Index in the latitude of query matrix
|
||||
coords_hi = RangeTensor(self.window_size[1])
|
||||
# Index in the latitude of key matrix
|
||||
coords_hj = -RangeTensor(self.window_size[1])*self.window_size[1]
|
||||
|
||||
# Index in the longitude of the key-value pair
|
||||
coords_w = RangeTensor(self.window_size[2])
|
||||
|
||||
# Change the order of the index to calculate the index in total
|
||||
coords_1 = Stack(MeshGrid([coords_zi, coords_hi, coords_w]))
|
||||
coords_2 = Stack(MeshGrid([coords_zj, coords_hj, coords_w]))
|
||||
coords_flatten_1 = Flatten(coords_1, start_dimension=1)
|
||||
coords_flatten_2 = Flatten(coords_2, start_dimension=1)
|
||||
coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]
|
||||
coords = TransposeDimensions(coords, (1, 2, 0))
|
||||
|
||||
# Shift the index for each dimension to start from 0
|
||||
coords[:, :, 2] += self.window_size[2] - 1
|
||||
coords[:, :, 1] *= 2 * self.window_size[2] - 1
|
||||
coords[:, :, 0] *= (2 * self.window_size[2] - 1)*self.window_size[1]*self.window_size[1]
|
||||
|
||||
# Sum up the indexes in three dimensions
|
||||
self.position_index = TensorSum(coords, dim=-1)
|
||||
|
||||
# Flatten the position index to facilitate further indexing
|
||||
self.position_index = Flatten(self.position_index)
|
||||
|
||||
def forward(self, x, mask):
|
||||
# Linear layer to create query, key and value
|
||||
x = self.linear1(x)
|
||||
|
||||
# Record the original shape of the input
|
||||
original_shape = x.shape
|
||||
|
||||
# reshape the data to calculate multi-head attention
|
||||
qkv = reshape(x, target_shape=(x.shape[0], x.shape[1], 3, self.head_number, self.dim // self.head_number))
|
||||
query, key, value = TransposeDimensions(qkv, (2, 0, 3, 1, 4))
|
||||
|
||||
# Scale the attention
|
||||
query = query * self.scale
|
||||
|
||||
# Calculated the attention, a learnable bias is added to fix the nonuniformity of the grid.
|
||||
attention = query @ key.T # @ denotes matrix multiplication
|
||||
|
||||
# self.earth_specific_bias is a set of neural network parameters to optimize.
|
||||
EarthSpecificBias = self.earth_specific_bias[self.position_index]
|
||||
|
||||
# Reshape the learnable bias to the same shape as the attention matrix
|
||||
EarthSpecificBias = reshape(EarthSpecificBias, target_shape=(self.window_size[0]*self.window_size[1]*self.window_size[2], self.window_size[0]*self.window_size[1]*self.window_size[2], self.type_of_windows, self.head_number))
|
||||
EarthSpecificBias = TransposeDimensions(EarthSpecificBias, (2, 3, 0, 1))
|
||||
EarthSpecificBias = reshape(EarthSpecificBias, target_shape = [1]+EarthSpecificBias.shape)
|
||||
|
||||
# Add the Earth-Specific bias to the attention matrix
|
||||
attention = attention + EarthSpecificBias
|
||||
|
||||
# Mask the attention between non-adjacent pixels, e.g., simply add -100 to the masked element.
|
||||
attention = self.mask_attention(attention, mask)
|
||||
attention = self.softmax(attention)
|
||||
attention = self.dropout(attention)
|
||||
|
||||
# Calculated the tensor after spatial mixing.
|
||||
x = attention @ value.T # @ denote matrix multiplication
|
||||
|
||||
# Reshape tensor to the original shape
|
||||
x = TransposeDimensions(x, (0, 2, 1))
|
||||
x = reshape(x, target_shape = original_shape)
|
||||
|
||||
# Linear layer to post-process operated tensor
|
||||
x = self.linear2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class Mlp:
|
||||
def __init__(self, dim, dropout_rate):
|
||||
'''MLP layers, same as most vision transformer architectures.'''
|
||||
self.linear1 = Linear(dim, dim * 4)
|
||||
self.linear2 = Linear(dim * 4, dim)
|
||||
self.activation = GeLU()
|
||||
self.drop = DropOut(drop_rate=dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.activation(x)
|
||||
x = self.drop(x)
|
||||
x = self.linear(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def PerlinNoise():
|
||||
'''Generate random Perlin noise: we follow https://github.com/pvigier/perlin-numpy/ to calculate the perlin noise.'''
|
||||
# Define number of noise
|
||||
octaves = 3
|
||||
# Define the scaling factor of noise
|
||||
noise_scale = 0.2
|
||||
# Define the number of periods of noise along the axis
|
||||
period_number = 12
|
||||
# The size of an input slice
|
||||
H, W = 721, 1440
|
||||
# Scaling factor between two octaves
|
||||
persistence = 0.5
|
||||
# see https://github.com/pvigier/perlin-numpy/ for the implementation of GenerateFractalNoise (e.g., from perlin_numpy import generate_fractal_noise_3d)
|
||||
perlin_noise = noise_scale*GenerateFractalNoise((H, W), (period_number, period_number), octaves, persistence)
|
||||
return perlin_noise
|
||||
@@ -0,0 +1,3 @@
|
||||
numpy
|
||||
onnx==1.13.1
|
||||
onnxruntime==1.14.0
|
||||
@@ -0,0 +1,3 @@
|
||||
numpy
|
||||
onnx==1.12.0
|
||||
onnxruntime-gpu==1.14.0
|
||||
Reference in New Issue
Block a user