Compare commits
20 Commits
add-license-1
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 8642e9806b | |||
| fa4cec5805 | |||
| 4d0d1eaed7 | |||
| 1130d80c2d | |||
| a359fa116e | |||
| b5850da258 | |||
| d2b8ae962d | |||
| d9f2964faf | |||
| 62acbda2c4 | |||
| cf3547114d | |||
| 0387b09086 | |||
| 84a1e2448d | |||
| 597f90e3d8 | |||
| 27d322a56a | |||
| b7a32defaf | |||
| d289a9a310 | |||
| d2df335fb0 | |||
| bf7834f827 | |||
| 95bc883352 | |||
| 0be504e554 |
-21
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Lingxi Xie
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,12 +1,30 @@
|
||||
## Pangu-Weather
|
||||
|
||||
This is the official repository for the Pangu-Weather paper.
|
||||
This is the official repository for the Pangu-Weather papers.
|
||||
|
||||
[Accurate medium-range global weather forecasting with 3D neural networks](https://www.nature.com/articles/s41586-023-06185-3), Nature, Volume 619, Pages 533–538, 2023.
|
||||
|
||||
[Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast](https://arxiv.org/abs/2211.02556), arXiv preprint: 2211.02556, 2022.
|
||||
|
||||
by Kaifeng Bi, Lingxi Xie, Hengheng Zhang, Xin Chen, Xiaotao Gu and Qi Tian
|
||||
*by Kaifeng Bi, Lingxi Xie, Hengheng Zhang, Xin Chen, Xiaotao Gu and Qi Tian*
|
||||
|
||||
Resources including pseudocode, pre-trained models, and inference code are released.
|
||||
**Note: the arXiv version offers more technical details, and the Nature paper contains some new figures.**
|
||||
|
||||
Resources including pseudocode, pre-trained models, and inference code are released here.
|
||||
|
||||
The slides used in a series of recent talks are attached here. [Baidu Netdisk](https://pan.baidu.com/s/14ZGywcr4XAK5dk75-8PUqA?pwd=9sco), extraction code: 9sco
|
||||
|
||||
## News and Updates
|
||||
|
||||
* [Jul 31 2023] We released the details of training the lite version of Pangu-Weather.
|
||||
* [Jul 19 2023] ECMWF released an official [technical report](https://arxiv.org/abs/2307.10128) for "the rise of data-driven weather forecasting". Pangu-Weather was mentioned and tested thoroughly in the paper. We thank ECMWF for testing our models in real-world scenarios.
|
||||
* [Jul 17 2023] Pangu-Weather was online as part of ECMWF's operational suite! Everyone can see 10-day global weather forecasting **without running code**. ECMWF has made use of the models released at this repository! [Please search the ECMWF charts website with the query of "PANGU".](https://charts.ecmwf.int/?query=pangu)
|
||||
* [Jul 05 2023] Pangu-Weather was published on [Nature](https://www.nature.com/articles/s41586-023-06185-3). It was made **Open Access**! We recommend the researchers to cite our Nature paper in the future.
|
||||
* [Jun 27 2023] Pangu-Weather was presented at [PASC 2023](https://pasc23.pasc-conference.org/program/schedule/).
|
||||
* [Jun 12 2023] Pangu-Weather was presented at [VALSE 2023](http://valser.org/2023/#/workshopde?id=15).
|
||||
* [May 27 2023] Pangu-Weather was presented at [the WMO Early Warning for All (EW4ALL) conference](https://community.wmo.int/en/news/exploring-possibilities-artificial-intelligence-areas-water-weather-and-climate).
|
||||
* [May 12 2023] ECMWF released a [repository](https://github.com/ecmwf-lab/ai-models-panguweather), offering a toolkit for running Pangu-Weather. We thank ECMWF for the efforts in easing everyone to test Pangu-Weather.
|
||||
* [May 09 2023] Pangu-Weather was accepted by Nature!
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -18,10 +36,10 @@ The downloaded files shall be organized as the following hierarchy:
|
||||
│ │ ├── input_surface.npy
|
||||
│ │ ├── input_upper.npy
|
||||
│ ├── output_data
|
||||
│ ├── model_jit_cpu_1.onnx
|
||||
│ ├── model_jit_cpu_3.onnx
|
||||
│ ├── model_jit_cpu_6.onnx
|
||||
│ ├── model_jit_cpu_24.onnx
|
||||
│ ├── pangu_weather_1.onnx
|
||||
│ ├── pangu_weather_3.onnx
|
||||
│ ├── pangu_weather_6.onnx
|
||||
│ ├── pangu_weather_24.onnx
|
||||
│ ├── inference_cpu.py
|
||||
│ ├── inference_gpu.py
|
||||
│ ├── inference_iterative.py
|
||||
@@ -29,27 +47,27 @@ The downloaded files shall be organized as the following hierarchy:
|
||||
|
||||
If you use a CPU environment, please run:
|
||||
```
|
||||
pip install -r requirement_cpu.txt
|
||||
pip install -r requirements_cpu.txt
|
||||
```
|
||||
|
||||
If you use a GPU environment, please first confirm that the cuda version is 11.6 and the cudnn version is the 8.2.4 for Linux and 8.5.0.96 for Windows (please see [this page](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) for details). Then, please run:
|
||||
```
|
||||
pip install -r requirement_gpu.txt
|
||||
pip install -r requirements_gpu.txt
|
||||
```
|
||||
|
||||
## Global weather forecasting (inference) using the trained models
|
||||
|
||||
#### Downloading trained models
|
||||
|
||||
Please download the four pre-trained models (~1.1GB each) from the Google drive:
|
||||
Please download the four pre-trained models (~1.1GB each) from Google drive or Baidu netdisk:
|
||||
|
||||
The 1-hour model: [model_jit_cpu_1.onnx](https://drive.google.com/file/d/1fg5jkiN_5dHzKb-5H9Aw4MOmfILmeY-S/view?usp=share_link)
|
||||
The 1-hour model (pangu_weather_1.onnx): [Google drive](https://drive.google.com/file/d/1fg5jkiN_5dHzKb-5H9Aw4MOmfILmeY-S/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1M7SAigVsCSH8hpw6DE8TDQ?pwd=ie0h)
|
||||
|
||||
The 3-hour model: [model_jit_cpu_3.onnx](https://drive.google.com/file/d/1EdoLlAXqE9iZLt9Ej9i-JW9LTJ9Jtewt/view?usp=share_link)
|
||||
The 3-hour model (pangu_weather_3.onnx): [Google drive](https://drive.google.com/file/d/1EdoLlAXqE9iZLt9Ej9i-JW9LTJ9Jtewt/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/197fZsoiCqZYzKwM7tyRrfg?pwd=gmcl)
|
||||
|
||||
The 6-hour model: [model_jit_cpu_6.onnx](https://drive.google.com/file/d/1a4XTktkZa5GCtjQxDJb_fNaqTAUiEJu4/view?usp=share_link)
|
||||
The 6-hour model (pangu_weather_6.onnx): [Google drive](https://drive.google.com/file/d/1a4XTktkZa5GCtjQxDJb_fNaqTAUiEJu4/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1q7IB7tNjqIwoGC7KVMPn4w?pwd=vxq3)
|
||||
|
||||
The 24-hour model: [model_jit_cpu_24.onnx](https://drive.google.com/file/d/1lweQlxcn9fG0zKNW8ne1Khr9ehRTI6HP/view?usp=share_link)
|
||||
The 24-hour model (pangu_weather_24.onnx): [Google drive](https://drive.google.com/file/d/1lweQlxcn9fG0zKNW8ne1Khr9ehRTI6HP/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/179q2gkz2BrsOR6g3yfTVQg?pwd=eajy)
|
||||
|
||||
These models are stored using the ONNX format, and thus can be used via different languages such as Python, C++, C#, Java, etc.
|
||||
|
||||
@@ -57,9 +75,9 @@ These models are stored using the ONNX format, and thus can be used via differen
|
||||
|
||||
Please prepare the input data using [numpy](https://numpy.org/). There are two files that shall be put under the `input_data` folder, namely, `input_surface.npy` and `input_upper.npy`.
|
||||
|
||||
`input_surface.npy` stores the input surface variables. It is a numpy array shaped (4,721,1440) where the first dimension represents the 4 surface variables (MSLP, U10, V10, T2M *in the exact order*).
|
||||
`input_surface.npy` stores the input surface variables. It is a numpy array shaped (4,721,1440) where the first dimension represents the 4 surface variables (MSLP, U10, V10, T2M **in the exact order**).
|
||||
|
||||
`input_upper.npy` stores the upper-air variables. It is a numpy array shaped (5,13,721,1440) where the first dimension represents the 5 surface variables (Z, Q, T, U and V *in the exact order*), and the second dimension represents the 13 pressure levels (1000hPa, 925hPa, 850hPa, 700hPa, 600hPa, 500hPa, 400hPa, 300hPa, 250hPa, 200hPa, 150hPa, 100hPa and 50hPa *in the exact order*).
|
||||
`input_upper.npy` stores the upper-air variables. It is a numpy array shaped (5,13,721,1440) where the first dimension represents the 5 surface variables (Z, Q, T, U and V **in the exact order**), and the second dimension represents the 13 pressure levels (1000hPa, 925hPa, 850hPa, 700hPa, 600hPa, 500hPa, 400hPa, 300hPa, 250hPa, 200hPa, 150hPa, 100hPa and 50hPa **in the exact order**).
|
||||
|
||||
In both cases, the dimensions of 721 and 1440 represent the size along the latitude and longitude, where the numerical range is [90,-90] degree and [0,359.75] degree, respectively, and the spacing is 0.25 degrees. For each 721x1440 slice, the data format is exactly the same as the `.nc` file download from the ERA5 official website.
|
||||
|
||||
@@ -69,11 +87,11 @@ We support ERA5 initial fields and ECMWF initial fields (e.g., the initial field
|
||||
|
||||
We temporarily do not support other kinds of initial fields due to the possibly dramatic differences in the fields when Z<0.
|
||||
|
||||
We provide an example of transferred input files, `input_surface.npy` and `input_upper.npy`, which correspond to the ERA5 initial fields of at 12:00UTC, 2018/09/27. Please download them using Google drive:
|
||||
We provide an example of transferred input files, `input_surface.npy` and `input_upper.npy`, which correspond to the ERA5 initial fields of at 12:00UTC, 2018/09/27. Please download them from Google drive or Baidu netdisk:
|
||||
|
||||
[`input_surface.npy`](https://drive.google.com/file/d/1pj8QEVNpC1FyJfUabDpV4oU3NpSe0BkD/view?usp=share_link)
|
||||
`input_surface.npy`: [Google drive](https://drive.google.com/file/d/1pj8QEVNpC1FyJfUabDpV4oU3NpSe0BkD/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1i4o5i8guAqmOus6PWncAlA?pwd=4z9s)
|
||||
|
||||
[`input_upper.npy`](https://drive.google.com/file/d/1--7xEBJt79E3oixizr8oFmK_haDE77SS/view?usp=share_link)
|
||||
`input_upper.npy`: [Google drive](https://drive.google.com/file/d/1--7xEBJt79E3oixizr8oFmK_haDE77SS/view?usp=share_link)/[Baidu netdisk](https://pan.baidu.com/s/1mS8X5MqEdbVfF2u2Us62FQ?pwd=sgx6)
|
||||
|
||||
#### Inference
|
||||
|
||||
@@ -92,16 +110,59 @@ Also, `inference_iterative.py` shows an example to generate per-6-hour forecast
|
||||
|
||||
Note that one needs to download about 60TB of ERA5 data and prepare for computational resource of 3000 GPU-days (in V100) to train each model.
|
||||
|
||||
## Training a lite version
|
||||
|
||||
Recently, we found that Pangu-Weather can be trained efficiently using only 1% of data and GPU computation. We call the version Pangu-Weather-lite. Note that the lite models cannot rival the full models, but the lite version offers opportunities for researchers with limited resource to explore the AI methods for weather forecasting.
|
||||
|
||||
Here are the key implementation details.
|
||||
|
||||
* Data. We reduced the training data into 11 years (2007-2017) and only used the 00UTC time point (the full version used all 24 time points throughout the day). Also, only 00UTC data is used in the testing phase. The total amount of downloaded data shall be less than 1TB.
|
||||
* Model. We adjusted the down-sampling rate in the first stage from 2x4x4 to 2x8x8.
|
||||
* Training epochs. One can remain using 100 epochs or reduce the number to 50 (half); note that the cosine annealing schedule is adjusted accordingly.
|
||||
* Model set. We only trained one model (lead time is 24 hours), which means that the lite version can only perform daily weather forecasting.
|
||||
|
||||
Here are the results.
|
||||
|
||||
| Method | RMSE, Z500 | RMSE, T850 | RMSE, T2M | RMSE, U10 | Years | Down-sampling | Epochs | GPU x days |
|
||||
| ------------------- | ---------------------- | -------------------- | -------------------- | -------------------- | ----- | ------------- | -- | ---------- |
|
||||
| Operational IFS | 152.8 (3d), 333.7 (5d) | 1.37 (3d), 2.06 (5d) | 1.34 (3d), 1.75 (5d) | 1.94 (3d), 2.90 (5d) | -- | -- | -- | -- |
|
||||
| Pangu-Weather | 134.5 (3d), 296.7 (5d) | 1.14 (3d), 1.79 (5d) | 1.05 (3d), 1.53 (5d) | 1.61 (3d), 2.53 (5d) | 39 | 2 x 4 x 4 | 100 | 192 x 16 |
|
||||
| Pangu-Weather-Lite1 | 163.1 (3d), 338.2 (5d) | 1.29 (3d), 1.96 (5d) | 1.16 (3d), 1.64 (5d) | 1.80 (3d), 2.74 (5d) | 11 | 2 x 8 x 8 | 100 | 8 x 6 |
|
||||
| Pangu-Weather-Lite2 | 177.9 (3d), 357.5 (5d) | 1.36 (3d), 2.05 (5d) | 1.24 (3d), 1.71 (5d) | 1.90 (3d), 2.84 (5d) | 11 | 2 x 8 x 8 | 50 | 8 x 3 |
|
||||
|
||||
One can observe that the lite version can surpass operational IFS (*when tested only at 00UTC time points*) in T850 (850hPa temperature), T2M (2m temperature) and U10 (u-component of 10m wind speed), while requiring less than 1% of computational costs compared to the full version.
|
||||
|
||||
Please note that the lite version was only trained and tested in 00UTC data. This means that its performance on other time points is not guaranteed. Since whether variables are closely correlated to time-in-day, it is difficult to directly use the lite version for daily whether forecasting. Again, the lite version is to ease the researchers to explore the property of AI models.
|
||||
|
||||
## License
|
||||
|
||||
Pangu-Weather is released under the MIT license.
|
||||
Pangu-Weather was released by Huawei Cloud.
|
||||
|
||||
Also, please note that all models were trained using the ERA5 dataset provided by ECMWF. Please do follow their policy and note that commercial use of these models is forbidden.
|
||||
The trained parameters of Pangu-Weather were made available under the terms of the BY-NC-SA 4.0 license. You can find details [here](https://creativecommons.org/licenses/by-nc-sa/4.0/).
|
||||
|
||||
**The commercial use of these models is forbidden.**
|
||||
|
||||
Also, please note that all models were trained using the ERA5 dataset provided by ECMWF. Please do follow [their policy](https://apps.ecmwf.int/datasets/licences/copernicus/).
|
||||
|
||||
## References
|
||||
|
||||
If you use the resource in your research, please cite our paper:
|
||||
|
||||
```
|
||||
@article{bi2023accurate,
|
||||
title={Accurate medium-range global weather forecasting with 3D neural networks},
|
||||
author={Bi, Kaifeng and Xie, Lingxi and Zhang, Hengheng and Chen, Xin and Gu, Xiaotao and Tian, Qi},
|
||||
journal={Nature},
|
||||
volume={619},
|
||||
number={7970},
|
||||
pages={533--538},
|
||||
year={2023},
|
||||
publisher={Nature Publishing Group}
|
||||
}
|
||||
```
|
||||
|
||||
We also offer the bibliography of the arXiv preprint version for your information.
|
||||
|
||||
```
|
||||
@article{bi2022pangu,
|
||||
title={Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast},
|
||||
|
||||
+1
-1
@@ -21,7 +21,7 @@ options.intra_op_num_threads = 1
|
||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||
|
||||
# Initialize onnxruntime session for Pangu-Weather Models
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=['CPUExecutionProvider'])
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider'])
|
||||
|
||||
# Load the upper-air numpy arrays
|
||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||
|
||||
+1
-1
@@ -21,7 +21,7 @@ options.intra_op_num_threads = 1
|
||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||
|
||||
# Initialize onnxruntime session for Pangu-Weather Models
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
|
||||
# Load the upper-air numpy arrays
|
||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||
|
||||
@@ -22,8 +22,8 @@ options.intra_op_num_threads = 1
|
||||
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
|
||||
|
||||
# Initialize onnxruntime session for Pangu-Weather Models
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, provider=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
ort_session_24 = ort.InferenceSession('pangu_weather_24.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
ort_session_6 = ort.InferenceSession('pangu_weather_6.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
|
||||
|
||||
# Load the upper-air numpy arrays
|
||||
input = np.load(os.path.join(input_data_dir, 'input_upper.npy')).astype(np.float32)
|
||||
|
||||
+599
@@ -0,0 +1,599 @@
|
||||
'''
|
||||
Pseudocode of Pangu-Weather
|
||||
'''
|
||||
# The pseudocode can be implemented using deep learning libraries, e.g., Pytorch and Tensorflow or other high-level APIs
|
||||
|
||||
# Basic operations used in our model, namely Linear, Conv3d, Conv2d, ConvTranspose3d and ConvTranspose2d
|
||||
# Linear: Linear transformation, available in all deep learning libraries
|
||||
# Conv3d and Con2d: Convolution with 2 or 3 dimensions, available in all deep learning libraries
|
||||
# ConvTranspose3d, ConvTranspose2d: transposed convolution with 2 or 3 dimensions, see Pytorch API or Tensorflow API
|
||||
from Your_AI_Library import Linear, Conv3d, Conv2d, ConvTranspose3d, ConvTranspose2d
|
||||
|
||||
# Functions in the networks, namely GeLU, DropOut, DropPath, LayerNorm, and SoftMax
|
||||
# GeLU: the GeLU activation function, see Pytorch API or Tensorflow API
|
||||
# DropOut: the dropout function, available in all deep learning libraries
|
||||
# DropPath: the DropPath function, see the implementation of vision-transformer, see timm pakage of Pytorch
|
||||
# A possible implementation of DropPath: from timm.models.layers import DropPath
|
||||
# LayerNorm: the layer normalization function, see Pytorch API or Tensorflow API
|
||||
# Softmax: softmax function, see Pytorch API or Tensorflow API
|
||||
from Your_AI_Library import GeLU, DropOut, DropPath, LayerNorm, SoftMax
|
||||
|
||||
# Common functions for roll, pad, and crop, depends on the data structure of your software environment
|
||||
from Your_AI_Library import roll3D, pad3D, pad2D, Crop3D, Crop2D
|
||||
|
||||
# Common functions for reshaping and changing the order of dimensions
|
||||
# reshape: change the shape of the data with the order unchanged, see Pytorch API or Tensorflow API
|
||||
# TransposeDimensions: change the order of the dimensions, see Pytorch API or Tensorflow API
|
||||
from Your_AI_Library import reshape, TransposeDimensions
|
||||
|
||||
# Common functions for creating new tensors
|
||||
# ConstructTensor: create a new tensor with an arbitrary shape
|
||||
# TruncatedNormalInit: Initialize the tensor with Truncate Normalization distribution
|
||||
# RangeTensor: create a new tensor like range(a, b)
|
||||
from Your_AI_Library import ConstructTensor, TruncatedNormalInit, RangeTensor
|
||||
|
||||
# Common operations for the data, you may design it or simply use deep learning APIs default operations
|
||||
# LinearSpace: a tensor version of numpy.linspace
|
||||
# MeshGrid: a tensor version of numpy.meshgrid
|
||||
# Stack: a tensor version of numpy.stack
|
||||
# Flatten: a tensor version of numpy.ndarray.flatten
|
||||
# TensorSum: a tensor version of numpy.sum
|
||||
# TensorAbs: a tensor version of numpy.abs
|
||||
# Concatenate: a tensor version of numpy.concatenate
|
||||
from Your_AI_Library import LinearSpace, MeshGrid, Stack, Flatten, TensorSum, TensorAbs, Concatenate
|
||||
|
||||
# Common functions for training models
|
||||
# LoadModel and SaveModel: Load and save the model, some APIs may require further adaptation to hardwares
|
||||
# Backward: Gradient backward to calculate the gratitude of each parameters
|
||||
# UpdateModelParametersWithAdam: Use Adam to update parameters, e.g., torch.optim.Adam
|
||||
from Your_AI_Library import LoadModel, Backward, UpdateModelParametersWithAdam, SaveModel
|
||||
|
||||
# Custom functions to read your data from the disc
|
||||
# LoadData: Load the ERA5 data
|
||||
# LoadConstantMask: Load constant masks, e.g., soil type
|
||||
# LoadStatic: Load mean and std of the ERA5 training data, every fields such as T850 is treated as an image and calculate the mean and std
|
||||
from Your_Data_Code import LoadData, LoadConstantMask, LoadStatic
|
||||
|
||||
|
||||
def Inference(input, input_surface, forecast_range):
|
||||
'''Inference code, describing the algorithm of inference using models with different lead times.
|
||||
PanguModel24, PanguModel6, PanguModel3 and PanguModel1 share the same training algorithm but differ in lead times.
|
||||
Args:
|
||||
input: input tensor, need to be normalized to N(0, 1) in practice
|
||||
input_surface: target tensor, need to be normalized to N(0, 1) in practice
|
||||
forecast_range: iteration numbers when roll out the forecast model
|
||||
'''
|
||||
|
||||
# Load 4 pre-trained models with different lead times
|
||||
PanguModel24 = LoadModel(ModelPath24)
|
||||
PanguModel6 = LoadModel(ModelPath6)
|
||||
PanguModel3 = LoadModel(ModelPath3)
|
||||
PanguModel1 = LoadModel(ModelPath1)
|
||||
|
||||
# Load mean and std of the weather data
|
||||
weather_mean, weather_std, weather_surface_mean, weather_surface_std = LoadStatic()
|
||||
|
||||
# Store initial input for different models
|
||||
input_24, input_surface_24 = input, input_surface
|
||||
input_6, input_surface_6 = input, input_surface
|
||||
input_3, input_surface_3 = input, input_surface
|
||||
|
||||
# Using a list to store output
|
||||
output_list = []
|
||||
|
||||
# Note: the following code is implemented for fast inference of [1,forecast_range]-hour forecasts -- if only one lead time is requested, the inference can be much faster.
|
||||
for i in range(forecast_range):
|
||||
# switch to the 24-hour model if the forecast time is 24 hours, 48 hours, ..., 24*N hours
|
||||
if (i+1) % 24 == 0:
|
||||
# Switch the input back to the stored input
|
||||
input, input_surface = input_24, input_surface_24
|
||||
|
||||
# Call the model pretrained for 24 hours forecast
|
||||
output, output_surface = PanguModel24(input, input_surface)
|
||||
|
||||
# Restore from uniformed output
|
||||
output = output * weather_std + weather_mean
|
||||
output_surface = output_surface * weather_surface_std + weather_surface_mean
|
||||
|
||||
# Stored the output for next round forecast
|
||||
input_24, input_surface_24 = output, output_surface
|
||||
input_6, input_surface_6 = output, output_surface
|
||||
input_3, input_surface_3 = output, output_surface
|
||||
|
||||
# switch to the 6-hour model if the forecast time is 30 hours, 36 hours, ..., 24*N + 6/12/18 hours
|
||||
elif (i+1) % 6 == 0:
|
||||
# Switch the input back to the stored input
|
||||
input, input_surface = input_6, input_surface_6
|
||||
|
||||
# Call the model pretrained for 6 hours forecast
|
||||
output, output_surface = PanguModel6(input, input_surface)
|
||||
|
||||
# Restore from uniformed output
|
||||
output = output * weather_std + weather_mean
|
||||
output_surface = output_surface * weather_surface_std + weather_surface_mean
|
||||
|
||||
# Stored the output for next round forecast
|
||||
input_6, input_surface_6 = output, output_surface
|
||||
input_3, input_surface_3 = output, output_surface
|
||||
|
||||
# switch to the 3-hour model if the forecast time is 3 hours, 9 hours, ..., 6*N + 3 hours
|
||||
elif (i+1) % 3 ==0:
|
||||
# Switch the input back to the stored input
|
||||
input, input_surface = input_3, input_surface_3
|
||||
|
||||
# Call the model pretrained for 3 hours forecast
|
||||
output, output_surface = PanguModel3(input, input_surface)
|
||||
|
||||
# Restore from uniformed output
|
||||
output = output * weather_std + weather_mean
|
||||
output_surface = output_surface * weather_surface_std + weather_surface_mean
|
||||
|
||||
# Stored the output for next round forecast
|
||||
input_3, input_surface_3 = output, output_surface
|
||||
|
||||
# switch to the 1-hour model
|
||||
else:
|
||||
# Call the model pretrained for 1 hours forecast
|
||||
output, output_surface = PanguModel1(input, input_surface)
|
||||
|
||||
# Restore from uniformed output
|
||||
output = output * weather_std + weather_mean
|
||||
output_surface = output_surface * weather_surface_std + weather_surface_mean
|
||||
|
||||
# Stored the output for next round forecast
|
||||
input, input_surface = output, output_surface
|
||||
|
||||
# Save the output
|
||||
output_list.append((output, output_surface))
|
||||
return output_list
|
||||
|
||||
|
||||
def Train():
|
||||
'''Training code'''
|
||||
# Initialize the model, for some APIs some adaptation is needed to fit hardwares
|
||||
model = PanguModel()
|
||||
|
||||
# Train single Pangu-Weather model
|
||||
epochs = 100
|
||||
for i in range(epochs):
|
||||
# For each epoch, we iterate from 1979 to 2017
|
||||
# dataset_length is the length of your training data, e.g., the sample between 1979 and 2017
|
||||
for step in range(dataset_length):
|
||||
# Load weather data at time t as the input; load weather data at time t+1/3/6/24 as the output
|
||||
# Note the data need to be randomly shuffled
|
||||
# Note the input and target need to be normalized, see Inference() for details
|
||||
input, input_surface, target, target_surface = LoadData(step)
|
||||
|
||||
# Call the model and get the output
|
||||
output, output_surface = model(input, input_surface)
|
||||
|
||||
# We use the MAE loss to train the model
|
||||
# The weight of surface loss is 0.25
|
||||
# Different weight can be applied for differen fields if needed
|
||||
loss = TensorAbs(output-target) + TensorAbs(output_surface-target_surface) * 0.25
|
||||
|
||||
# Call the backward algorithm and calculate the gratitude of parameters
|
||||
Backward(loss)
|
||||
|
||||
# Update model parameters with Adam optimizer
|
||||
# The learning rate is 5e-4 as in the paper, while the weight decay is 3e-6
|
||||
# A example solution is using torch.optim.adam
|
||||
UpdateModelParametersWithAdam()
|
||||
|
||||
# Save the model at the end of the training stage
|
||||
SaveModel()
|
||||
|
||||
class PanguModel:
|
||||
def __init__(self):
|
||||
# Drop path rate is linearly increased as the depth increases
|
||||
drop_path_list = LinearSpace(0, 0.2, 8)
|
||||
|
||||
# Patch embedding
|
||||
self._input_layer = PatchEmbedding((2, 4, 4), 192)
|
||||
|
||||
# Four basic layers
|
||||
self.layer1 = EarthSpecificLayer(2, 192, drop_list[:2], 6)
|
||||
self.layer2 = EarthSpecificLayer(6, 384, drop_list[6:], 12)
|
||||
self.layer3 = EarthSpecificLayer(6, 384, drop_list[6:], 12)
|
||||
self.layer4 = EarthSpecificLayer(2, 192, drop_list[:2], 6)
|
||||
|
||||
# Upsample and downsample
|
||||
self.upsample = UpSample(384, 192)
|
||||
self.downsample = DownSample(192)
|
||||
|
||||
# Patch Recovery
|
||||
self._output_layer = PatchRecovery(384)
|
||||
|
||||
def forward(self, input, input_surface):
|
||||
'''Backbone architecture'''
|
||||
# Embed the input fields into patches
|
||||
x = self._input_layer(input, input_surface)
|
||||
|
||||
# Encoder, composed of two layers
|
||||
# Layer 1, shape (8, 360, 181, C), C = 192 as in the original paper
|
||||
x = self.layer1(x, 8, 360, 181)
|
||||
|
||||
# Store the tensor for skip-connection
|
||||
skip = x
|
||||
|
||||
# Downsample from (8, 360, 181) to (8, 180, 91)
|
||||
x = self.downsample(x, 8, 360, 181)
|
||||
|
||||
# Layer 2, shape (8, 180, 91, 2C), C = 192 as in the original paper
|
||||
x = self.layer2(x, 8, 180, 91)
|
||||
|
||||
# Decoder, composed of two layers
|
||||
# Layer 3, shape (8, 180, 91, 2C), C = 192 as in the original paper
|
||||
x = self.layer3(x, 8, 180, 91)
|
||||
|
||||
# Upsample from (8, 180, 91) to (8, 360, 181)
|
||||
x = self.upsample(x)
|
||||
|
||||
# Layer 4, shape (8, 360, 181, 2C), C = 192 as in the original paper
|
||||
x = self.layer4(x, 8, 360, 181)
|
||||
|
||||
# Skip connect, in last dimension(C from 192 to 384)
|
||||
x = Concatenate(skip, x)
|
||||
|
||||
# Recover the output fields from patches
|
||||
output, output_surface = self._output_layer(x)
|
||||
return output, output_surface
|
||||
|
||||
class PatchEmbedding:
|
||||
def __init__(self, patch_size, dim):
|
||||
'''Patch embedding operation'''
|
||||
# Here we use convolution to partition data into cubes
|
||||
self.conv = Conv3d(input_dims=5, output_dims=dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.conv_surface = Conv2d(input_dims=7, output_dims=dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
|
||||
# Load constant masks from the disc
|
||||
self.land_mask, self.soil_type, self.topography = LoadConstantMask()
|
||||
|
||||
def forward(self, input, input_surface):
|
||||
# Zero-pad the input
|
||||
input = Pad3D(input)
|
||||
input_surface = Pad2D(input_surface)
|
||||
|
||||
# Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches, patch_size = (2, 4, 4) as in the original paper
|
||||
input = self.conv(input)
|
||||
|
||||
# Add three constant fields to the surface fields
|
||||
input_surface = Concatenate(input_surface, self.land_mask, self.soil_type, self.topography)
|
||||
|
||||
# Apply a linear projection for patch_size[1]*patch_size[2] patches
|
||||
input_surface = self.conv_surface(input_surface)
|
||||
|
||||
# Concatenate the input in the pressure level, i.e., in Z dimension
|
||||
x = Concatenate(input, input_surface)
|
||||
|
||||
# Reshape x for calculation of linear projections
|
||||
x = TransposeDimensions(x, (0, 2, 3, 4, 1))
|
||||
x = reshape(x, target_shape=(x.shape[0], 8*360*181, x.shape[-1]))
|
||||
return x
|
||||
|
||||
class PatchRecovery:
|
||||
def __init__(self, dim):
|
||||
'''Patch recovery operation'''
|
||||
# Hear we use two transposed convolutions to recover data
|
||||
self.conv = ConvTranspose3d(input_dims=dim, output_dims=5, kernel_size=patch_size, stride=patch_size)
|
||||
self.conv_surface = ConvTranspose2d(input_dims=dim, output_dims=4, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
|
||||
def forward(self, x, Z, H, W):
|
||||
# The inverse operation of the patch embedding operation, patch_size = (2, 4, 4) as in the original paper
|
||||
# Reshape x back to three dimensions
|
||||
x = TransposeDimensions(x, (0, 2, 1))
|
||||
x = reshape(x, target_shape=(x.shape[0], x.shape[1], Z, H, W))
|
||||
|
||||
# Call the transposed convolution
|
||||
output = self.conv(x[:, :, 1:, :, :])
|
||||
output_surface = self.conv_surface(x[:, :, 0, :, :])
|
||||
|
||||
# Crop the output to remove zero-paddings
|
||||
output = Crop3D(output)
|
||||
output_surface = Crop2D(output_surface)
|
||||
return output, output_surface
|
||||
|
||||
class DownSample:
|
||||
def __init__(self, dim):
|
||||
'''Down-sampling operation'''
|
||||
# A linear function and a layer normalization
|
||||
self.linear = Linear(4*dim, 2*dim, bias=Fasle)
|
||||
self.norm = LayerNorm(4*dim)
|
||||
|
||||
def forward(self, x, Z, H, W):
|
||||
# Reshape x to three dimensions for downsampling
|
||||
x = reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[-1]))
|
||||
|
||||
# Padding the input to facilitate downsampling
|
||||
x = Pad3D(x)
|
||||
|
||||
# Reorganize x to reduce the resolution: simply change the order and downsample from (8, 360, 182) to (8, 180, 91)
|
||||
Z, H, W = x.shape
|
||||
# Reshape x to facilitate downsampling
|
||||
x = reshape(x, target_shape=(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1]))
|
||||
# Change the order of x
|
||||
x = TransposeDimensions(x, (0,1,2,4,3,5,6))
|
||||
# Reshape to get a tensor of resolution (8, 180, 91)
|
||||
x = reshape(x, target_shape=(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1]))
|
||||
|
||||
# Call the layer normalization
|
||||
x = self.norm(x)
|
||||
|
||||
# Decrease the channels of the data to reduce computation cost
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class UpSample:
|
||||
def __init__(self, input_dim, output_dim):
|
||||
'''Up-sampling operation'''
|
||||
# Linear layers without bias to increase channels of the data
|
||||
self.linear1 = Linear(input_dim, output_dim*4, bias=False)
|
||||
|
||||
# Linear layers without bias to mix the data up
|
||||
self.linear2 = Linear(output_dim, output_dim, bias=False)
|
||||
|
||||
# Normalization
|
||||
self.norm = LayerNorm(output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# Call the linear functions to increase channels of the data
|
||||
x = self.linear1(x)
|
||||
|
||||
# Reorganize x to increase the resolution: simply change the order and upsample from (8, 180, 91) to (8, 360, 182)
|
||||
# Reshape x to facilitate upsampling.
|
||||
x = reshape(x, target_shape=(x.shape[0], 8, 180, 91, 2, 2, x.shape[-1]//4))
|
||||
# Change the order of x
|
||||
x = TransposeDimensions(x, (0,1,2,4,3,5,6))
|
||||
# Reshape to get Tensor with a resolution of (8, 360, 182)
|
||||
x = reshape(x, target_shape=(x.shape[0], 8, 360, 182, x.shape[-1]))
|
||||
|
||||
# Crop the output to the input shape of the network
|
||||
x = Crop3D(x)
|
||||
|
||||
# Reshape x back
|
||||
x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1]))
|
||||
|
||||
# Call the layer normalization
|
||||
x = self.norm(x)
|
||||
|
||||
# Mixup normalized tensors
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
class EarthSpecificLayer:
|
||||
def __init__(self, depth, dim, drop_path_ratio_list, heads):
|
||||
'''Basic layer of our network, contains 2 or 6 blocks'''
|
||||
self.depth = depth
|
||||
self.blocks = []
|
||||
|
||||
# Construct basic blocks
|
||||
for i in range(depth):
|
||||
self.blocks.append(EarthSpecificBlock(dim, drop_path_ratio_list[i], heads))
|
||||
|
||||
def forward(self, x, Z, H, W):
|
||||
for i in range(self.depth):
|
||||
# Roll the input every two blocks
|
||||
if i % 2 == 0:
|
||||
self.blocks[i](x, Z, H, W, roll=False)
|
||||
else:
|
||||
self.blocks[i](x, Z, H, W, roll=True)
|
||||
return x
|
||||
|
||||
class EarthSpecificBlock:
|
||||
def __init__(self, dim, drop_path_ratio, heads):
|
||||
'''
|
||||
3D transformer block with Earth-Specific bias and window attention,
|
||||
see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
|
||||
The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.
|
||||
'''
|
||||
# Define the window size of the neural network
|
||||
self.window_size = (2, 6, 12)
|
||||
|
||||
# Initialize serveral operations
|
||||
self.drop_path = DropPath(drop_rate=drop_path_ratio)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.linear = MLP(dim, 0)
|
||||
self.attention = EarthAttention3D(dim, heads, 0, self.window_size)
|
||||
|
||||
def forward(self, x, Z, H, W, roll):
|
||||
# Save the shortcut for skip-connection
|
||||
shortcut = x
|
||||
|
||||
# Reshape input to three dimensions to calculate window attention
|
||||
reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[2]))
|
||||
|
||||
# Zero-pad input if needed
|
||||
x = pad3D(x)
|
||||
|
||||
# Store the shape of the input for restoration
|
||||
ori_shape = x.shape
|
||||
|
||||
if roll:
|
||||
# Roll x for half of the window for 3 dimensions
|
||||
x = roll3D(x, shift=[self.window_size[0]//2, self.window_size[1]//2, self.window_size[2]//2])
|
||||
# Generate mask of attention masks
|
||||
# If two pixels are not adjacent, then mask the attention between them
|
||||
# Your can set the matrix element to -1000 when it is not adjacent, then add it to the attention
|
||||
mask = gen_mask(x)
|
||||
else:
|
||||
# e.g., zero matrix when you add mask to attention
|
||||
mask = no_mask
|
||||
|
||||
# Reorganize data to calculate window attention
|
||||
x_window = reshape(x, target_shape=(x.shape[0], Z//window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], x.shape[-1]))
|
||||
x_window = TransposeDimensions(x_window, (0, 1, 3, 5, 2, 4, 6, 7))
|
||||
|
||||
# Get data stacked in 3D cubes, which will further be used to calculated attention among each cube
|
||||
x_window = reshape(x_window, target_shape=(-1, window_size[0]* window_size[1]*window_size[2], x.shape[-1]))
|
||||
|
||||
# Apply 3D window attention with Earth-Specific bias
|
||||
x_window = self.attention(x, mask)
|
||||
|
||||
# Reorganize data to original shapes
|
||||
x = reshape(x_window, target_shape=((-1, Z // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], x_window.shape[-1])))
|
||||
x = TransposeDimensions(x, (0, 1, 4, 2, 5, 3, 6, 7))
|
||||
|
||||
# Reshape the tensor back to its original shape
|
||||
x = reshape(x_window, target_shape=ori_shape)
|
||||
|
||||
if roll:
|
||||
# Roll x back for half of the window
|
||||
x = roll3D(x, shift=[-self.window_size[0]//2, -self.window_size[1]//2, -self.window_size[2]//2])
|
||||
|
||||
# Crop the zero-padding
|
||||
x = Crop3D(x)
|
||||
|
||||
# Reshape the tensor back to the input shape
|
||||
x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[4]))
|
||||
|
||||
# Main calculation stages
|
||||
x = shortcut + self.drop_path(self.norm1(x))
|
||||
x = x + self.drop_path(self.norm2(self.linear(x)))
|
||||
return x
|
||||
|
||||
class EarthAttention3D:
|
||||
def __init__(self, dim, heads, dropout_rate, window_size):
|
||||
'''
|
||||
3D window attention with the Earth-Specific bias,
|
||||
see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
|
||||
'''
|
||||
# Initialize several operations
|
||||
self.linear1 = Linear(dim, dim=3, bias=True)
|
||||
self.linear2 = Linear(dim, dim)
|
||||
self.softmax = SoftMax(dim=-1)
|
||||
self.dropout = DropOut(dropout_rate)
|
||||
|
||||
# Store several attributes
|
||||
self.head_number = heads
|
||||
self.dim = dim
|
||||
self.scale = (dim//heads)**-0.5
|
||||
self.window_size = window_size
|
||||
|
||||
# input_shape is current shape of the self.forward function
|
||||
# You can run your code to record it, modify the code and rerun it
|
||||
# Record the number of different window types
|
||||
self.type_of_windows = (input_shape[0]//window_size[0])*(input_shape[1]//window_size[1])
|
||||
|
||||
# For each type of window, we will construct a set of parameters according to the paper
|
||||
self.earth_specific_bias = ConstructTensor(shape=((2 * window_size[2] - 1) * window_size[1] * window_size[1] * window_size[0] * window_size[0], self.type_of_windows, heads))
|
||||
|
||||
# Making these tensors to be learnable parameters
|
||||
self.earth_specific_bias = Parameters(self.earth_specific_bias)
|
||||
|
||||
# Initialize the tensors using Truncated normal distribution
|
||||
TruncatedNormalInit(self.earth_specific_bias, std=0.02)
|
||||
|
||||
# Construct position index to reuse self.earth_specific_bias
|
||||
self.position_index = self._construct_index()
|
||||
|
||||
def _construct_index(self):
|
||||
''' This function construct the position index to reuse symmetrical parameters of the position bias'''
|
||||
# Index in the pressure level of query matrix
|
||||
coords_zi = RangeTensor(self.window_size[0])
|
||||
# Index in the pressure level of key matrix
|
||||
coords_zj = -RangeTensor(self.window_size[0])*self.window_size[0]
|
||||
|
||||
# Index in the latitude of query matrix
|
||||
coords_hi = RangeTensor(self.window_size[1])
|
||||
# Index in the latitude of key matrix
|
||||
coords_hj = -RangeTensor(self.window_size[1])*self.window_size[1]
|
||||
|
||||
# Index in the longitude of the key-value pair
|
||||
coords_w = RangeTensor(self.window_size[2])
|
||||
|
||||
# Change the order of the index to calculate the index in total
|
||||
coords_1 = Stack(MeshGrid([coords_zi, coords_hi, coords_w]))
|
||||
coords_2 = Stack(MeshGrid([coords_zj, coords_hj, coords_w]))
|
||||
coords_flatten_1 = Flatten(coords_1, start_dimension=1)
|
||||
coords_flatten_2 = Flatten(coords_2, start_dimension=1)
|
||||
coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]
|
||||
coords = TransposeDimensions(coords, (1, 2, 0))
|
||||
|
||||
# Shift the index for each dimension to start from 0
|
||||
coords[:, :, 2] += self.window_size[2] - 1
|
||||
coords[:, :, 1] *= 2 * self.window_size[2] - 1
|
||||
coords[:, :, 0] *= (2 * self.window_size[2] - 1)*self.window_size[1]*self.window_size[1]
|
||||
|
||||
# Sum up the indexes in three dimensions
|
||||
self.position_index = TensorSum(coords, dim=-1)
|
||||
|
||||
# Flatten the position index to facilitate further indexing
|
||||
self.position_index = Flatten(self.position_index)
|
||||
|
||||
def forward(self, x, mask):
|
||||
# Linear layer to create query, key and value
|
||||
x = self.linear1(x)
|
||||
|
||||
# Record the original shape of the input
|
||||
original_shape = x.shape
|
||||
|
||||
# reshape the data to calculate multi-head attention
|
||||
qkv = reshape(x, target_shape=(x.shape[0], x.shape[1], 3, self.head_number, self.dim // self.head_number))
|
||||
query, key, value = TransposeDimensions(qkv, (2, 0, 3, 1, 4))
|
||||
|
||||
# Scale the attention
|
||||
query = query * self.scale
|
||||
|
||||
# Calculated the attention, a learnable bias is added to fix the nonuniformity of the grid.
|
||||
attention = query @ key.T # @ denotes matrix multiplication
|
||||
|
||||
# self.earth_specific_bias is a set of neural network parameters to optimize.
|
||||
EarthSpecificBias = self.earth_specific_bias[self.position_index]
|
||||
|
||||
# Reshape the learnable bias to the same shape as the attention matrix
|
||||
EarthSpecificBias = reshape(EarthSpecificBias, target_shape=(self.window_size[0]*self.window_size[1]*self.window_size[2], self.window_size[0]*self.window_size[1]*self.window_size[2], self.type_of_windows, self.head_number))
|
||||
EarthSpecificBias = TransposeDimensions(EarthSpecificBias, (2, 3, 0, 1))
|
||||
EarthSpecificBias = reshape(EarthSpecificBias, target_shape = [1]+EarthSpecificBias.shape)
|
||||
|
||||
# Add the Earth-Specific bias to the attention matrix
|
||||
attention = attention + EarthSpecificBias
|
||||
|
||||
# Mask the attention between non-adjacent pixels, e.g., simply add -100 to the masked element.
|
||||
attention = self.mask_attention(attention, mask)
|
||||
attention = self.softmax(attention)
|
||||
attention = self.dropout(attention)
|
||||
|
||||
# Calculated the tensor after spatial mixing.
|
||||
x = attention @ value.T # @ denote matrix multiplication
|
||||
|
||||
# Reshape tensor to the original shape
|
||||
x = TransposeDimensions(x, (0, 2, 1))
|
||||
x = reshape(x, target_shape = original_shape)
|
||||
|
||||
# Linear layer to post-process operated tensor
|
||||
x = self.linear2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class Mlp:
|
||||
def __init__(self, dim, dropout_rate):
|
||||
'''MLP layers, same as most vision transformer architectures.'''
|
||||
self.linear1 = Linear(dim, dim * 4)
|
||||
self.linear2 = Linear(dim * 4, dim)
|
||||
self.activation = GeLU()
|
||||
self.drop = DropOut(drop_rate=dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.activation(x)
|
||||
x = self.drop(x)
|
||||
x = self.linear(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def PerlinNoise():
|
||||
'''Generate random Perlin noise: we follow https://github.com/pvigier/perlin-numpy/ to calculate the perlin noise.'''
|
||||
# Define number of noise
|
||||
octaves = 3
|
||||
# Define the scaling factor of noise
|
||||
noise_scale = 0.2
|
||||
# Define the number of periods of noise along the axis
|
||||
period_number = 12
|
||||
# The size of an input slice
|
||||
H, W = 721, 1440
|
||||
# Scaling factor between two octaves
|
||||
persistence = 0.5
|
||||
# see https://github.com/pvigier/perlin-numpy/ for the implementation of GenerateFractalNoise (e.g., from perlin_numpy import generate_fractal_noise_3d)
|
||||
perlin_noise = noise_scale*GenerateFractalNoise((H, W), (period_number, period_number), octaves, persistence)
|
||||
return perlin_noise
|
||||
Reference in New Issue
Block a user