[add]上传训练benchmark by z00560161
This commit is contained in:
@@ -0,0 +1,37 @@
|
||||
# ResNet101_tensorflow训练说明
|
||||
|
||||
### 1. 模型训练参数配置
|
||||
|
||||
在train/yaml/ResNet101.yaml中修改相应配置, 配置项含义:
|
||||
|
||||
```
|
||||
tensorflow_config:
|
||||
# 基本参数
|
||||
data_url: /home/imagenet_TF/
|
||||
# 1p/8p,epoches设为150
|
||||
epoches: 1
|
||||
epochs_between_evals: 1
|
||||
max_train_steps: 1000
|
||||
batch_size: 128
|
||||
|
||||
# 仅多机执行需要配置: ip1:卡数量1,ip2:卡数量2
|
||||
mpirun_ip: 90.90.176.152:8,90.90.176.154:8
|
||||
|
||||
# docker 镜像名称:版本号
|
||||
docker_image: c73:b02
|
||||
|
||||
# 指定 device id, 多个 id 使用空格分隔, 数量需与 rank_size 相同
|
||||
device_group_1p: 0
|
||||
device_group_2p: 0 1
|
||||
device_group_4p: 0 1 2 3
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+203
@@ -0,0 +1,203 @@
|
||||
Copyright 2015 The TensorFlow Authors. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2015, The TensorFlow Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
+151
@@ -0,0 +1,151 @@
|
||||

|
||||
|
||||
# TensorFlow Official Models
|
||||
|
||||
The TensorFlow official models are a collection of models
|
||||
that use TensorFlow’s high-level APIs.
|
||||
They are intended to be well-maintained, tested, and kept up to date
|
||||
with the latest TensorFlow API.
|
||||
They should also be reasonably optimized for fast performance while still
|
||||
being easy to read.
|
||||
These models are used as end-to-end tests, ensuring that the models run
|
||||
with the same or improved speed and performance with each new TensorFlow build.
|
||||
|
||||
## Model Implementations
|
||||
|
||||
### Natural Language Processing
|
||||
|
||||
| Model | Description | Reference |
|
||||
| ----- | ----------- | --------- |
|
||||
| [ALBERT](nlp/albert) | A Lite BERT for Self-supervised Learning of Language Representations | [arXiv:1909.11942](https://arxiv.org/abs/1909.11942) |
|
||||
| [BERT](nlp/bert) | A powerful pre-trained language representation model: BERT (Bidirectional Encoder Representations from Transformers) | [arXiv:1810.04805](https://arxiv.org/abs/1810.04805) |
|
||||
| [NHNet](nlp/nhnet) | A transformer-based multi-sequence to sequence model: Generating Representative Headlines for News Stories | [arXiv:2001.09386](https://arxiv.org/abs/2001.09386) |
|
||||
| [Transformer](nlp/transformer) | A transformer model to translate the WMT English to German dataset | [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) |
|
||||
| [XLNet](nlp/xlnet) | XLNet: Generalized Autoregressive Pretraining for Language Understanding | [arXiv:1906.08237](https://arxiv.org/abs/1906.08237) |
|
||||
|
||||
### Computer Vision
|
||||
|
||||
| Model | Description | Reference |
|
||||
| ----- | ----------- | --------- |
|
||||
| [MNIST](vision/image_classification) | A basic model to classify digits from the MNIST dataset | [Link](http://yann.lecun.com/exdb/mnist/) |
|
||||
| [ResNet](vision/image_classification) | A deep residual network for image recognition | [arXiv:1512.03385](https://arxiv.org/abs/1512.03385) |
|
||||
| [RetinaNet](vision/detection) | A fast and powerful object detector | [arXiv:1708.02002](https://arxiv.org/abs/1708.02002) |
|
||||
| [Mask R-CNN](vision/detection) | An object detection and instance segmentation model | [arXiv:1703.06870](https://arxiv.org/abs/1703.06870) |
|
||||
|
||||
### Other models
|
||||
|
||||
| Model | Description | Reference |
|
||||
| ----- | ----------- | --------- |
|
||||
| [NCF](recommendation) | Neural Collaborative Filtering model for recommendation tasks | [arXiv:1708.05031](https://arxiv.org/abs/1708.05031) |
|
||||
|
||||
---
|
||||
|
||||
## How to get started with the Model Garden official models
|
||||
|
||||
* The models in the master branch are developed using TensorFlow 2,
|
||||
and they target the TensorFlow [nightly binaries](https://github.com/tensorflow/tensorflow#installation)
|
||||
built from the
|
||||
[master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master).
|
||||
* The stable versions targeting releases of TensorFlow are available
|
||||
as tagged branches or [downloadable releases](https://github.com/tensorflow/models/releases).
|
||||
* Model repository version numbers match the target TensorFlow release,
|
||||
such that
|
||||
[release v2.1.0](https://github.com/tensorflow/models/releases/tag/v2.1.0)
|
||||
are compatible with
|
||||
[TensorFlow v2.1.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0).
|
||||
|
||||
Please follow the below steps before running models in this repository.
|
||||
|
||||
### Requirements
|
||||
|
||||
* The latest TensorFlow Model Garden release and TensorFlow 2
|
||||
* If you are on a version of TensorFlow earlier than 2.1, please
|
||||
upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
|
||||
|
||||
```shell
|
||||
pip3 install tf-nightly
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
#### Method 1: Install the TensorFlow Model Garden pip package
|
||||
|
||||
**tf-models-nightly** is the nightly Model Garden package
|
||||
created daily automatically. pip will install all models
|
||||
and dependencies automatically.
|
||||
|
||||
```shell
|
||||
pip install tf-models-nightly
|
||||
```
|
||||
|
||||
Please check out our [example](colab/bert.ipynb)
|
||||
to learn how to use a PIP package.
|
||||
|
||||
#### Method 2: Clone the source
|
||||
|
||||
1. Clone the GitHub repository:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/tensorflow/models.git
|
||||
```
|
||||
|
||||
2. Add the top-level ***/models*** folder to the Python path.
|
||||
|
||||
```shell
|
||||
export PYTHONPATH=$PYTHONPATH:/path/to/models
|
||||
```
|
||||
|
||||
If you are using a Colab notebook, please set the Python path with os.environ.
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ['PYTHONPATH'] += ":/path/to/models"
|
||||
```
|
||||
|
||||
3. Install other dependencies
|
||||
|
||||
```shell
|
||||
pip3 install --user -r official/requirements.txt
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## More models to come!
|
||||
|
||||
The team is actively developing new models.
|
||||
In the near future, we will add:
|
||||
|
||||
- State-of-the-art language understanding models:
|
||||
More members in Transformer family
|
||||
- Start-of-the-art image classification models:
|
||||
EfficientNet, MnasNet and variants.
|
||||
- A set of excellent objection detection models.
|
||||
|
||||
If you would like to make any fixes or improvements to the models, please
|
||||
[submit a pull request](https://github.com/tensorflow/models/compare).
|
||||
|
||||
---
|
||||
|
||||
## Contributions
|
||||
|
||||
Every model should follow our guidelines to uphold our objectives of readable,
|
||||
usable, and maintainable code.
|
||||
|
||||
### General Guidelines
|
||||
|
||||
- Code should be well documented and tested.
|
||||
- Runnable from a blank environment with ease.
|
||||
- Trainable on: single GPU/CPU (baseline), multiple GPUs & TPUs
|
||||
- Compatible with Python 3 (using [six](https://pythonhosted.org/six/)
|
||||
when being compatible with Python 2 is necessary)
|
||||
- Conform to
|
||||
[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md)
|
||||
|
||||
### Implementation Guidelines
|
||||
|
||||
These guidelines are to ensure consistent model implementations for
|
||||
better readability and maintainability.
|
||||
|
||||
- Use [common utility functions](utils)
|
||||
- Export SavedModel at the end of the training.
|
||||
- Consistent flags and flag-parsing library ([read more here](utils/flags/guidelines.md))
|
||||
+25
@@ -0,0 +1,25 @@
|
||||
# Offically Supported TensorFlow 2.1+ Models on Cloud TPU
|
||||
|
||||
## Natural Language Processing
|
||||
|
||||
* [bert](nlp/bert): A powerful pre-trained language representation model:
|
||||
BERT, which stands for Bidirectional Encoder Representations from
|
||||
Transformers.
|
||||
[BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
|
||||
* [transformer](nlp/transformer): A transformer model to translate the WMT
|
||||
English to German dataset.
|
||||
[Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
|
||||
|
||||
## Computer Vision
|
||||
|
||||
* [efficientnet](vision/image_classification): A family of convolutional
|
||||
neural networks that scale by balancing network depth, width, and
|
||||
resolution and can be used to classify ImageNet's dataset of 1000 classes.
|
||||
See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
|
||||
* [mnist](vision/image_classification): A basic model to classify digits
|
||||
from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
|
||||
* [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
|
||||
* [resnet](vision/image_classification): A deep residual network that can
|
||||
be used to classify ImageNet's dataset of 1000 classes.
|
||||
See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
|
||||
* [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
|
||||
+79
@@ -0,0 +1,79 @@
|
||||
# ResNet in TensorFlow On NPU
|
||||
---
|
||||
|
||||
# Classification Model
|
||||
## Overview
|
||||
1. This is an implementation of the ResNet101 model as described in the [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) paper.
|
||||
2. Current implementation is based on the code from the TensorFlow Official implementation in the [tensorflow/models Repo](https://github.com/tensorflow/models).
|
||||
|
||||
## Introduction
|
||||
1. ResNet is a relatively good network for classification problems in the ImageNet competition. It introduces the concept of residual learning. It protects the integrity of information by adding direct connection channels, solves problems such as information loss, gradient disappearance, and gradient explosion. The network can also be trained. ResNet has different network layers, commonly used are 18-layer, 34-layer, 50-layer, 101-layer, 152-layer.
|
||||
2. Ascend provides the V1.5 version of the 50-layer ResNet network this time. The difference between the V1.5 version of the ResNet network and the V1 version is that in the bottleneck module, the V1 version is set stride=2 in the first 1x1 convolutional layer, and V1.5 sets stride=2 in the 3x3 convolutional layer.
|
||||
|
||||
## Dataset
|
||||
We have used the [ImageNet](http://www.image-net.org/)dataset as an example here, you can use mnist or your own dataset to modify and adapt.
|
||||
We use [build_imagenet_data](https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/research/slim/datasets/build_imagenet_data.py) to build record for training.
|
||||
|
||||
## Running Code
|
||||
### Config the env paramater
|
||||
check if path '/usr/local/HiAI' or ''/usr/local/Ascend' is existed or not.
|
||||
modify '/usr/local/HiAI' to the actual path in scripts/run.sh
|
||||
|
||||
### Train and evaluate model
|
||||
[imagenet_main.py](official/r1/resnet/imagenet_main.py) is the Entry Python script.
|
||||
[resnet_run_loop.py](official/r1/resnet/resnet_run_loop.py) is the Main Python script.
|
||||
|
||||
### Check your rank_table
|
||||
default rank_table setting in [configs](official/r1/resnet/configs) is usrd for X86.
|
||||
if you use aach64, please modify board_id from "0x0000" -->
|
||||
|
||||
To train and evaluate the model, issue the following command:
|
||||
```
|
||||
# for single training
|
||||
bash ./scripts/train_1p.sh
|
||||
# for multi training
|
||||
bash ./scripts/train_8p.sh
|
||||
```
|
||||
|
||||
Default Args:
|
||||
- Batch size: 128
|
||||
- Momentum: 0.9
|
||||
- LR scheduler: cosine
|
||||
- Learning rate(LR): 0.064
|
||||
- loss scale: 512
|
||||
- Weight decay: 0.0001
|
||||
- Label smoothing: 0.1
|
||||
- train epoch: 90
|
||||
|
||||
There are other arguments about models and training process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions.
|
||||
|
||||
### Train and evaluate result
|
||||
- 1 NPU
|
||||
- Train performance:109ms/step,1170images/sec.
|
||||
- 8 NPU
|
||||
- Train performance:109ms/step,9390images/sec.
|
||||
- best result
|
||||
- Accuracy(Top1): 79.03
|
||||
- Accuracy(Top5): 94.53
|
||||
|
||||
### More
|
||||
|
||||
#### modify file
|
||||
- The npu modify file list as follows:
|
||||
- DaVinci npu platform adaptation code,including
|
||||
1.official/r1/resnet/imagenet_main.py
|
||||
2.official/r1/resnet/resnet_model.py
|
||||
3.official/r1/resnet/resnet_run_loop.py
|
||||
4.official/utils/flags/_base.py
|
||||
|
||||
#### FileTree Intro
|
||||
- Main Dir
|
||||
- ./official/r1/resnet
|
||||
- Single NPU Training Shell
|
||||
- npu_train_1p_test.sh
|
||||
- Multi NPU(8p) Training Shell
|
||||
- npu_train_8p_test.sh
|
||||
- Log Info
|
||||
- STDOUT nohup.out
|
||||
- Performance perf.log
|
||||
|
||||
+157
@@ -0,0 +1,157 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Library to upload benchmark generated by BenchmarkLogger to remote repo.
|
||||
|
||||
This library require google cloud bigquery lib as dependency, which can be
|
||||
installed with:
|
||||
> pip install --upgrade google-cloud-bigquery
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from google.cloud import bigquery
|
||||
from google.cloud import exceptions
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class BigQueryUploader(object):
|
||||
"""Upload the benchmark and metric info from JSON input to BigQuery. """
|
||||
|
||||
def __init__(self, gcp_project=None, credentials=None):
|
||||
"""Initialized BigQueryUploader with proper setting.
|
||||
|
||||
Args:
|
||||
gcp_project: string, the name of the GCP project that the log will be
|
||||
uploaded to. The default project name will be detected from local
|
||||
environment if no value is provided.
|
||||
credentials: google.auth.credentials. The credential to access the
|
||||
BigQuery service. The default service account credential will be
|
||||
detected from local environment if no value is provided. Please use
|
||||
google.oauth2.service_account.Credentials to load credential from local
|
||||
file for the case that the test is run out side of GCP.
|
||||
"""
|
||||
self._bq_client = bigquery.Client(
|
||||
project=gcp_project, credentials=credentials)
|
||||
|
||||
def upload_benchmark_run_json(
|
||||
self, dataset_name, table_name, run_id, run_json):
|
||||
"""Upload benchmark run information to Bigquery.
|
||||
|
||||
Args:
|
||||
dataset_name: string, the name of bigquery dataset where the data will be
|
||||
uploaded.
|
||||
table_name: string, the name of bigquery table under the dataset where
|
||||
the data will be uploaded.
|
||||
run_id: string, a unique ID that will be attached to the data, usually
|
||||
this is a UUID4 format.
|
||||
run_json: dict, the JSON data that contains the benchmark run info.
|
||||
"""
|
||||
run_json["model_id"] = run_id
|
||||
self._upload_json(dataset_name, table_name, [run_json])
|
||||
|
||||
def upload_benchmark_metric_json(
|
||||
self, dataset_name, table_name, run_id, metric_json_list):
|
||||
"""Upload metric information to Bigquery.
|
||||
|
||||
Args:
|
||||
dataset_name: string, the name of bigquery dataset where the data will be
|
||||
uploaded.
|
||||
table_name: string, the name of bigquery table under the dataset where
|
||||
the metric data will be uploaded. This is different from the
|
||||
benchmark_run table.
|
||||
run_id: string, a unique ID that will be attached to the data, usually
|
||||
this is a UUID4 format. This should be the same as the benchmark run_id.
|
||||
metric_json_list: list, a list of JSON object that record the metric info.
|
||||
"""
|
||||
for m in metric_json_list:
|
||||
m["run_id"] = run_id
|
||||
self._upload_json(dataset_name, table_name, metric_json_list)
|
||||
|
||||
def upload_benchmark_run_file(
|
||||
self, dataset_name, table_name, run_id, run_json_file):
|
||||
"""Upload benchmark run information to Bigquery from input json file.
|
||||
|
||||
Args:
|
||||
dataset_name: string, the name of bigquery dataset where the data will be
|
||||
uploaded.
|
||||
table_name: string, the name of bigquery table under the dataset where
|
||||
the data will be uploaded.
|
||||
run_id: string, a unique ID that will be attached to the data, usually
|
||||
this is a UUID4 format.
|
||||
run_json_file: string, the file path that contains the run JSON data.
|
||||
"""
|
||||
with tf.io.gfile.GFile(run_json_file) as f:
|
||||
benchmark_json = json.load(f)
|
||||
self.upload_benchmark_run_json(
|
||||
dataset_name, table_name, run_id, benchmark_json)
|
||||
|
||||
def upload_metric_file(
|
||||
self, dataset_name, table_name, run_id, metric_json_file):
|
||||
"""Upload metric information to Bigquery from input json file.
|
||||
|
||||
Args:
|
||||
dataset_name: string, the name of bigquery dataset where the data will be
|
||||
uploaded.
|
||||
table_name: string, the name of bigquery table under the dataset where
|
||||
the metric data will be uploaded. This is different from the
|
||||
benchmark_run table.
|
||||
run_id: string, a unique ID that will be attached to the data, usually
|
||||
this is a UUID4 format. This should be the same as the benchmark run_id.
|
||||
metric_json_file: string, the file path that contains the metric JSON
|
||||
data.
|
||||
"""
|
||||
with tf.io.gfile.GFile(metric_json_file) as f:
|
||||
metrics = []
|
||||
for line in f:
|
||||
metrics.append(json.loads(line.strip()))
|
||||
self.upload_benchmark_metric_json(
|
||||
dataset_name, table_name, run_id, metrics)
|
||||
|
||||
def _upload_json(self, dataset_name, table_name, json_list):
|
||||
# Find the unique table reference based on dataset and table name, so that
|
||||
# the data can be inserted to it.
|
||||
table_ref = self._bq_client.dataset(dataset_name).table(table_name)
|
||||
errors = self._bq_client.insert_rows_json(table_ref, json_list)
|
||||
if errors:
|
||||
tf.logging.error(
|
||||
"Failed to upload benchmark info to bigquery: {}".format(errors))
|
||||
|
||||
def insert_run_status(self, dataset_name, table_name, run_id, run_status):
|
||||
"""Insert the run status in to Bigquery run status table."""
|
||||
query = ("INSERT {ds}.{tb} "
|
||||
"(run_id, status) "
|
||||
"VALUES('{rid}', '{status}')").format(
|
||||
ds=dataset_name, tb=table_name, rid=run_id, status=run_status)
|
||||
try:
|
||||
self._bq_client.query(query=query).result()
|
||||
except exceptions.GoogleCloudError as e:
|
||||
tf.logging.error("Failed to insert run status: %s", e)
|
||||
|
||||
def update_run_status(self, dataset_name, table_name, run_id, run_status):
|
||||
"""Update the run status in in Bigquery run status table."""
|
||||
query = ("UPDATE {ds}.{tb} "
|
||||
"SET status = '{status}' "
|
||||
"WHERE run_id = '{rid}'").format(
|
||||
ds=dataset_name, tb=table_name, status=run_status, rid=run_id)
|
||||
try:
|
||||
self._bq_client.query(query=query).result()
|
||||
except exceptions.GoogleCloudError as e:
|
||||
tf.logging.error("Failed to update run status: %s", e)
|
||||
+66
@@ -0,0 +1,66 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Binary to upload benchmark generated by BenchmarkLogger to remote repo.
|
||||
|
||||
This library require google cloud bigquery lib as dependency, which can be
|
||||
installed with:
|
||||
> pip install --upgrade google-cloud-bigquery
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from absl import app as absl_app
|
||||
from absl import flags
|
||||
|
||||
from official.benchmark import benchmark_uploader
|
||||
from official.utils.flags import core as flags_core
|
||||
from official.utils.logs import logger
|
||||
|
||||
def main(_):
|
||||
if not flags.FLAGS.benchmark_log_dir:
|
||||
print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir")
|
||||
sys.exit(1)
|
||||
|
||||
uploader = benchmark_uploader.BigQueryUploader(
|
||||
gcp_project=flags.FLAGS.gcp_project)
|
||||
run_id = str(uuid.uuid4())
|
||||
run_json_file = os.path.join(
|
||||
flags.FLAGS.benchmark_log_dir, logger.BENCHMARK_RUN_LOG_FILE_NAME)
|
||||
metric_json_file = os.path.join(
|
||||
flags.FLAGS.benchmark_log_dir, logger.METRIC_LOG_FILE_NAME)
|
||||
|
||||
uploader.upload_benchmark_run_file(
|
||||
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id,
|
||||
run_json_file)
|
||||
uploader.upload_metric_file(
|
||||
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id,
|
||||
metric_json_file)
|
||||
# Assume the run finished successfully before user invoke the upload script.
|
||||
uploader.insert_run_status(
|
||||
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_status_table,
|
||||
run_id, logger.RUN_STATUS_SUCCESS)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags_core.define_benchmark()
|
||||
flags.adopt_module_key_flags(flags_core)
|
||||
absl_app.run(main=main)
|
||||
+123
@@ -0,0 +1,123 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for benchmark_uploader."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from mock import MagicMock
|
||||
from mock import patch
|
||||
|
||||
import tensorflow as tf # pylint: disable=g-bad-import-order
|
||||
|
||||
try:
|
||||
from google.cloud import bigquery
|
||||
from official.benchmark import benchmark_uploader
|
||||
except ImportError:
|
||||
bigquery = None
|
||||
benchmark_uploader = None
|
||||
|
||||
|
||||
@unittest.skipIf(bigquery is None, "Bigquery dependency is not installed.")
|
||||
class BigQueryUploaderTest(tf.test.TestCase):
|
||||
|
||||
@patch.object(bigquery, "Client")
|
||||
def setUp(self, mock_bigquery):
|
||||
self.mock_client = mock_bigquery.return_value
|
||||
self.mock_dataset = MagicMock(name="dataset")
|
||||
self.mock_table = MagicMock(name="table")
|
||||
self.mock_client.dataset.return_value = self.mock_dataset
|
||||
self.mock_dataset.table.return_value = self.mock_table
|
||||
self.mock_client.insert_rows_json.return_value = []
|
||||
|
||||
self.benchmark_uploader = benchmark_uploader.BigQueryUploader()
|
||||
self.benchmark_uploader._bq_client = self.mock_client
|
||||
|
||||
self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||
with open(os.path.join(self.log_dir, "metric.log"), "a") as f:
|
||||
json.dump({"name": "accuracy", "value": 1.0}, f)
|
||||
f.write("\n")
|
||||
json.dump({"name": "loss", "value": 0.5}, f)
|
||||
f.write("\n")
|
||||
with open(os.path.join(self.log_dir, "run.log"), "w") as f:
|
||||
json.dump({"model_name": "value"}, f)
|
||||
|
||||
def tearDown(self):
|
||||
tf.io.gfile.rmtree(self.get_temp_dir())
|
||||
|
||||
def test_upload_benchmark_run_json(self):
|
||||
self.benchmark_uploader.upload_benchmark_run_json(
|
||||
"dataset", "table", "run_id", {"model_name": "value"})
|
||||
|
||||
self.mock_client.insert_rows_json.assert_called_once_with(
|
||||
self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
|
||||
|
||||
def test_upload_benchmark_metric_json(self):
|
||||
metric_json_list = [
|
||||
{"name": "accuracy", "value": 1.0},
|
||||
{"name": "loss", "value": 0.5}
|
||||
]
|
||||
expected_params = [
|
||||
{"run_id": "run_id", "name": "accuracy", "value": 1.0},
|
||||
{"run_id": "run_id", "name": "loss", "value": 0.5}
|
||||
]
|
||||
self.benchmark_uploader.upload_benchmark_metric_json(
|
||||
"dataset", "table", "run_id", metric_json_list)
|
||||
self.mock_client.insert_rows_json.assert_called_once_with(
|
||||
self.mock_table, expected_params)
|
||||
|
||||
def test_upload_benchmark_run_file(self):
|
||||
self.benchmark_uploader.upload_benchmark_run_file(
|
||||
"dataset", "table", "run_id", os.path.join(self.log_dir, "run.log"))
|
||||
|
||||
self.mock_client.insert_rows_json.assert_called_once_with(
|
||||
self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
|
||||
|
||||
def test_upload_metric_file(self):
|
||||
self.benchmark_uploader.upload_metric_file(
|
||||
"dataset", "table", "run_id",
|
||||
os.path.join(self.log_dir, "metric.log"))
|
||||
expected_params = [
|
||||
{"run_id": "run_id", "name": "accuracy", "value": 1.0},
|
||||
{"run_id": "run_id", "name": "loss", "value": 0.5}
|
||||
]
|
||||
self.mock_client.insert_rows_json.assert_called_once_with(
|
||||
self.mock_table, expected_params)
|
||||
|
||||
def test_insert_run_status(self):
|
||||
self.benchmark_uploader.insert_run_status(
|
||||
"dataset", "table", "run_id", "status")
|
||||
expected_query = ("INSERT dataset.table "
|
||||
"(run_id, status) "
|
||||
"VALUES('run_id', 'status')")
|
||||
self.mock_client.query.assert_called_once_with(query=expected_query)
|
||||
|
||||
def test_update_run_status(self):
|
||||
self.benchmark_uploader.update_run_status(
|
||||
"dataset", "table", "run_id", "status")
|
||||
expected_query = ("UPDATE dataset.table "
|
||||
"SET status = 'status' "
|
||||
"WHERE run_id = 'run_id'")
|
||||
self.mock_client.query.assert_called_once_with(query=expected_query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
+97
@@ -0,0 +1,97 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utils to annotate and trace benchmarks."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
from absl.testing import flagsaver
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_multi_string(
|
||||
'benchmark_method_flags', None,
|
||||
'Optional list of runtime flags of the form key=value. Specify '
|
||||
'multiple times to specify different flags. These will override the FLAGS '
|
||||
'object directly after hardcoded settings in individual benchmark methods '
|
||||
'before they call _run_and_report benchmark. Example if we set '
|
||||
'--benchmark_method_flags=train_steps=10 and a benchmark method hardcodes '
|
||||
'FLAGS.train_steps=10000 and later calls _run_and_report_benchmark, '
|
||||
'it\'ll only run for 10 steps. This is useful for '
|
||||
'debugging/profiling workflows.')
|
||||
|
||||
|
||||
def enable_runtime_flags(decorated_func):
|
||||
"""Sets attributes from --benchmark_method_flags for method execution.
|
||||
|
||||
@enable_runtime_flags decorator temporarily adds flags passed in via
|
||||
--benchmark_method_flags and runs the decorated function in that context.
|
||||
|
||||
A user can set --benchmark_method_flags=train_steps=5 to run the benchmark
|
||||
method in the snippet below with FLAGS.train_steps=5 for debugging (without
|
||||
modifying the benchmark code).
|
||||
|
||||
class ModelBenchmark():
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self):
|
||||
# run benchmark ...
|
||||
# report benchmark results ...
|
||||
|
||||
def benchmark_method(self):
|
||||
FLAGS.train_steps = 1000
|
||||
...
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
Args:
|
||||
decorated_func: The method that runs the benchmark after previous setup
|
||||
execution that set some flags.
|
||||
|
||||
Returns:
|
||||
new_func: The same method which executes in a temporary context where flag
|
||||
overrides from --benchmark_method_flags are active.
|
||||
"""
|
||||
|
||||
def runner(*args, **kwargs):
|
||||
"""Creates a temporary context to activate --benchmark_method_flags."""
|
||||
if FLAGS.benchmark_method_flags:
|
||||
saved_flag_values = flagsaver.save_flag_values()
|
||||
for key_value in FLAGS.benchmark_method_flags:
|
||||
key, value = key_value.split('=', 1)
|
||||
try:
|
||||
numeric_float = float(value)
|
||||
numeric_int = int(numeric_float)
|
||||
if abs(numeric_int) == abs(numeric_float):
|
||||
flag_value = numeric_int
|
||||
else:
|
||||
flag_value = numeric_float
|
||||
except ValueError:
|
||||
flag_value = value
|
||||
logging.info('Setting --%s=%s', key, flag_value)
|
||||
setattr(FLAGS, key, flag_value)
|
||||
else:
|
||||
saved_flag_values = None
|
||||
try:
|
||||
result = decorated_func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
if saved_flag_values:
|
||||
flagsaver.restore_flag_values(saved_flag_values)
|
||||
|
||||
return runner
|
||||
+354
@@ -0,0 +1,354 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes BERT benchmarks and accuracy tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from absl import flags
|
||||
from absl.testing import flagsaver
|
||||
import tensorflow as tf
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
from official.benchmark import bert_benchmark_utils as benchmark_utils
|
||||
from official.nlp.bert import configs
|
||||
from official.nlp.bert import run_classifier
|
||||
from official.utils.misc import distribution_utils
|
||||
from official.benchmark import benchmark_wrappers
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
|
||||
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record'
|
||||
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record'
|
||||
CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data'
|
||||
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
TMP_DIR = os.getenv('TMPDIR')
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
|
||||
"""Base class to hold methods common to test classes in the module."""
|
||||
|
||||
def __init__(self, output_dir=None, tpu=None):
|
||||
super(BertClassifyBenchmarkBase, self).__init__(output_dir)
|
||||
self.num_epochs = None
|
||||
self.num_steps_per_epoch = None
|
||||
self.tpu = tpu
|
||||
FLAGS.steps_per_loop = 50
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def _run_bert_classifier(self, callbacks=None, use_ds=True):
|
||||
"""Starts BERT classification task."""
|
||||
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
|
||||
input_meta_data = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
|
||||
if self.num_steps_per_epoch:
|
||||
steps_per_epoch = self.num_steps_per_epoch
|
||||
else:
|
||||
train_data_size = input_meta_data['train_data_size']
|
||||
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
|
||||
warmup_steps = int(epochs * steps_per_epoch * 0.1)
|
||||
eval_steps = int(
|
||||
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
|
||||
if self.tpu:
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy='tpu', tpu_address=self.tpu)
|
||||
else:
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy='mirrored' if use_ds else 'off',
|
||||
num_gpus=self.num_gpus)
|
||||
|
||||
max_seq_length = input_meta_data['max_seq_length']
|
||||
train_input_fn = run_classifier.get_dataset_fn(
|
||||
FLAGS.train_data_path,
|
||||
max_seq_length,
|
||||
FLAGS.train_batch_size,
|
||||
is_training=True)
|
||||
eval_input_fn = run_classifier.get_dataset_fn(
|
||||
FLAGS.eval_data_path,
|
||||
max_seq_length,
|
||||
FLAGS.eval_batch_size,
|
||||
is_training=False)
|
||||
run_classifier.run_bert_classifier(
|
||||
strategy,
|
||||
bert_config,
|
||||
input_meta_data,
|
||||
FLAGS.model_dir,
|
||||
epochs,
|
||||
steps_per_epoch,
|
||||
FLAGS.steps_per_loop,
|
||||
eval_steps,
|
||||
warmup_steps,
|
||||
FLAGS.learning_rate,
|
||||
FLAGS.init_checkpoint,
|
||||
train_input_fn,
|
||||
eval_input_fn,
|
||||
custom_callbacks=callbacks)
|
||||
|
||||
|
||||
class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
|
||||
"""Short benchmark performance tests for BERT model.
|
||||
|
||||
Tests BERT classification performance in different GPU, TPU configurations.
|
||||
The naming convention of below test cases follow
|
||||
`benchmark_(number of gpus)_gpu_(dataset type)` for GPUs and
|
||||
`benchmark_(topology)_tpu_(dataset type)` for TPUs.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
|
||||
super(BertClassifyBenchmarkReal, self).__init__(
|
||||
output_dir=output_dir, tpu=tpu)
|
||||
|
||||
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
|
||||
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
|
||||
self.bert_config_file = MODEL_CONFIG_FILE_PATH
|
||||
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
|
||||
|
||||
# Since we only care about performance metrics, we limit
|
||||
# the number of training steps and epochs to prevent unnecessarily
|
||||
# long tests.
|
||||
self.num_steps_per_epoch = 100
|
||||
self.num_epochs = 1
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
training_summary_path,
|
||||
min_accuracy=0,
|
||||
max_accuracy=1,
|
||||
use_ds=True):
|
||||
"""Starts BERT performance benchmark test."""
|
||||
start_time_sec = time.time()
|
||||
self._run_bert_classifier(callbacks=[self.timer_callback], use_ds=use_ds)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
|
||||
summary = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
# Since we do not load from any pretrained checkpoints, we ignore all
|
||||
# accuracy metrics.
|
||||
summary.pop('eval_metrics', None)
|
||||
summary['start_time_sec'] = start_time_sec
|
||||
|
||||
super(BertClassifyBenchmarkReal, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_accuracy=min_accuracy,
|
||||
max_accuracy=max_accuracy)
|
||||
|
||||
def benchmark_1_gpu_mrpc(self):
|
||||
"""Test BERT model performance with 1 GPU."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc')
|
||||
FLAGS.train_data_path = self.train_data_path
|
||||
FLAGS.eval_data_path = self.eval_data_path
|
||||
FLAGS.input_meta_data_path = self.input_meta_data_path
|
||||
FLAGS.bert_config_file = self.bert_config_file
|
||||
FLAGS.train_batch_size = 4
|
||||
FLAGS.eval_batch_size = 4
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path)
|
||||
|
||||
def benchmark_1_gpu_mrpc_xla(self):
|
||||
"""Test BERT model performance with 1 GPU."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_xla')
|
||||
FLAGS.train_data_path = self.train_data_path
|
||||
FLAGS.eval_data_path = self.eval_data_path
|
||||
FLAGS.input_meta_data_path = self.input_meta_data_path
|
||||
FLAGS.bert_config_file = self.bert_config_file
|
||||
FLAGS.train_batch_size = 4
|
||||
FLAGS.eval_batch_size = 4
|
||||
FLAGS.enable_xla = True
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path)
|
||||
|
||||
def benchmark_1_gpu_mrpc_no_dist_strat(self):
|
||||
"""Test BERT model performance with 1 GPU, no distribution strategy."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_no_dist_strat')
|
||||
FLAGS.train_data_path = self.train_data_path
|
||||
FLAGS.eval_data_path = self.eval_data_path
|
||||
FLAGS.input_meta_data_path = self.input_meta_data_path
|
||||
FLAGS.bert_config_file = self.bert_config_file
|
||||
FLAGS.train_batch_size = 4
|
||||
FLAGS.eval_batch_size = 4
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path, use_ds=False)
|
||||
|
||||
def benchmark_8_gpu_mrpc(self):
|
||||
"""Test BERT model performance with 8 GPUs."""
|
||||
|
||||
self._setup()
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
|
||||
FLAGS.train_data_path = self.train_data_path
|
||||
FLAGS.eval_data_path = self.eval_data_path
|
||||
FLAGS.input_meta_data_path = self.input_meta_data_path
|
||||
FLAGS.bert_config_file = self.bert_config_file
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path)
|
||||
|
||||
def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
|
||||
"""Performance for 1 GPU no DS with automatic mixed precision."""
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_1_gpu_amp_mrpc_no_dist_strat')
|
||||
FLAGS.train_data_path = self.train_data_path
|
||||
FLAGS.eval_data_path = self.eval_data_path
|
||||
FLAGS.input_meta_data_path = self.input_meta_data_path
|
||||
FLAGS.bert_config_file = self.bert_config_file
|
||||
FLAGS.train_batch_size = 4
|
||||
FLAGS.eval_batch_size = 4
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path, use_ds=False)
|
||||
|
||||
def benchmark_8_gpu_amp_mrpc(self):
|
||||
"""Test BERT model performance with 8 GPUs with automatic mixed precision.
|
||||
"""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_mrpc')
|
||||
FLAGS.train_data_path = self.train_data_path
|
||||
FLAGS.eval_data_path = self.eval_data_path
|
||||
FLAGS.input_meta_data_path = self.input_meta_data_path
|
||||
FLAGS.bert_config_file = self.bert_config_file
|
||||
FLAGS.train_batch_size = 32
|
||||
FLAGS.eval_batch_size = 32
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path, use_ds=False)
|
||||
|
||||
def benchmark_2x2_tpu_mrpc(self):
|
||||
"""Test BERT model performance with 2x2 TPU."""
|
||||
|
||||
self._setup()
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
|
||||
FLAGS.train_data_path = self.train_data_path
|
||||
FLAGS.eval_data_path = self.eval_data_path
|
||||
FLAGS.input_meta_data_path = self.input_meta_data_path
|
||||
FLAGS.bert_config_file = self.bert_config_file
|
||||
FLAGS.train_batch_size = 32
|
||||
FLAGS.eval_batch_size = 32
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path, use_ds=False)
|
||||
|
||||
|
||||
class BertClassifyAccuracy(BertClassifyBenchmarkBase):
|
||||
"""Short accuracy test for BERT model.
|
||||
|
||||
Tests BERT classification task model accuracy. The naming
|
||||
convention of below test cases follow
|
||||
`benchmark_(number of gpus)_gpu_(dataset type)` format.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=TMP_DIR, **kwargs):
|
||||
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
|
||||
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
|
||||
self.bert_config_file = MODEL_CONFIG_FILE_PATH
|
||||
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
|
||||
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
|
||||
|
||||
super(BertClassifyAccuracy, self).__init__(output_dir=output_dir)
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
training_summary_path,
|
||||
min_accuracy=0.84,
|
||||
max_accuracy=0.88):
|
||||
"""Starts BERT accuracy benchmark test."""
|
||||
|
||||
start_time_sec = time.time()
|
||||
self._run_bert_classifier(callbacks=[self.timer_callback])
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
|
||||
summary = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
super(BertClassifyAccuracy, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_accuracy=min_accuracy,
|
||||
max_accuracy=max_accuracy)
|
||||
|
||||
def _setup(self):
|
||||
super(BertClassifyAccuracy, self)._setup()
|
||||
FLAGS.train_data_path = self.train_data_path
|
||||
FLAGS.eval_data_path = self.eval_data_path
|
||||
FLAGS.input_meta_data_path = self.input_meta_data_path
|
||||
FLAGS.bert_config_file = self.bert_config_file
|
||||
FLAGS.init_checkpoint = self.pretrained_checkpoint_path
|
||||
|
||||
def benchmark_8_gpu_mrpc(self):
|
||||
"""Run BERT model accuracy test with 8 GPUs.
|
||||
|
||||
Due to comparatively small cardinality of MRPC dataset, training
|
||||
accuracy metric has high variance between trainings. As so, we
|
||||
set the wide range of allowed accuracy (84% to 88%).
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path)
|
||||
|
||||
def benchmark_8_gpu_mrpc_xla(self):
|
||||
"""Run BERT model accuracy test with 8 GPUs with XLA."""
|
||||
self._setup()
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
|
||||
FLAGS.enable_xla = True
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+126
@@ -0,0 +1,126 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility functions or classes shared between BERT benchmarks."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import numpy as np
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
from official.utils.flags import core as flags_core
|
||||
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
|
||||
"""Callback that records time it takes to run each batch."""
|
||||
|
||||
def __init__(self, num_batches_to_skip=10):
|
||||
super(BenchmarkTimerCallback, self).__init__()
|
||||
self.batch_start_times = {}
|
||||
self.batch_stop_times = {}
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
self.batch_start_times[batch] = time.time()
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
# If there are multiple steps_per_loop, the end batch index will not be the
|
||||
# same as the starting index. Use the last starting index instead.
|
||||
if batch not in self.batch_start_times:
|
||||
batch = max(self.batch_start_times.keys())
|
||||
|
||||
self.batch_stop_times[batch] = time.time()
|
||||
|
||||
def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
|
||||
batch_durations = []
|
||||
for batch in self.batch_start_times:
|
||||
if batch in self.batch_stop_times and batch >= num_batches_to_skip:
|
||||
batch_durations.append(self.batch_stop_times[batch] -
|
||||
self.batch_start_times[batch])
|
||||
return batch_size / np.mean(batch_durations)
|
||||
|
||||
def get_startup_time(self, program_start_time):
|
||||
return self.batch_start_times[0] - program_start_time
|
||||
|
||||
|
||||
class BertBenchmarkBase(PerfZeroBenchmark):
|
||||
"""Base class to hold methods common to test classes."""
|
||||
local_flags = None
|
||||
|
||||
def __init__(self, output_dir=None):
|
||||
super(BertBenchmarkBase, self).__init__(output_dir=output_dir)
|
||||
self.num_gpus = 8
|
||||
self.timer_callback = None
|
||||
|
||||
def _setup(self):
|
||||
"""Sets up and resets flags before each test."""
|
||||
super(BertBenchmarkBase, self)._setup()
|
||||
self.timer_callback = BenchmarkTimerCallback()
|
||||
|
||||
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
|
||||
"""Report benchmark results by writing to local protobuf file.
|
||||
|
||||
Args:
|
||||
stats: dict returned from BERT models with known entries.
|
||||
wall_time_sec: the during of the benchmark execution in seconds
|
||||
min_accuracy: Minimum classification accuracy constraint to verify
|
||||
correctness of the model.
|
||||
max_accuracy: Maximum classification accuracy constraint to verify
|
||||
correctness of the model.
|
||||
"""
|
||||
metrics = [{
|
||||
'name': 'training_loss',
|
||||
'value': stats['train_loss'],
|
||||
}]
|
||||
if self.timer_callback:
|
||||
metrics.append({
|
||||
'name':
|
||||
'exp_per_second',
|
||||
'value':
|
||||
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
|
||||
FLAGS.steps_per_loop)
|
||||
})
|
||||
else:
|
||||
metrics.append({
|
||||
'name': 'exp_per_second',
|
||||
'value': 0.0,
|
||||
})
|
||||
if self.timer_callback and 'start_time_sec' in stats:
|
||||
metrics.append({
|
||||
'name': 'startup_time',
|
||||
'value': self.timer_callback.get_startup_time(stats['start_time_sec'])
|
||||
})
|
||||
|
||||
if 'eval_metrics' in stats:
|
||||
metrics.append({
|
||||
'name': 'eval_accuracy',
|
||||
'value': stats['eval_metrics'],
|
||||
'min_value': min_accuracy,
|
||||
'max_value': max_accuracy,
|
||||
})
|
||||
flags_str = flags_core.get_nondefault_flags_as_str()
|
||||
self.report_benchmark(
|
||||
iters=stats['total_training_steps'],
|
||||
wall_time=wall_time_sec,
|
||||
metrics=metrics,
|
||||
extras={'flags': flags_str})
|
||||
+654
@@ -0,0 +1,654 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes BERT SQuAD benchmarks and accuracy tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
from absl.testing import flagsaver
|
||||
import tensorflow as tf
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
from official.benchmark import bert_benchmark_utils as benchmark_utils
|
||||
from official.nlp.bert import run_squad
|
||||
from official.utils.misc import distribution_utils
|
||||
from official.utils.misc import keras_utils
|
||||
from official.benchmark import benchmark_wrappers
|
||||
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
|
||||
SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
|
||||
SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
|
||||
SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
|
||||
SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data'
|
||||
SQUAD_LONG_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_long_meta_data'
|
||||
SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data'
|
||||
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
TMP_DIR = os.getenv('TMPDIR')
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
|
||||
"""Base class to hold methods common to test classes in the module."""
|
||||
|
||||
def __init__(self, output_dir=None, tpu=None):
|
||||
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir)
|
||||
self.tpu = tpu
|
||||
|
||||
def _read_training_summary_from_file(self):
|
||||
"""Reads the training summary from a file."""
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
with tf.io.gfile.GFile(summary_path, 'rb') as reader:
|
||||
return json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
def _read_input_meta_data_from_file(self):
|
||||
"""Reads the input metadata from a file."""
|
||||
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
|
||||
return json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
def _get_distribution_strategy(self, ds_type='mirrored'):
|
||||
"""Gets the distribution strategy.
|
||||
|
||||
Args:
|
||||
ds_type: String, the distribution strategy type to be used. Can be
|
||||
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
|
||||
|
||||
Returns:
|
||||
A `tf.distribute.DistibutionStrategy` object.
|
||||
"""
|
||||
if self.tpu or ds_type == 'tpu':
|
||||
return distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy='tpu', tpu_address=self.tpu)
|
||||
elif ds_type == 'multi_worker_mirrored':
|
||||
# Configures cluster spec for multi-worker distribution strategy.
|
||||
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
|
||||
FLAGS.task_index)
|
||||
return distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=ds_type,
|
||||
num_gpus=self.num_gpus,
|
||||
all_reduce_alg=FLAGS.all_reduce_alg)
|
||||
|
||||
def _init_gpu_and_data_threads(self):
|
||||
"""Set env variables before any TF calls."""
|
||||
if FLAGS.tf_gpu_thread_mode:
|
||||
keras_utils.set_gpu_thread_mode_and_count(
|
||||
per_gpu_thread_count=FLAGS.per_gpu_thread_count,
|
||||
gpu_thread_mode=FLAGS.tf_gpu_thread_mode,
|
||||
num_gpus=self.num_gpus,
|
||||
datasets_num_private_threads=FLAGS.datasets_num_private_threads)
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
|
||||
"""Runs BERT SQuAD training. Uses mirrored strategy by default."""
|
||||
self._init_gpu_and_data_threads()
|
||||
input_meta_data = self._read_input_meta_data_from_file()
|
||||
strategy = self._get_distribution_strategy(ds_type)
|
||||
|
||||
run_squad.train_squad(
|
||||
strategy=strategy,
|
||||
input_meta_data=input_meta_data,
|
||||
run_eagerly=run_eagerly,
|
||||
custom_callbacks=[self.timer_callback])
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def _evaluate_squad(self, ds_type='mirrored'):
|
||||
"""Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
|
||||
self._init_gpu_and_data_threads()
|
||||
input_meta_data = self._read_input_meta_data_from_file()
|
||||
strategy = self._get_distribution_strategy(ds_type)
|
||||
|
||||
if input_meta_data.get('version_2_with_negative', False):
|
||||
logging.error('In memory evaluation result for SQuAD v2 is not accurate')
|
||||
eval_metrics = run_squad.eval_squad(strategy=strategy,
|
||||
input_meta_data=input_meta_data)
|
||||
# Use F1 score as reported evaluation metric.
|
||||
self.eval_metrics = eval_metrics['final_f1']
|
||||
|
||||
|
||||
class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
|
||||
"""Short benchmark performance tests for BERT SQuAD model.
|
||||
|
||||
Tests BERT SQuAD performance in different GPU configurations.
|
||||
The naming convention of below test cases follow
|
||||
`benchmark_(number of gpus)_gpu` format for GPUs and
|
||||
`benchmark_(topology)_tpu` format for TPUs.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
|
||||
super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu)
|
||||
|
||||
def _setup(self):
|
||||
"""Sets up the benchmark and SQuAD flags."""
|
||||
super(BertSquadBenchmarkReal, self)._setup()
|
||||
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
|
||||
FLAGS.predict_file = SQUAD_PREDICT_FILE
|
||||
FLAGS.vocab_file = SQUAD_VOCAB_FILE
|
||||
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
|
||||
FLAGS.num_train_epochs = 1
|
||||
FLAGS.steps_per_loop = 100
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
run_eagerly=False,
|
||||
ds_type='mirrored'):
|
||||
"""Runs the benchmark and reports various metrics."""
|
||||
if FLAGS.train_batch_size <= 4 or run_eagerly:
|
||||
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
|
||||
else:
|
||||
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
|
||||
start_time_sec = time.time()
|
||||
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
summary = self._read_training_summary_from_file()
|
||||
summary['start_time_sec'] = start_time_sec
|
||||
|
||||
super(BertSquadBenchmarkReal, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_accuracy=0,
|
||||
max_accuracy=1)
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad')
|
||||
FLAGS.train_batch_size = 4
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_eager(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
|
||||
FLAGS.train_batch_size = 2
|
||||
|
||||
self._run_and_report_benchmark(run_eagerly=True)
|
||||
|
||||
def benchmark_1_gpu_xla(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU with XLA."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad')
|
||||
# XLA runs out of memory when running with batch size 4.
|
||||
FLAGS.train_batch_size = 3
|
||||
FLAGS.enable_xla = True
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU without DS."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
|
||||
FLAGS.train_batch_size = 4
|
||||
|
||||
self._run_and_report_benchmark(ds_type='off')
|
||||
|
||||
def benchmark_1_gpu_eager_no_dist_strat(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU with eager execution."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_1_gpu_eager_no_dist_strat_squad')
|
||||
FLAGS.train_batch_size = 4
|
||||
|
||||
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
|
||||
|
||||
def benchmark_2_gpu(self):
|
||||
"""Tests BERT SQuAD model performance with 2 GPUs."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 2
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad')
|
||||
FLAGS.train_batch_size = 8
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_4_gpu(self):
|
||||
"""Tests BERT SQuAD model performance with 4 GPUs."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 4
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad')
|
||||
FLAGS.train_batch_size = 16
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Tests BERT SQuAD model performance with 8 GPUs."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
|
||||
FLAGS.train_batch_size = 24
|
||||
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_fp16_eager(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU and FP16."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16_eager')
|
||||
FLAGS.train_batch_size = 4
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 'dynamic'
|
||||
|
||||
self._run_and_report_benchmark(run_eagerly=True)
|
||||
|
||||
def benchmark_1_gpu_fp16(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU and FP16."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16')
|
||||
FLAGS.train_batch_size = 4
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 'dynamic'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_xla_fp16(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU with XLA and FP16."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad_fp16')
|
||||
FLAGS.train_batch_size = 4
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 'dynamic'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_2_gpu_fp16(self):
|
||||
"""Tests BERT SQuAD model performance with 2 GPUs and FP16."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 2
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad_fp16')
|
||||
FLAGS.train_batch_size = 8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 'dynamic'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_4_gpu_fp16(self):
|
||||
"""Tests BERT SQuAD model performance with 4 GPUs and FP16."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 4
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad_fp16')
|
||||
FLAGS.train_batch_size = 16
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 'dynamic'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_fp16(self):
|
||||
"""Tests BERT SQuAD model performance with 8 GPUs."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
|
||||
FLAGS.train_batch_size = 32
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 'dynamic'
|
||||
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_xla_fp16(self):
|
||||
"""Tests BERT SQuAD model performance with 8 GPUs with XLA."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
|
||||
FLAGS.train_batch_size = 32
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 'dynamic'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_amp(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp_squad')
|
||||
FLAGS.train_batch_size = 4
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_4_gpu_amp(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 4
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_amp_squad')
|
||||
FLAGS.train_batch_size = 16
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_amp(self):
|
||||
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_squad')
|
||||
FLAGS.train_batch_size = 32
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_2x2_tpu(self):
|
||||
"""Tests BERT SQuAD model performance with 2x2 TPU."""
|
||||
|
||||
self._setup()
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
|
||||
FLAGS.train_batch_size = 48
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
|
||||
class BertSquadAccuracy(BertSquadBenchmarkBase):
|
||||
"""Short accuracy test for BERT SQuAD model.
|
||||
|
||||
Tests BERT SQuAD accuracy. The naming convention of below test cases follow
|
||||
`benchmark_(number of gpus)_gpu` format for GPUs and
|
||||
`benchmark_(topology)_tpu` format for TPUs.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=None, tpu=None, **kwargs):
|
||||
super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
|
||||
|
||||
def _setup(self):
|
||||
"""Sets up the benchmark and SQuAD flags."""
|
||||
super(BertSquadAccuracy, self)._setup()
|
||||
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
|
||||
FLAGS.predict_file = SQUAD_PREDICT_FILE
|
||||
FLAGS.vocab_file = SQUAD_VOCAB_FILE
|
||||
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
|
||||
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
|
||||
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
|
||||
FLAGS.num_train_epochs = 2
|
||||
FLAGS.steps_per_loop = 100
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
run_eagerly=False,
|
||||
ds_type='mirrored'):
|
||||
"""Runs the benchmark and reports various metrics."""
|
||||
start_time_sec = time.time()
|
||||
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
|
||||
self._evaluate_squad(ds_type=ds_type)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
summary = self._read_training_summary_from_file()
|
||||
summary['eval_metrics'] = self.eval_metrics
|
||||
summary['start_time_sec'] = start_time_sec
|
||||
|
||||
super(BertSquadAccuracy, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_accuracy=0.900,
|
||||
max_accuracy=0.920)
|
||||
|
||||
def benchmark_1_gpu_eager(self):
|
||||
"""Tests BERT SQuAD model accuracy with 1 GPU with eager execution."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 1
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
|
||||
FLAGS.train_batch_size = 4
|
||||
|
||||
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
|
||||
FLAGS.train_batch_size = 24
|
||||
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_fp16(self):
|
||||
"""Tests BERT SQuAD model accuracy with 8 GPUs and FP16."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
|
||||
FLAGS.train_batch_size = 32
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 'dynamic'
|
||||
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_xla(self):
|
||||
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
|
||||
|
||||
self._setup()
|
||||
self.num_gpus = 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_xla')
|
||||
FLAGS.train_batch_size = 32
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_2x2_tpu(self):
|
||||
"""Tests BERT SQuAD model accuracy with 2x2 TPU."""
|
||||
|
||||
self._setup()
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
|
||||
FLAGS.train_batch_size = 48
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
|
||||
class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
|
||||
"""BERT SQuAD distributed accuracy tests with multiple workers."""
|
||||
|
||||
def __init__(self, output_dir=None, tpu=None, **kwargs):
|
||||
super(BertSquadMultiWorkerAccuracy, self).__init__(
|
||||
output_dir=output_dir, tpu=tpu)
|
||||
|
||||
def _setup(self):
|
||||
"""Sets up the benchmark and SQuAD flags."""
|
||||
super(BertSquadMultiWorkerAccuracy, self)._setup()
|
||||
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
|
||||
FLAGS.predict_file = SQUAD_PREDICT_FILE
|
||||
FLAGS.vocab_file = SQUAD_VOCAB_FILE
|
||||
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
|
||||
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
|
||||
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
|
||||
FLAGS.num_train_epochs = 2
|
||||
FLAGS.steps_per_loop = 100
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
use_ds=True,
|
||||
run_eagerly=False):
|
||||
"""Runs the benchmark and reports various metrics."""
|
||||
start_time_sec = time.time()
|
||||
self._train_squad(run_eagerly=run_eagerly,
|
||||
ds_type='multi_worker_mirrored')
|
||||
self._evaluate_squad(ds_type='multi_worker_mirrored')
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
summary = self._read_training_summary_from_file()
|
||||
summary['eval_metrics'] = self.eval_metrics
|
||||
|
||||
super(BertSquadMultiWorkerAccuracy, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_accuracy=0.900,
|
||||
max_accuracy=0.920)
|
||||
|
||||
def _benchmark_common(self, num_workers, all_reduce_alg):
|
||||
"""Common to all benchmarks in this class."""
|
||||
self._setup()
|
||||
|
||||
num_gpus = 8
|
||||
FLAGS.num_gpus = num_gpus
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.enable_xla = False
|
||||
FLAGS.distribution_strategy = 'multi_worker_mirrored'
|
||||
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
||||
FLAGS.datasets_num_private_threads = 32
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
|
||||
num_workers, all_reduce_alg))
|
||||
FLAGS.train_batch_size = 4 * num_gpus * num_workers
|
||||
FLAGS.all_reduce_alg = all_reduce_alg
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
|
||||
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
|
||||
self._benchmark_common(num_workers=2, all_reduce_alg='ring')
|
||||
|
||||
def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
|
||||
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
|
||||
self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
|
||||
|
||||
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
|
||||
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
|
||||
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
|
||||
|
||||
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
|
||||
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
|
||||
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
|
||||
|
||||
|
||||
class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
|
||||
"""BERT SQuAD distributed benchmark tests with multiple workers."""
|
||||
|
||||
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
|
||||
super(BertSquadMultiWorkerBenchmark, self).__init__(
|
||||
output_dir=output_dir, tpu=tpu)
|
||||
|
||||
def _setup(self):
|
||||
"""Sets up the benchmark and SQuAD flags."""
|
||||
super(BertSquadMultiWorkerBenchmark, self)._setup()
|
||||
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
|
||||
FLAGS.predict_file = SQUAD_PREDICT_FILE
|
||||
FLAGS.vocab_file = SQUAD_VOCAB_FILE
|
||||
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
|
||||
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
|
||||
FLAGS.num_train_epochs = 1
|
||||
FLAGS.steps_per_loop = 100
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
use_ds=True,
|
||||
run_eagerly=False):
|
||||
"""Runs the benchmark and reports various metrics."""
|
||||
if FLAGS.train_batch_size <= 4 * 8:
|
||||
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
|
||||
else:
|
||||
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
|
||||
start_time_sec = time.time()
|
||||
self._train_squad(run_eagerly=run_eagerly,
|
||||
ds_type='multi_worker_mirrored')
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
summary = self._read_training_summary_from_file()
|
||||
summary['start_time_sec'] = start_time_sec
|
||||
|
||||
super(BertSquadMultiWorkerBenchmark, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_accuracy=0,
|
||||
max_accuracy=1)
|
||||
|
||||
def _benchmark_common(self, num_workers, all_reduce_alg):
|
||||
"""Common to all benchmarks in this class."""
|
||||
self._setup()
|
||||
|
||||
num_gpus = 8
|
||||
FLAGS.num_gpus = num_gpus
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.enable_xla = False
|
||||
FLAGS.distribution_strategy = 'multi_worker_mirrored'
|
||||
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
||||
FLAGS.datasets_num_private_threads = 32
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
|
||||
num_workers, all_reduce_alg))
|
||||
FLAGS.train_batch_size = 4 * num_gpus * num_workers
|
||||
FLAGS.all_reduce_alg = all_reduce_alg
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_1_worker_fp16_ring_tweaked(self):
|
||||
"""8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
|
||||
self._benchmark_common(num_workers=1, all_reduce_alg='ring')
|
||||
|
||||
def benchmark_8_gpu_1_worker_fp16_nccl_tweaked(self):
|
||||
"""8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
|
||||
self._benchmark_common(num_workers=1, all_reduce_alg='nccl')
|
||||
|
||||
def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self):
|
||||
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
|
||||
self._benchmark_common(num_workers=2, all_reduce_alg='ring')
|
||||
|
||||
def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self):
|
||||
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
|
||||
self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
|
||||
|
||||
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
|
||||
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
|
||||
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
|
||||
|
||||
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
|
||||
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
|
||||
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
[
|
||||
{
|
||||
"description": "The ID of the benchmark run, where this metric should tie to.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "run_id",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The name of the metric, which should be descriptive. E.g. training_loss, accuracy.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The unit of the metric. E.g. MB per sec.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "unit",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The value of the metric.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "value",
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"description": "The timestamp when the metric is recorded.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "timestamp",
|
||||
"type": "TIMESTAMP"
|
||||
},
|
||||
{
|
||||
"description": "The global step when this metric is recorded.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "global_step",
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"description": "Free format metadata for the extra information about the metric.",
|
||||
"mode": "REPEATED",
|
||||
"name": "extras",
|
||||
"type": "RECORD",
|
||||
"fields": [
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "value",
|
||||
"type": "STRING"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
+368
@@ -0,0 +1,368 @@
|
||||
[
|
||||
{
|
||||
"description": "The UUID of the run for the benchmark.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "model_id",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The name of the model, E.g ResNet50, LeNet-5 etc.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "model_name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The date when the test of the model is started",
|
||||
"mode": "REQUIRED",
|
||||
"name": "run_date",
|
||||
"type": "TIMESTAMP"
|
||||
},
|
||||
{
|
||||
"description": "The unique name for a test by the combination of key parameters, eg batch size, num of GPU, etc. It is hardware independent.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "test_id",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The tensorflow version information.",
|
||||
"fields": [
|
||||
{
|
||||
"description": "Version of the tensorflow. E.g. 1.7.0-rc0",
|
||||
"mode": "REQUIRED",
|
||||
"name": "version",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "Git Hash of the tensorflow",
|
||||
"mode": "NULLABLE",
|
||||
"name": "git_hash",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The channel of the tensorflow binary, eg, nightly, RC, final, custom.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "channel",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "Identify anything special about the build, eg CUDA 10, NCCL, MKL, etc.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "build_type",
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"mode": "REQUIRED",
|
||||
"name": "tensorflow_version",
|
||||
"type": "RECORD"
|
||||
},
|
||||
{
|
||||
"description": "The arbitrary attribute of the model.",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The name of the attribute.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The value of the attribute.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "value",
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"mode": "REPEATED",
|
||||
"name": "attribute",
|
||||
"type": "RECORD"
|
||||
},
|
||||
{
|
||||
"description": "Environment variables when the benchmark run is executed.",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The name of the variable.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The value of the variable.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "value",
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"mode": "REPEATED",
|
||||
"name": "environment_variable",
|
||||
"type": "RECORD"
|
||||
},
|
||||
{
|
||||
"description": "TF Environment variables when the benchmark run is executed.",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The name of the variable.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The value of the variable.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "value",
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"mode": "REPEATED",
|
||||
"name": "tensorflow_environment_variables",
|
||||
"type": "RECORD"
|
||||
},
|
||||
{
|
||||
"description": "The list of parameters run with the model. It could contain hyperparameters or others.",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The name of the parameter.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The string value of the parameter.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "string_value",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The bool value of the parameter.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "bool_value",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The int/long value of the parameter.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "long_value",
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"description": "The double/float value of parameter.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "float_value",
|
||||
"type": "FLOAT"
|
||||
}
|
||||
],
|
||||
"mode": "REPEATED",
|
||||
"name": "run_parameters",
|
||||
"type": "RECORD"
|
||||
},
|
||||
{
|
||||
"description": "The dataset that run with the benchmark.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "dataset",
|
||||
"type": "RECORD",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The name of the dataset that the model is trained/validated with. E.g ImageNet, mnist.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The arbitrary attribute of the dataset.",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The name of the attribute.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The value of the attribute.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "value",
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"mode": "REPEATED",
|
||||
"name": "attribute",
|
||||
"type": "RECORD"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Used to differentiate from AWS, GCE or DGX-1 at a high level",
|
||||
"mode": "NULLABLE",
|
||||
"name": "test_environment",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The machine configuration of the benchmark run.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "machine_config",
|
||||
"type": "RECORD",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The platform information of the benchmark run.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "platform_info",
|
||||
"type": "RECORD",
|
||||
"fields": [
|
||||
{
|
||||
"description": "Eg: 64bit.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "bits",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "Eg: ELF.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "linkage",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "Eg: i386.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "machine",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "Eg: 3.13.0-76-generic.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "release",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "Eg: Linux.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "system",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "Eg: #120-Ubuntu SMP Mon Jan 18 15:59:10 UTC 2016.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "version",
|
||||
"type": "STRING"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "The CPU information of the benchmark run.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "cpu_info",
|
||||
"type": "RECORD",
|
||||
"fields": [
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "num_cores",
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "num_cores_allowed",
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"description" : "How fast are those CPUs.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "mhz_per_cpu",
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"description" : "Additional CPU info, Eg: Intel Ivybridge with HyperThreading (24 cores).",
|
||||
"mode": "NULLABLE",
|
||||
"name": "cpu_info",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description" : "What kind of cpu scaling is enabled on the host. Eg performance, ondemand, conservative, mixed.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "cpu_governor",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "Cache size of the CPUs.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "cache_size",
|
||||
"type": "RECORD",
|
||||
"fields": [
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "level",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "size",
|
||||
"type": "INTEGER"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "gpu_info",
|
||||
"type": "RECORD",
|
||||
"fields": [
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "count",
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "model",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "cuda_version",
|
||||
"type": "STRING"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "The cloud instance inforation if the benchmark run is executed on cloud",
|
||||
"mode": "NULLABLE",
|
||||
"name": "cloud_info",
|
||||
"type": "RECORD",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The instance type, E.g. n1-standard-4.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "instance_type",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The arbitrary attribute of the cloud info.",
|
||||
"fields": [
|
||||
{
|
||||
"description": "The name of the attribute.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "name",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The value of the attribute.",
|
||||
"mode": "NULLABLE",
|
||||
"name": "value",
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"mode": "REPEATED",
|
||||
"name": "attribute",
|
||||
"type": "RECORD"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "memory_total",
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"mode": "NULLABLE",
|
||||
"name": "memory_available",
|
||||
"type": "STRING"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
[
|
||||
{
|
||||
"description": "The UUID of the run for the benchmark.",
|
||||
"mode": "REQUIRED",
|
||||
"name": "run_id",
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"description": "The status of the run for the benchmark. Eg, running, failed, success",
|
||||
"mode": "REQUIRED",
|
||||
"name": "status",
|
||||
"type": "STRING"
|
||||
}
|
||||
]
|
||||
+98
@@ -0,0 +1,98 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes Keras benchmarks and accuracy tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
|
||||
from official.utils.flags import core as flags_core
|
||||
|
||||
|
||||
class KerasBenchmark(PerfZeroBenchmark):
|
||||
"""Base benchmark class with methods to simplify testing."""
|
||||
|
||||
def __init__(self,
|
||||
output_dir=None,
|
||||
default_flags=None,
|
||||
flag_methods=None,
|
||||
tpu=None):
|
||||
super(KerasBenchmark, self).__init__(
|
||||
output_dir=output_dir,
|
||||
default_flags=default_flags,
|
||||
flag_methods=flag_methods,
|
||||
tpu=tpu)
|
||||
|
||||
def _report_benchmark(self,
|
||||
stats,
|
||||
wall_time_sec,
|
||||
top_1_max=None,
|
||||
top_1_min=None,
|
||||
log_steps=None,
|
||||
total_batch_size=None,
|
||||
warmup=1,
|
||||
start_time_sec=None):
|
||||
"""Report benchmark results by writing to local protobuf file.
|
||||
|
||||
Args:
|
||||
stats: dict returned from keras models with known entries.
|
||||
wall_time_sec: the during of the benchmark execution in seconds
|
||||
top_1_max: highest passing level for top_1 accuracy.
|
||||
top_1_min: lowest passing level for top_1 accuracy.
|
||||
log_steps: How often the log was created for stats['step_timestamp_log'].
|
||||
total_batch_size: Global batch-size.
|
||||
warmup: number of entries in stats['step_timestamp_log'] to ignore.
|
||||
start_time_sec: the start time of the program in seconds since epoch
|
||||
"""
|
||||
|
||||
metrics = []
|
||||
if 'accuracy_top_1' in stats:
|
||||
metrics.append({'name': 'accuracy_top_1',
|
||||
'value': stats['accuracy_top_1'],
|
||||
'min_value': top_1_min,
|
||||
'max_value': top_1_max})
|
||||
metrics.append({'name': 'top_1_train_accuracy',
|
||||
'value': stats['training_accuracy_top_1']})
|
||||
|
||||
if (warmup and 'step_timestamp_log' in stats and
|
||||
len(stats['step_timestamp_log']) > warmup):
|
||||
# first entry in the time_log is start of step 1. The rest of the
|
||||
# entries are the end of each step recorded
|
||||
time_log = stats['step_timestamp_log']
|
||||
elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
|
||||
num_examples = (
|
||||
total_batch_size * log_steps * (len(time_log) - warmup - 1))
|
||||
examples_per_sec = num_examples / elapsed
|
||||
metrics.append({'name': 'exp_per_second',
|
||||
'value': examples_per_sec})
|
||||
|
||||
if 'avg_exp_per_second' in stats:
|
||||
metrics.append({'name': 'avg_exp_per_second',
|
||||
'value': stats['avg_exp_per_second']})
|
||||
|
||||
if start_time_sec and 'step_timestamp_log' in stats:
|
||||
time_log = stats['step_timestamp_log']
|
||||
# time_log[0] is recorded at the beginning of the first step.
|
||||
startup_time = time_log[0].timestamp - start_time_sec
|
||||
metrics.append({'name': 'startup_time', 'value': startup_time})
|
||||
|
||||
flags_str = flags_core.get_nondefault_flags_as_str()
|
||||
self.report_benchmark(
|
||||
iters=-1,
|
||||
wall_time=wall_time_sec,
|
||||
metrics=metrics,
|
||||
extras={'flags': flags_str})
|
||||
+402
@@ -0,0 +1,402 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes Keras benchmarks and accuracy tests."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
from absl import flags
|
||||
import tensorflow as tf # pylint: disable=g-bad-import-order
|
||||
|
||||
from official.benchmark import keras_benchmark
|
||||
from official.benchmark import benchmark_wrappers
|
||||
from official.benchmark.models import resnet_cifar_main
|
||||
|
||||
MIN_TOP_1_ACCURACY = 0.929
|
||||
MAX_TOP_1_ACCURACY = 0.938
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
CIFAR_DATA_DIR_NAME = 'cifar-10-batches-bin'
|
||||
|
||||
|
||||
class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
|
||||
"""Accuracy tests for ResNet56 Keras CIFAR-10."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
"""A benchmark class.
|
||||
|
||||
Args:
|
||||
output_dir: directory where to output e.g. log files
|
||||
root_data_dir: directory under which to look for dataset
|
||||
**kwargs: arbitrary named arguments. This is needed to make the
|
||||
constructor forward compatible in case PerfZero provides more
|
||||
named arguments before updating the constructor.
|
||||
"""
|
||||
|
||||
self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
|
||||
flag_methods = [resnet_cifar_main.define_cifar_flags]
|
||||
|
||||
super(Resnet56KerasAccuracy, self).__init__(
|
||||
output_dir=output_dir, flag_methods=flag_methods)
|
||||
|
||||
def _setup(self):
|
||||
super(Resnet56KerasAccuracy, self)._setup()
|
||||
FLAGS.use_tensor_lr = False
|
||||
|
||||
def benchmark_graph_1_gpu(self):
|
||||
"""Test keras based model with Keras fit and distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
|
||||
FLAGS.dtype = 'fp32'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
"""Test keras based model with eager and distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
||||
FLAGS.dtype = 'fp32'
|
||||
FLAGS.enable_eager = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu(self):
|
||||
"""Test keras based model on CPU."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
|
||||
FLAGS.dtype = 'fp32'
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.data_format = 'channels_last'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu_no_dist_strat(self):
|
||||
"""Test keras based model on CPU without distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
|
||||
FLAGS.dtype = 'fp32'
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.data_format = 'channels_last'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu_no_dist_strat_run_eagerly(self):
|
||||
"""Test keras based model on CPU w/forced eager and no dist_strat."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_cpu_no_dist_strat_run_eagerly')
|
||||
FLAGS.dtype = 'fp32'
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.run_eagerly = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.data_format = 'channels_last'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat(self):
|
||||
"""Test keras based model with eager and no dist strat."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
|
||||
FLAGS.dtype = 'fp32'
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
|
||||
"""Test keras based model w/forced eager and no dist_strat."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_1_gpu_no_dist_strat_run_eagerly')
|
||||
FLAGS.dtype = 'fp32'
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.run_eagerly = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_graph_1_gpu_no_dist_strat(self):
|
||||
"""Test keras based model with Keras fit but not distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
|
||||
FLAGS.dtype = 'fp32'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_2_gpu(self):
|
||||
"""Test keras based model with eager and distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 2
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
|
||||
FLAGS.dtype = 'fp32'
|
||||
FLAGS.enable_eager = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_graph_2_gpu(self):
|
||||
"""Test keras based model with Keras fit and distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 2
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.train_epochs = 182
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
|
||||
FLAGS.dtype = 'fp32'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self):
|
||||
start_time_sec = time.time()
|
||||
stats = resnet_cifar_main.run(FLAGS)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
super(Resnet56KerasAccuracy, self)._report_benchmark(
|
||||
stats,
|
||||
wall_time_sec,
|
||||
top_1_min=MIN_TOP_1_ACCURACY,
|
||||
top_1_max=MAX_TOP_1_ACCURACY,
|
||||
total_batch_size=FLAGS.batch_size,
|
||||
log_steps=100)
|
||||
|
||||
|
||||
class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
|
||||
"""Short performance tests for ResNet56 via Keras and CIFAR-10."""
|
||||
|
||||
def __init__(self, output_dir=None, default_flags=None):
|
||||
flag_methods = [resnet_cifar_main.define_cifar_flags]
|
||||
|
||||
super(Resnet56KerasBenchmarkBase, self).__init__(
|
||||
output_dir=output_dir,
|
||||
flag_methods=flag_methods,
|
||||
default_flags=default_flags)
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self):
|
||||
start_time_sec = time.time()
|
||||
stats = resnet_cifar_main.run(FLAGS)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
super(Resnet56KerasBenchmarkBase, self)._report_benchmark(
|
||||
stats,
|
||||
wall_time_sec,
|
||||
total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
"""Test 1 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_xla(self):
|
||||
"""Test 1 gpu with xla enabled."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.run_eagerly = False
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_graph_1_gpu(self):
|
||||
"""Test 1 gpu graph."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.enable_eager = False
|
||||
FLAGS.run_eagerly = False
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat(self):
|
||||
"""Test 1 gpu without distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_graph_1_gpu_no_dist_strat(self):
|
||||
"""Test 1 gpu graph mode without distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.enable_eager = False
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
|
||||
"""Test 1 gpu without distribution strategy and forced eager."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_1_gpu_no_dist_strat_run_eagerly')
|
||||
FLAGS.dtype = 'fp32'
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.run_eagerly = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_2_gpu(self):
|
||||
"""Test 2 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 2
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.run_eagerly = False
|
||||
FLAGS.distribution_strategy = 'mirrored'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
|
||||
FLAGS.batch_size = 128 * 2 # 2 GPUs
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_graph_2_gpu(self):
|
||||
"""Test 2 gpu graph mode."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 2
|
||||
FLAGS.enable_eager = False
|
||||
FLAGS.run_eagerly = False
|
||||
FLAGS.distribution_strategy = 'mirrored'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
|
||||
FLAGS.batch_size = 128 * 2 # 2 GPUs
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu(self):
|
||||
"""Test cpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.data_format = 'channels_last'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_graph_cpu(self):
|
||||
"""Test cpu graph mode."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.enable_eager = False
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu')
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.data_format = 'channels_last'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu_no_dist_strat_run_eagerly(self):
|
||||
"""Test cpu without distribution strategy and forced eager."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.run_eagerly = True
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_cpu_no_dist_strat_run_eagerly')
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.data_format = 'channels_last'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu_no_dist_strat(self):
|
||||
"""Test cpu without distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.enable_eager = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.data_format = 'channels_last'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_graph_cpu_no_dist_strat(self):
|
||||
"""Test cpu graph mode without distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.enable_eager = False
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu_no_dist_strat')
|
||||
FLAGS.batch_size = 128
|
||||
FLAGS.data_format = 'channels_last'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
|
||||
class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
|
||||
"""Synthetic benchmarks for ResNet56 and Keras."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
default_flags = {}
|
||||
default_flags['skip_eval'] = True
|
||||
default_flags['use_synthetic_data'] = True
|
||||
default_flags['train_steps'] = 110
|
||||
default_flags['log_steps'] = 10
|
||||
default_flags['use_tensor_lr'] = False
|
||||
|
||||
super(Resnet56KerasBenchmarkSynth, self).__init__(
|
||||
output_dir=output_dir, default_flags=default_flags)
|
||||
|
||||
|
||||
class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
|
||||
"""Real data benchmarks for ResNet56 and Keras."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
default_flags = {}
|
||||
default_flags['skip_eval'] = True
|
||||
default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
|
||||
default_flags['train_steps'] = 110
|
||||
default_flags['log_steps'] = 10
|
||||
default_flags['use_tensor_lr'] = False
|
||||
|
||||
super(Resnet56KerasBenchmarkReal, self).__init__(
|
||||
output_dir=output_dir, default_flags=default_flags)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+1685
File diff suppressed because it is too large
Load Diff
+287
@@ -0,0 +1,287 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Runs a ResNet model on the Cifar-10 dataset."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from official.benchmark.models import resnet_cifar_model
|
||||
from official.benchmark.models import synthetic_util
|
||||
from official.utils.flags import core as flags_core
|
||||
from official.utils.logs import logger
|
||||
from official.utils.misc import distribution_utils
|
||||
from official.utils.misc import keras_utils
|
||||
from official.vision.image_classification.resnet import cifar_preprocessing
|
||||
from official.vision.image_classification.resnet import common
|
||||
|
||||
|
||||
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
|
||||
(0.1, 91), (0.01, 136), (0.001, 182)
|
||||
]
|
||||
|
||||
|
||||
def learning_rate_schedule(current_epoch,
|
||||
current_batch,
|
||||
batches_per_epoch,
|
||||
batch_size):
|
||||
"""Handles linear scaling rule and LR decay.
|
||||
|
||||
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
|
||||
provided scaling factor.
|
||||
|
||||
Args:
|
||||
current_epoch: integer, current epoch indexed from 0.
|
||||
current_batch: integer, current batch in the current epoch, indexed from 0.
|
||||
batches_per_epoch: integer, number of steps in an epoch.
|
||||
batch_size: integer, total batch sized.
|
||||
|
||||
Returns:
|
||||
Adjusted learning rate.
|
||||
"""
|
||||
del current_batch, batches_per_epoch # not used
|
||||
initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
|
||||
learning_rate = initial_learning_rate
|
||||
for mult, start_epoch in LR_SCHEDULE:
|
||||
if current_epoch >= start_epoch:
|
||||
learning_rate = initial_learning_rate * mult
|
||||
else:
|
||||
break
|
||||
return learning_rate
|
||||
|
||||
|
||||
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
|
||||
"""Callback to update learning rate on every batch (not epoch boundaries).
|
||||
|
||||
N.B. Only support Keras optimizers, not TF optimizers.
|
||||
|
||||
Attributes:
|
||||
schedule: a function that takes an epoch index and a batch index as input
|
||||
(both integer, indexed from 0) and returns a new learning rate as
|
||||
output (float).
|
||||
"""
|
||||
|
||||
def __init__(self, schedule, batch_size, steps_per_epoch):
|
||||
super(LearningRateBatchScheduler, self).__init__()
|
||||
self.schedule = schedule
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.batch_size = batch_size
|
||||
self.epochs = -1
|
||||
self.prev_lr = -1
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
if not hasattr(self.model.optimizer, 'learning_rate'):
|
||||
raise ValueError('Optimizer must have a "learning_rate" attribute.')
|
||||
self.epochs += 1
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
"""Executes before step begins."""
|
||||
lr = self.schedule(self.epochs,
|
||||
batch,
|
||||
self.steps_per_epoch,
|
||||
self.batch_size)
|
||||
if not isinstance(lr, (float, np.float32, np.float64)):
|
||||
raise ValueError('The output of the "schedule" function should be float.')
|
||||
if lr != self.prev_lr:
|
||||
self.model.optimizer.learning_rate = lr # lr should be a float here
|
||||
self.prev_lr = lr
|
||||
logging.debug(
|
||||
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
|
||||
'change learning rate to %s.', self.epochs, batch, lr)
|
||||
|
||||
|
||||
def run(flags_obj):
|
||||
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs.
|
||||
|
||||
Args:
|
||||
flags_obj: An object containing parsed flag values.
|
||||
|
||||
Raises:
|
||||
ValueError: If fp16 is passed as it is not currently supported.
|
||||
|
||||
Returns:
|
||||
Dictionary of training and eval stats.
|
||||
"""
|
||||
keras_utils.set_session_config(
|
||||
enable_eager=flags_obj.enable_eager,
|
||||
enable_xla=flags_obj.enable_xla)
|
||||
|
||||
# Execute flag override logic for better model performance
|
||||
if flags_obj.tf_gpu_thread_mode:
|
||||
keras_utils.set_gpu_thread_mode_and_count(
|
||||
per_gpu_thread_count=flags_obj.per_gpu_thread_count,
|
||||
gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
|
||||
num_gpus=flags_obj.num_gpus,
|
||||
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
|
||||
common.set_cudnn_batchnorm_mode()
|
||||
|
||||
dtype = flags_core.get_tf_dtype(flags_obj)
|
||||
if dtype == 'fp16':
|
||||
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
|
||||
'value(fp32).')
|
||||
|
||||
data_format = flags_obj.data_format
|
||||
if data_format is None:
|
||||
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
|
||||
else 'channels_last')
|
||||
tf.keras.backend.set_image_data_format(data_format)
|
||||
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=flags_obj.distribution_strategy,
|
||||
num_gpus=flags_obj.num_gpus,
|
||||
all_reduce_alg=flags_obj.all_reduce_alg,
|
||||
num_packs=flags_obj.num_packs)
|
||||
|
||||
if strategy:
|
||||
# flags_obj.enable_get_next_as_optional controls whether enabling
|
||||
# get_next_as_optional behavior in DistributedIterator. If true, last
|
||||
# partial batch can be supported.
|
||||
strategy.extended.experimental_enable_get_next_as_optional = (
|
||||
flags_obj.enable_get_next_as_optional
|
||||
)
|
||||
|
||||
strategy_scope = distribution_utils.get_strategy_scope(strategy)
|
||||
|
||||
if flags_obj.use_synthetic_data:
|
||||
synthetic_util.set_up_synthetic_data()
|
||||
input_fn = common.get_synth_input_fn(
|
||||
height=cifar_preprocessing.HEIGHT,
|
||||
width=cifar_preprocessing.WIDTH,
|
||||
num_channels=cifar_preprocessing.NUM_CHANNELS,
|
||||
num_classes=cifar_preprocessing.NUM_CLASSES,
|
||||
dtype=flags_core.get_tf_dtype(flags_obj),
|
||||
drop_remainder=True)
|
||||
else:
|
||||
synthetic_util.undo_set_up_synthetic_data()
|
||||
input_fn = cifar_preprocessing.input_fn
|
||||
|
||||
train_input_dataset = input_fn(
|
||||
is_training=True,
|
||||
data_dir=flags_obj.data_dir,
|
||||
batch_size=flags_obj.batch_size,
|
||||
parse_record_fn=cifar_preprocessing.parse_record,
|
||||
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
|
||||
dtype=dtype,
|
||||
# Setting drop_remainder to avoid the partial batch logic in normalization
|
||||
# layer, which triggers tf.where and leads to extra memory copy of input
|
||||
# sizes between host and GPU.
|
||||
drop_remainder=(not flags_obj.enable_get_next_as_optional))
|
||||
|
||||
eval_input_dataset = None
|
||||
if not flags_obj.skip_eval:
|
||||
eval_input_dataset = input_fn(
|
||||
is_training=False,
|
||||
data_dir=flags_obj.data_dir,
|
||||
batch_size=flags_obj.batch_size,
|
||||
parse_record_fn=cifar_preprocessing.parse_record)
|
||||
|
||||
steps_per_epoch = (
|
||||
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
|
||||
lr_schedule = 0.1
|
||||
if flags_obj.use_tensor_lr:
|
||||
initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
|
||||
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
|
||||
boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
|
||||
values=[initial_learning_rate] +
|
||||
list(p[0] * initial_learning_rate for p in LR_SCHEDULE))
|
||||
|
||||
with strategy_scope:
|
||||
optimizer = common.get_optimizer(lr_schedule)
|
||||
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
|
||||
model.compile(
|
||||
loss='sparse_categorical_crossentropy',
|
||||
optimizer=optimizer,
|
||||
metrics=(['sparse_categorical_accuracy']
|
||||
if flags_obj.report_accuracy_metrics else None),
|
||||
run_eagerly=flags_obj.run_eagerly)
|
||||
|
||||
train_epochs = flags_obj.train_epochs
|
||||
|
||||
callbacks = common.get_callbacks(steps_per_epoch)
|
||||
|
||||
if not flags_obj.use_tensor_lr:
|
||||
lr_callback = LearningRateBatchScheduler(
|
||||
schedule=learning_rate_schedule,
|
||||
batch_size=flags_obj.batch_size,
|
||||
steps_per_epoch=steps_per_epoch)
|
||||
callbacks.append(lr_callback)
|
||||
|
||||
# if mutliple epochs, ignore the train_steps flag.
|
||||
if train_epochs <= 1 and flags_obj.train_steps:
|
||||
steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
|
||||
train_epochs = 1
|
||||
|
||||
num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
|
||||
flags_obj.batch_size)
|
||||
|
||||
validation_data = eval_input_dataset
|
||||
if flags_obj.skip_eval:
|
||||
if flags_obj.set_learning_phase_to_train:
|
||||
# TODO(haoyuzhang): Understand slowdown of setting learning phase when
|
||||
# not using distribution strategy.
|
||||
tf.keras.backend.set_learning_phase(1)
|
||||
num_eval_steps = None
|
||||
validation_data = None
|
||||
|
||||
if not strategy and flags_obj.explicit_gpu_placement:
|
||||
# TODO(b/135607227): Add device scope automatically in Keras training loop
|
||||
# when not using distribition strategy.
|
||||
no_dist_strat_device = tf.device('/device:GPU:0')
|
||||
no_dist_strat_device.__enter__()
|
||||
|
||||
history = model.fit(train_input_dataset,
|
||||
epochs=train_epochs,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
callbacks=callbacks,
|
||||
validation_steps=num_eval_steps,
|
||||
validation_data=validation_data,
|
||||
validation_freq=flags_obj.epochs_between_evals,
|
||||
verbose=2)
|
||||
eval_output = None
|
||||
if not flags_obj.skip_eval:
|
||||
eval_output = model.evaluate(eval_input_dataset,
|
||||
steps=num_eval_steps,
|
||||
verbose=2)
|
||||
|
||||
if not strategy and flags_obj.explicit_gpu_placement:
|
||||
no_dist_strat_device.__exit__()
|
||||
|
||||
stats = common.build_stats(history, eval_output, callbacks)
|
||||
return stats
|
||||
|
||||
|
||||
def define_cifar_flags():
|
||||
common.define_keras_flags(dynamic_loss_scale=False)
|
||||
|
||||
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
|
||||
model_dir='/tmp/cifar10_model',
|
||||
epochs_between_evals=10,
|
||||
batch_size=128)
|
||||
|
||||
|
||||
def main(_):
|
||||
with logger.benchmark_context(flags.FLAGS):
|
||||
return run(flags.FLAGS)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.set_verbosity(logging.INFO)
|
||||
define_cifar_flags()
|
||||
app.run(main)
|
||||
+262
@@ -0,0 +1,262 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ResNet56 model for Keras adapted from tf.keras.applications.ResNet50.
|
||||
|
||||
# Reference:
|
||||
- [Deep Residual Learning for Image Recognition](
|
||||
https://arxiv.org/abs/1512.03385)
|
||||
Adapted from code contributed by BigMoyan.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import regularizers
|
||||
|
||||
|
||||
BATCH_NORM_DECAY = 0.997
|
||||
BATCH_NORM_EPSILON = 1e-5
|
||||
L2_WEIGHT_DECAY = 2e-4
|
||||
|
||||
|
||||
def identity_building_block(input_tensor,
|
||||
kernel_size,
|
||||
filters,
|
||||
stage,
|
||||
block,
|
||||
training=None):
|
||||
"""The identity block is the block that has no conv layer at shortcut.
|
||||
|
||||
Arguments:
|
||||
input_tensor: input tensor
|
||||
kernel_size: default 3, the kernel size of
|
||||
middle conv layer at main path
|
||||
filters: list of integers, the filters of 3 conv layer at main path
|
||||
stage: integer, current stage label, used for generating layer names
|
||||
block: current block label, used for generating layer names
|
||||
training: Only used if training keras model with Estimator. In other
|
||||
scenarios it is handled automatically.
|
||||
|
||||
Returns:
|
||||
Output tensor for the block.
|
||||
"""
|
||||
filters1, filters2 = filters
|
||||
if backend.image_data_format() == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||
|
||||
x = layers.Conv2D(filters1, kernel_size,
|
||||
padding='same', use_bias=False,
|
||||
kernel_initializer='he_normal',
|
||||
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
|
||||
name=conv_name_base + '2a')(input_tensor)
|
||||
x = layers.BatchNormalization(
|
||||
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
|
||||
name=bn_name_base + '2a')(x, training=training)
|
||||
x = layers.Activation('relu')(x)
|
||||
|
||||
x = layers.Conv2D(filters2, kernel_size,
|
||||
padding='same', use_bias=False,
|
||||
kernel_initializer='he_normal',
|
||||
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
|
||||
name=conv_name_base + '2b')(x)
|
||||
x = layers.BatchNormalization(
|
||||
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
|
||||
name=bn_name_base + '2b')(x, training=training)
|
||||
|
||||
x = layers.add([x, input_tensor])
|
||||
x = layers.Activation('relu')(x)
|
||||
return x
|
||||
|
||||
|
||||
def conv_building_block(input_tensor,
|
||||
kernel_size,
|
||||
filters,
|
||||
stage,
|
||||
block,
|
||||
strides=(2, 2),
|
||||
training=None):
|
||||
"""A block that has a conv layer at shortcut.
|
||||
|
||||
Arguments:
|
||||
input_tensor: input tensor
|
||||
kernel_size: default 3, the kernel size of
|
||||
middle conv layer at main path
|
||||
filters: list of integers, the filters of 3 conv layer at main path
|
||||
stage: integer, current stage label, used for generating layer names
|
||||
block: current block label, used for generating layer names
|
||||
strides: Strides for the first conv layer in the block.
|
||||
training: Only used if training keras model with Estimator. In other
|
||||
scenarios it is handled automatically.
|
||||
|
||||
Returns:
|
||||
Output tensor for the block.
|
||||
|
||||
Note that from stage 3,
|
||||
the first conv layer at main path is with strides=(2, 2)
|
||||
And the shortcut should have strides=(2, 2) as well
|
||||
"""
|
||||
filters1, filters2 = filters
|
||||
if tf.keras.backend.image_data_format() == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||
|
||||
x = layers.Conv2D(filters1, kernel_size, strides=strides,
|
||||
padding='same', use_bias=False,
|
||||
kernel_initializer='he_normal',
|
||||
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
|
||||
name=conv_name_base + '2a')(input_tensor)
|
||||
x = layers.BatchNormalization(
|
||||
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
|
||||
name=bn_name_base + '2a')(x, training=training)
|
||||
x = layers.Activation('relu')(x)
|
||||
|
||||
x = layers.Conv2D(filters2, kernel_size, padding='same', use_bias=False,
|
||||
kernel_initializer='he_normal',
|
||||
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
|
||||
name=conv_name_base + '2b')(x)
|
||||
x = layers.BatchNormalization(
|
||||
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
|
||||
name=bn_name_base + '2b')(x, training=training)
|
||||
|
||||
shortcut = layers.Conv2D(filters2, (1, 1), strides=strides, use_bias=False,
|
||||
kernel_initializer='he_normal',
|
||||
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
|
||||
name=conv_name_base + '1')(input_tensor)
|
||||
shortcut = layers.BatchNormalization(
|
||||
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
|
||||
name=bn_name_base + '1')(shortcut, training=training)
|
||||
|
||||
x = layers.add([x, shortcut])
|
||||
x = layers.Activation('relu')(x)
|
||||
return x
|
||||
|
||||
|
||||
def resnet_block(input_tensor,
|
||||
size,
|
||||
kernel_size,
|
||||
filters,
|
||||
stage,
|
||||
conv_strides=(2, 2),
|
||||
training=None):
|
||||
"""A block which applies conv followed by multiple identity blocks.
|
||||
|
||||
Arguments:
|
||||
input_tensor: input tensor
|
||||
size: integer, number of constituent conv/identity building blocks.
|
||||
A conv block is applied once, followed by (size - 1) identity blocks.
|
||||
kernel_size: default 3, the kernel size of
|
||||
middle conv layer at main path
|
||||
filters: list of integers, the filters of 3 conv layer at main path
|
||||
stage: integer, current stage label, used for generating layer names
|
||||
conv_strides: Strides for the first conv layer in the block.
|
||||
training: Only used if training keras model with Estimator. In other
|
||||
scenarios it is handled automatically.
|
||||
|
||||
Returns:
|
||||
Output tensor after applying conv and identity blocks.
|
||||
"""
|
||||
|
||||
x = conv_building_block(input_tensor, kernel_size, filters, stage=stage,
|
||||
strides=conv_strides, block='block_0',
|
||||
training=training)
|
||||
for i in range(size - 1):
|
||||
x = identity_building_block(x, kernel_size, filters, stage=stage,
|
||||
block='block_%d' % (i + 1), training=training)
|
||||
return x
|
||||
|
||||
|
||||
def resnet(num_blocks, classes=10, training=None):
|
||||
"""Instantiates the ResNet architecture.
|
||||
|
||||
Arguments:
|
||||
num_blocks: integer, the number of conv/identity blocks in each block.
|
||||
The ResNet contains 3 blocks with each block containing one conv block
|
||||
followed by (layers_per_block - 1) number of idenity blocks. Each
|
||||
conv/idenity block has 2 convolutional layers. With the input
|
||||
convolutional layer and the pooling layer towards the end, this brings
|
||||
the total size of the network to (6*num_blocks + 2)
|
||||
classes: optional number of classes to classify images into
|
||||
training: Only used if training keras model with Estimator. In other
|
||||
scenarios it is handled automatically.
|
||||
|
||||
Returns:
|
||||
A Keras model instance.
|
||||
"""
|
||||
|
||||
input_shape = (32, 32, 3)
|
||||
img_input = layers.Input(shape=input_shape)
|
||||
|
||||
if backend.image_data_format() == 'channels_first':
|
||||
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
|
||||
name='transpose')(img_input)
|
||||
bn_axis = 1
|
||||
else: # channel_last
|
||||
x = img_input
|
||||
bn_axis = 3
|
||||
|
||||
x = layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
|
||||
x = layers.Conv2D(16, (3, 3),
|
||||
strides=(1, 1),
|
||||
padding='valid', use_bias=False,
|
||||
kernel_initializer='he_normal',
|
||||
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
|
||||
name='conv1')(x)
|
||||
x = layers.BatchNormalization(axis=bn_axis,
|
||||
momentum=BATCH_NORM_DECAY,
|
||||
epsilon=BATCH_NORM_EPSILON,
|
||||
name='bn_conv1',)(x, training=training)
|
||||
x = layers.Activation('relu')(x)
|
||||
|
||||
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16],
|
||||
stage=2, conv_strides=(1, 1), training=training)
|
||||
|
||||
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[32, 32],
|
||||
stage=3, conv_strides=(2, 2), training=training)
|
||||
|
||||
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64],
|
||||
stage=4, conv_strides=(2, 2), training=training)
|
||||
|
||||
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
|
||||
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
|
||||
x = layers.Dense(classes,
|
||||
activation='softmax',
|
||||
kernel_initializer=initializers.RandomNormal(stddev=0.01),
|
||||
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
|
||||
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
|
||||
name='fc10')(x)
|
||||
|
||||
inputs = img_input
|
||||
# Create model.
|
||||
model = tf.keras.models.Model(inputs, x, name='resnet56')
|
||||
|
||||
return model
|
||||
|
||||
|
||||
resnet20 = functools.partial(resnet, num_blocks=3)
|
||||
resnet32 = functools.partial(resnet, num_blocks=5)
|
||||
resnet56 = functools.partial(resnet, num_blocks=9)
|
||||
resnet10 = functools.partial(resnet, num_blocks=110)
|
||||
+187
@@ -0,0 +1,187 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test the keras ResNet model with Cifar data."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.platform import googletest
|
||||
from official.benchmark.models import resnet_cifar_main
|
||||
from official.utils.misc import keras_utils
|
||||
from official.utils.testing import integration
|
||||
from official.vision.image_classification.resnet import cifar_preprocessing
|
||||
|
||||
|
||||
class KerasCifarTest(googletest.TestCase):
|
||||
"""Unit tests for Keras ResNet with Cifar."""
|
||||
|
||||
_extra_flags = [
|
||||
"-batch_size", "4",
|
||||
"-train_steps", "1",
|
||||
"-use_synthetic_data", "true"
|
||||
]
|
||||
_tempdir = None
|
||||
|
||||
def get_temp_dir(self):
|
||||
if not self._tempdir:
|
||||
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
|
||||
return self._tempdir
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls): # pylint: disable=invalid-name
|
||||
super(KerasCifarTest, cls).setUpClass()
|
||||
resnet_cifar_main.define_cifar_flags()
|
||||
|
||||
def setUp(self):
|
||||
super(KerasCifarTest, self).setUp()
|
||||
cifar_preprocessing.NUM_IMAGES["validation"] = 4
|
||||
|
||||
def tearDown(self):
|
||||
super(KerasCifarTest, self).tearDown()
|
||||
tf.io.gfile.rmtree(self.get_temp_dir())
|
||||
|
||||
def test_end_to_end_no_dist_strat(self):
|
||||
"""Test Keras model with 1 GPU, no distribution strategy."""
|
||||
config = keras_utils.get_config_proto_v1()
|
||||
tf.compat.v1.enable_eager_execution(config=config)
|
||||
|
||||
extra_flags = [
|
||||
"-distribution_strategy", "off",
|
||||
"-model_dir", "keras_cifar_no_dist_strat",
|
||||
"-data_format", "channels_last",
|
||||
]
|
||||
extra_flags = extra_flags + self._extra_flags
|
||||
|
||||
integration.run_synthetic(
|
||||
main=resnet_cifar_main.run,
|
||||
tmp_root=self.get_temp_dir(),
|
||||
extra_flags=extra_flags
|
||||
)
|
||||
|
||||
def test_end_to_end_graph_no_dist_strat(self):
|
||||
"""Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
|
||||
extra_flags = [
|
||||
"-enable_eager", "false",
|
||||
"-distribution_strategy", "off",
|
||||
"-model_dir", "keras_cifar_graph_no_dist_strat",
|
||||
"-data_format", "channels_last",
|
||||
]
|
||||
extra_flags = extra_flags + self._extra_flags
|
||||
|
||||
integration.run_synthetic(
|
||||
main=resnet_cifar_main.run,
|
||||
tmp_root=self.get_temp_dir(),
|
||||
extra_flags=extra_flags
|
||||
)
|
||||
|
||||
def test_end_to_end_1_gpu(self):
|
||||
"""Test Keras model with 1 GPU."""
|
||||
config = keras_utils.get_config_proto_v1()
|
||||
tf.compat.v1.enable_eager_execution(config=config)
|
||||
|
||||
if context.num_gpus() < 1:
|
||||
self.skipTest(
|
||||
"{} GPUs are not available for this test. {} GPUs are available".
|
||||
format(1, context.num_gpus()))
|
||||
|
||||
extra_flags = [
|
||||
"-num_gpus", "1",
|
||||
"-distribution_strategy", "mirrored",
|
||||
"-model_dir", "keras_cifar_1_gpu",
|
||||
"-data_format", "channels_last",
|
||||
]
|
||||
extra_flags = extra_flags + self._extra_flags
|
||||
|
||||
integration.run_synthetic(
|
||||
main=resnet_cifar_main.run,
|
||||
tmp_root=self.get_temp_dir(),
|
||||
extra_flags=extra_flags
|
||||
)
|
||||
|
||||
def test_end_to_end_graph_1_gpu(self):
|
||||
"""Test Keras model in legacy graph mode with 1 GPU."""
|
||||
if context.num_gpus() < 1:
|
||||
self.skipTest(
|
||||
"{} GPUs are not available for this test. {} GPUs are available".
|
||||
format(1, context.num_gpus()))
|
||||
|
||||
extra_flags = [
|
||||
"-num_gpus", "1",
|
||||
"-noenable_eager",
|
||||
"-distribution_strategy", "mirrored",
|
||||
"-model_dir", "keras_cifar_graph_1_gpu",
|
||||
"-data_format", "channels_last",
|
||||
]
|
||||
extra_flags = extra_flags + self._extra_flags
|
||||
|
||||
integration.run_synthetic(
|
||||
main=resnet_cifar_main.run,
|
||||
tmp_root=self.get_temp_dir(),
|
||||
extra_flags=extra_flags
|
||||
)
|
||||
|
||||
def test_end_to_end_2_gpu(self):
|
||||
"""Test Keras model with 2 GPUs."""
|
||||
config = keras_utils.get_config_proto_v1()
|
||||
tf.compat.v1.enable_eager_execution(config=config)
|
||||
|
||||
if context.num_gpus() < 2:
|
||||
self.skipTest(
|
||||
"{} GPUs are not available for this test. {} GPUs are available".
|
||||
format(2, context.num_gpus()))
|
||||
|
||||
extra_flags = [
|
||||
"-num_gpus", "2",
|
||||
"-distribution_strategy", "mirrored",
|
||||
"-model_dir", "keras_cifar_2_gpu",
|
||||
]
|
||||
extra_flags = extra_flags + self._extra_flags
|
||||
|
||||
integration.run_synthetic(
|
||||
main=resnet_cifar_main.run,
|
||||
tmp_root=self.get_temp_dir(),
|
||||
extra_flags=extra_flags
|
||||
)
|
||||
|
||||
def test_end_to_end_graph_2_gpu(self):
|
||||
"""Test Keras model in legacy graph mode with 2 GPUs."""
|
||||
if context.num_gpus() < 2:
|
||||
self.skipTest(
|
||||
"{} GPUs are not available for this test. {} GPUs are available".
|
||||
format(2, context.num_gpus()))
|
||||
|
||||
extra_flags = [
|
||||
"-num_gpus", "2",
|
||||
"-enable_eager", "false",
|
||||
"-distribution_strategy", "mirrored",
|
||||
"-model_dir", "keras_cifar_graph_2_gpu",
|
||||
]
|
||||
extra_flags = extra_flags + self._extra_flags
|
||||
|
||||
integration.run_synthetic(
|
||||
main=resnet_cifar_main.run,
|
||||
tmp_root=self.get_temp_dir(),
|
||||
extra_flags=extra_flags
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
+31
@@ -0,0 +1,31 @@
|
||||
# Shakespeare character LSTM model
|
||||
|
||||
This is an implemention of a simple character LSTM used to generate text.
|
||||
|
||||
## Instructions
|
||||
|
||||
First download the source data:
|
||||
|
||||
```
|
||||
wget https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
|
||||
```
|
||||
|
||||
Note that files other than shakepeare.txt can also be used to train the model to generater other text.
|
||||
|
||||
Then train the model:
|
||||
|
||||
```python
|
||||
python3 shakespeare_main.py --training_data shakespeare.txt \
|
||||
--model_dir /tmp/shakespeare
|
||||
```
|
||||
|
||||
This will place model checkpoints in `/tmp/shakespeare`, so that we can use them to make predictions.
|
||||
|
||||
Then generate predictions:
|
||||
|
||||
```python
|
||||
python3 shakespeare_main.py --training_data shakespeare.txt \
|
||||
--model_dir /tmp/shakespeare --notrain --predict_context=ROMEO:
|
||||
```
|
||||
|
||||
Change `--predict_context` and `--predict_length` to suit your needs.
|
||||
+1
@@ -0,0 +1 @@
|
||||
|
||||
+316
@@ -0,0 +1,316 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Runs a character LSTM model trained on Shakespeare."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
|
||||
# pylint: disable=wrong-import-order
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
# pylint: enable=wrong-import-order
|
||||
|
||||
from official.utils.flags import core as flags_core
|
||||
from official.utils.misc import distribution_utils
|
||||
from official.utils.misc import keras_utils
|
||||
|
||||
EMBEDDING_DIM = 256
|
||||
RNN_UNITS = 1024
|
||||
SEQ_LENGTH = 100
|
||||
# Calculated by running batch_size=1
|
||||
BATCHES_PER_EPOCH = 11043
|
||||
|
||||
|
||||
def define_flags():
|
||||
"""Define the flags for the Shakespeare character LSTM."""
|
||||
flags_core.define_base(data_dir=False,
|
||||
clean=False,
|
||||
train_epochs=True,
|
||||
epochs_between_evals=False,
|
||||
stop_threshold=False,
|
||||
num_gpu=True,
|
||||
hooks=False,
|
||||
export_dir=False,
|
||||
run_eagerly=True,
|
||||
distribution_strategy=True)
|
||||
|
||||
flags_core.define_performance(num_parallel_calls=False,
|
||||
inter_op=False,
|
||||
intra_op=False,
|
||||
synthetic_data=False,
|
||||
max_train_steps=False,
|
||||
dtype=True,
|
||||
loss_scale=True,
|
||||
enable_xla=True)
|
||||
|
||||
flags_core.set_defaults(train_epochs=43,
|
||||
batch_size=64)
|
||||
|
||||
flags.DEFINE_boolean(name='enable_eager', default=True, help='Enable eager?')
|
||||
flags.DEFINE_boolean(
|
||||
name='train', default=True,
|
||||
help='If true trains the model.')
|
||||
flags.DEFINE_string(
|
||||
name='predict_context', default=None,
|
||||
help='If set, makes a prediction with the given context.')
|
||||
flags.DEFINE_integer(
|
||||
name='predict_length', default=1000,
|
||||
help='Length of the predicted text including the context.')
|
||||
flags.DEFINE_integer(name='train_steps', default=None,
|
||||
help='Overrides train_steps per epoch if not None.')
|
||||
flags.DEFINE_integer(
|
||||
name='log_steps', default=100,
|
||||
help='For every log_steps, we log the timing information such as '
|
||||
'examples per second.')
|
||||
flags.DEFINE_string(
|
||||
name='training_data', default=None,
|
||||
help='Path to file containing the training data.')
|
||||
flags.DEFINE_boolean(name='cudnn', default=True, help='Use CuDNN LSTM.')
|
||||
|
||||
|
||||
def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
|
||||
"""Creates a dataset from a given text file.
|
||||
|
||||
Args:
|
||||
path_to_file: The path to the training data.
|
||||
batch_size: Batch size to use.
|
||||
seq_length: The length of the LSTM sequence.
|
||||
|
||||
Returns:
|
||||
A tuple, consisting of the Dataset and the class to character mapping
|
||||
and character to class mapping.
|
||||
"""
|
||||
with tf.io.gfile.GFile(path_to_file, 'rb') as train_data:
|
||||
text = train_data.read().decode(encoding='utf-8')
|
||||
|
||||
# Create vocab
|
||||
vocab = sorted(set(text))
|
||||
char2idx = {u: i for i, u in enumerate(vocab)}
|
||||
idx2char = np.array(vocab)
|
||||
|
||||
# Split text into sequence length + 1 chucks to create examples
|
||||
text_as_int = np.array([char2idx[c] for c in text])
|
||||
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
|
||||
sequences = char_dataset.batch(seq_length+1, drop_remainder=True)
|
||||
|
||||
def split_input_target(chunk):
|
||||
input_text = chunk[:-1]
|
||||
target_text = chunk[1:]
|
||||
return input_text, tf.one_hot(target_text, len(vocab))
|
||||
dataset = sequences.map(split_input_target)
|
||||
dataset = dataset.shuffle(10000).repeat()
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
|
||||
return dataset, idx2char, char2idx
|
||||
|
||||
|
||||
def build_model(vocab_size,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
rnn_units=RNN_UNITS,
|
||||
batch_size=None,
|
||||
stateful=False,
|
||||
use_cudnn=True):
|
||||
"""Builds the Shakespeare model.
|
||||
|
||||
Args:
|
||||
vocab_size: The number of character classes in the input.
|
||||
embedding_dim: The dimension of the embedding space for each class.
|
||||
rnn_units: The number of RNN units in the layer.
|
||||
batch_size: When predicting, the batch size of the predictions.
|
||||
stateful: If true, the LSTM is stateful.
|
||||
|
||||
Returns:
|
||||
A Keras Model.
|
||||
"""
|
||||
assert keras_utils.is_v2_0()
|
||||
LSTM = functools.partial(tf.keras.layers.LSTM, implementation=2)
|
||||
|
||||
# By indirecting the activation through a lambda layer, the logic to dispatch
|
||||
# to CuDNN in V2 doesn't trigger and we force the LSTM to run in non-CuDNN
|
||||
# mode.
|
||||
lstm_activation = ('tanh' if use_cudnn else
|
||||
lambda x: tf.math.tanh(x))
|
||||
|
||||
batch_shape = [batch_size if stateful else None, None]
|
||||
return tf.keras.Sequential([
|
||||
tf.keras.layers.Embedding(vocab_size, embedding_dim,
|
||||
batch_input_shape=batch_shape),
|
||||
LSTM(rnn_units,
|
||||
activation=lstm_activation,
|
||||
return_sequences=True,
|
||||
stateful=stateful,
|
||||
recurrent_initializer='glorot_uniform'),
|
||||
tf.keras.layers.Dense(vocab_size),
|
||||
tf.keras.layers.Softmax(dtype=tf.float32)])
|
||||
|
||||
|
||||
def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
|
||||
"""Trains a Shakespeare model.
|
||||
|
||||
Args:
|
||||
flags_obj: An object containing parsed flag values.s
|
||||
dataset: the training data set.
|
||||
vocab_size: the number of unique character classes.
|
||||
strategy: distribution strategy to use.
|
||||
checkpoint_dir: if not None, the directory in which to make checkpoints.
|
||||
|
||||
Returns:
|
||||
The training history and callbacks.
|
||||
"""
|
||||
if flags_obj.train_steps:
|
||||
train_steps = flags_obj.train_steps
|
||||
else:
|
||||
train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size
|
||||
strategy_scope = distribution_utils.get_strategy_scope(strategy)
|
||||
|
||||
with strategy_scope:
|
||||
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size,
|
||||
use_cudnn=flags_obj.cudnn)
|
||||
|
||||
# When keras_use_ctl is False, Model.fit() automatically applies
|
||||
# loss scaling so we don't need to create a LossScaleOptimizer.
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
loss=tf.keras.losses.CategoricalCrossentropy(),
|
||||
metrics=[tf.keras.metrics.Recall(top_k=1, name='RecallAt1'),
|
||||
tf.keras.metrics.Recall(top_k=5, name='RecallAt5')],
|
||||
run_eagerly=flags_obj.run_eagerly)
|
||||
|
||||
callbacks = []
|
||||
if checkpoint_dir:
|
||||
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch}')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=checkpoint_prefix,
|
||||
save_weights_only=True)
|
||||
callbacks.append(checkpoint_callback)
|
||||
time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
|
||||
flags_obj.log_steps)
|
||||
callbacks.append(time_callback)
|
||||
history = model.fit(dataset,
|
||||
epochs=flags_obj.train_epochs,
|
||||
steps_per_epoch=train_steps,
|
||||
callbacks=callbacks,
|
||||
verbose=2)
|
||||
return history, callbacks
|
||||
|
||||
|
||||
def make_prediction(checkpoint_dir, length, context, idx2char, char2idx):
|
||||
"""Make predictions from a Shakespeare model.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: the directory from which to load checkpoints
|
||||
length: the total length of the generated text (including the context).
|
||||
context: the initial text with which the LSTM is primed.
|
||||
idx2char: the character class to character mapping.
|
||||
char2idx: the character to character class mapping.
|
||||
|
||||
Returns:
|
||||
A generated string of text of the given length.
|
||||
"""
|
||||
prediction_model = build_model(
|
||||
vocab_size=len(idx2char), batch_size=1, stateful=True)
|
||||
prediction_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
|
||||
prediction_model.build(tf.TensorShape([1, None]))
|
||||
|
||||
input_eval = [char2idx[s] for s in context]
|
||||
input_eval = tf.expand_dims(input_eval, 0)
|
||||
|
||||
text_generated = []
|
||||
|
||||
prediction_model.reset_states()
|
||||
for _ in range(length - len(context)):
|
||||
predictions = prediction_model(input_eval)
|
||||
predictions = tf.squeeze(predictions, 0)
|
||||
|
||||
# We applied a softmax to the output of the model so that
|
||||
# tf.keras.metrics.Recall would work. We need logits for
|
||||
# tf.random.categorical, so we convert the probabilities back to log odds
|
||||
predictions = tf.math.log(predictions / (1 - predictions))
|
||||
|
||||
random_output = tf.random.categorical(predictions, num_samples=1)
|
||||
selected_id = random_output[-1, 0].numpy()
|
||||
input_eval = tf.expand_dims([selected_id], 0)
|
||||
text_generated.append(idx2char[selected_id])
|
||||
|
||||
return context + ''.join(text_generated)
|
||||
|
||||
|
||||
def run(flags_obj):
|
||||
"""Run Shakespeare training and predict.
|
||||
|
||||
Args:
|
||||
flags_obj: An object containing parsed flag values.
|
||||
|
||||
Returns:
|
||||
Dictionary with status from the run.
|
||||
"""
|
||||
if not flags_obj.training_data:
|
||||
raise ValueError(
|
||||
'Must set the path to a training data file. e.g download the following '
|
||||
'https://storage.googleapis.com/download.tensorflow.org/data/'
|
||||
'shakespeare.txt')
|
||||
|
||||
if flags_obj.dtype == 'fp16':
|
||||
policy = tf.keras.mixed_precision.experimental.Policy(
|
||||
'mixed_float16',
|
||||
loss_scale=flags_core.get_loss_scale(flags_obj,
|
||||
default_for_fp16='dynamic'))
|
||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
||||
|
||||
keras_utils.set_session_config(
|
||||
enable_eager=flags_obj.enable_eager,
|
||||
enable_xla=flags_obj.enable_xla)
|
||||
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=flags_obj.distribution_strategy,
|
||||
num_gpus=flags_obj.num_gpus)
|
||||
|
||||
dataset, idx2char, char2idx = get_dataset(flags_obj.training_data,
|
||||
batch_size=flags_obj.batch_size)
|
||||
stats = {}
|
||||
if flags_obj.train:
|
||||
history, callbacks = train_model(flags_obj, dataset,
|
||||
len(idx2char), strategy,
|
||||
checkpoint_dir=flags_obj.model_dir)
|
||||
|
||||
stats['history'] = history.history
|
||||
stats['callbacks'] = callbacks
|
||||
|
||||
if flags_obj.predict_context:
|
||||
if not flags_obj.model_dir:
|
||||
raise ValueError('Must set model_dir to get predictions.')
|
||||
print(make_prediction(flags_obj.model_dir,
|
||||
flags_obj.predict_length,
|
||||
flags_obj.predict_context,
|
||||
idx2char,
|
||||
char2idx))
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main(_):
|
||||
flags_obj = flags.FLAGS
|
||||
run(flags_obj)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
define_flags()
|
||||
app.run(main)
|
||||
+129
@@ -0,0 +1,129 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helper functions to generate data directly on devices."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import string
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# The `SyntheticDataset` is a temporary solution for generating synthetic data
|
||||
# directly on devices. It is only useful for Keras with Distribution
|
||||
# Strategies. We will have better support in `tf.data` or Distribution Strategy
|
||||
# later.
|
||||
class SyntheticDataset(object):
|
||||
"""A dataset that generates synthetic data on each device."""
|
||||
|
||||
def __init__(self, dataset, split_by=1):
|
||||
# dataset.take(1) doesn't have GPU kernel.
|
||||
with tf.device('device:CPU:0'):
|
||||
tensor = tf.data.experimental.get_single_element(dataset.take(1))
|
||||
flat_tensor = tf.nest.flatten(tensor)
|
||||
variable_data = []
|
||||
initializers = []
|
||||
for t in flat_tensor:
|
||||
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
|
||||
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
|
||||
v = tf.compat.v1.get_local_variable(self._random_name(),
|
||||
initializer=rebatched_t)
|
||||
variable_data.append(v)
|
||||
initializers.append(v.initializer)
|
||||
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
|
||||
self._iterator = SyntheticIterator(input_data, initializers)
|
||||
|
||||
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
|
||||
return ''.join(random.choice(chars) for _ in range(size))
|
||||
|
||||
def __iter__(self):
|
||||
return self._iterator
|
||||
|
||||
def make_one_shot_iterator(self):
|
||||
return self._iterator
|
||||
|
||||
def make_initializable_iterator(self):
|
||||
return self._iterator
|
||||
|
||||
|
||||
class SyntheticIterator(object):
|
||||
"""A dataset that generates synthetic data on each device."""
|
||||
|
||||
def __init__(self, input_data, initializers):
|
||||
self._input_data = input_data
|
||||
self._initializers = initializers
|
||||
|
||||
def get_next(self):
|
||||
return self._input_data
|
||||
|
||||
def next(self):
|
||||
return self.__next__()
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_next()
|
||||
except tf.errors.OutOfRangeError:
|
||||
raise StopIteration
|
||||
|
||||
def initialize(self):
|
||||
if tf.executing_eagerly():
|
||||
return tf.no_op()
|
||||
else:
|
||||
return self._initializers
|
||||
|
||||
|
||||
def _monkey_patch_dataset_method(strategy):
|
||||
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
|
||||
def make_dataset(self, dataset):
|
||||
logging.info('Using pure synthetic data.')
|
||||
with self.scope():
|
||||
if self.extended._global_batch_size: # pylint: disable=protected-access
|
||||
return SyntheticDataset(dataset, self.num_replicas_in_sync)
|
||||
else:
|
||||
return SyntheticDataset(dataset)
|
||||
|
||||
def make_iterator(self, dataset):
|
||||
dist_dataset = make_dataset(self, dataset)
|
||||
return iter(dist_dataset)
|
||||
|
||||
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
|
||||
strategy.make_dataset_iterator = make_iterator
|
||||
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
|
||||
strategy.experimental_distribute_dataset = make_dataset
|
||||
|
||||
|
||||
def _undo_monkey_patch_dataset_method(strategy):
|
||||
if hasattr(strategy, 'orig_make_dataset_iterator'):
|
||||
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
|
||||
if hasattr(strategy, 'orig_distribute_dataset'):
|
||||
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
|
||||
|
||||
|
||||
def set_up_synthetic_data():
|
||||
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
|
||||
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
|
||||
_monkey_patch_dataset_method(
|
||||
tf.distribute.experimental.MultiWorkerMirroredStrategy)
|
||||
|
||||
|
||||
def undo_set_up_synthetic_data():
|
||||
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
|
||||
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
|
||||
_undo_monkey_patch_dataset_method(
|
||||
tf.distribute.experimental.MultiWorkerMirroredStrategy)
|
||||
+457
@@ -0,0 +1,457 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes Keras benchmarks and accuracy tests."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
from absl.testing import flagsaver
|
||||
import tensorflow as tf
|
||||
from official.benchmark import benchmark_wrappers
|
||||
from official.recommendation import ncf_common
|
||||
from official.recommendation import ncf_keras_main
|
||||
from official.utils.flags import core
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
NCF_DATA_DIR_NAME = 'movielens_data'
|
||||
NCF_TF_DATA_1M_BATCH_DIR_NAME = 'gs://tf-perfzero-data/movielens_data/ncf_8gpu_1M_batch'
|
||||
|
||||
|
||||
class NCFKerasBenchmarkBase(tf.test.Benchmark):
|
||||
"""Base class for NCF model benchmark."""
|
||||
local_flags = None
|
||||
|
||||
def __init__(self,
|
||||
output_dir=None,
|
||||
default_flags=None,
|
||||
**kwargs):
|
||||
self.output_dir = output_dir
|
||||
self.default_flags = default_flags or {}
|
||||
# Run all benchmarks with ml_perf flag.
|
||||
self.default_flags['ml_perf'] = True
|
||||
|
||||
def _setup(self):
|
||||
"""Sets up and resets flags before each test."""
|
||||
logging.set_verbosity(logging.INFO)
|
||||
if NCFKerasBenchmarkBase.local_flags is None:
|
||||
ncf_common.define_ncf_flags()
|
||||
# Loads flags to get defaults to then override. List cannot be empty.
|
||||
flags.FLAGS(['foo'])
|
||||
core.set_defaults(**self.default_flags)
|
||||
saved_flag_values = flagsaver.save_flag_values()
|
||||
NCFKerasBenchmarkBase.local_flags = saved_flag_values
|
||||
else:
|
||||
flagsaver.restore_flag_values(NCFKerasBenchmarkBase.local_flags)
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self, hr_at_10_min=0, hr_at_10_max=0):
|
||||
start_time_sec = time.time()
|
||||
stats = ncf_keras_main.run_ncf(FLAGS)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
metrics = []
|
||||
metrics.append({'name': 'exp_per_second',
|
||||
'value': stats['avg_exp_per_second']})
|
||||
|
||||
if hr_at_10_min > 0:
|
||||
metrics.append({'name': 'hr_at_10',
|
||||
'value': stats['eval_hit_rate'],
|
||||
'min_value': hr_at_10_min,
|
||||
'max_value': hr_at_10_max})
|
||||
|
||||
metrics.append({'name': 'train_loss',
|
||||
'value': stats['loss']})
|
||||
|
||||
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)
|
||||
|
||||
|
||||
class NCFKerasAccuracy(NCFKerasBenchmarkBase):
|
||||
"""Benchmark NCF model using real data."""
|
||||
|
||||
def __init__(self,
|
||||
output_dir=None,
|
||||
root_data_dir=None,
|
||||
default_flags=None,
|
||||
**kwargs):
|
||||
root_data_dir = root_data_dir if root_data_dir else ''
|
||||
default_flags = {}
|
||||
default_flags['dataset'] = 'ml-20m'
|
||||
default_flags['num_gpus'] = 1
|
||||
default_flags['train_epochs'] = 10
|
||||
default_flags['clean'] = True
|
||||
default_flags['batch_size'] = 99000
|
||||
default_flags['learning_rate'] = 0.00382059
|
||||
default_flags['beta1'] = 0.783529
|
||||
default_flags['beta2'] = 0.909003
|
||||
default_flags['epsilon'] = 1.45439e-07
|
||||
default_flags['layers'] = [256, 256, 128, 64]
|
||||
default_flags['num_factors'] = 64
|
||||
default_flags['hr_threshold'] = 0.635
|
||||
default_flags['ml_perf'] = True
|
||||
default_flags['use_synthetic_data'] = False
|
||||
default_flags['data_dir'] = os.path.join(root_data_dir, NCF_DATA_DIR_NAME)
|
||||
|
||||
super(NCFKerasAccuracy, self).__init__(
|
||||
output_dir=output_dir,
|
||||
default_flags=default_flags,
|
||||
**kwargs)
|
||||
|
||||
def _run_and_report_benchmark_mlperf_like(self):
|
||||
"""Run test and report results.
|
||||
|
||||
Note: MLPerf like tests are not tuned to hit a specific hr@10 value, but
|
||||
we want it recorded.
|
||||
"""
|
||||
self._run_and_report_benchmark(hr_at_10_min=0.61)
|
||||
|
||||
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.645):
|
||||
"""Run test and report results.
|
||||
|
||||
Note: Target is 0.635, but some runs are below that level. Until we have
|
||||
multi-run tests, we have to accept a lower target.
|
||||
|
||||
Args:
|
||||
hr_at_10_min: Minimum acceptable hr@10 value.
|
||||
hr_at_10_max: Maximum acceptable hr@10 value.
|
||||
"""
|
||||
super(NCFKerasAccuracy, self)._run_and_report_benchmark(
|
||||
hr_at_10_min=hr_at_10_min,
|
||||
hr_at_10_max=hr_at_10_max)
|
||||
|
||||
def benchmark_1_gpu_early_stop(self):
|
||||
self._setup()
|
||||
FLAGS.early_stopping = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat_early_stop(self):
|
||||
self._setup()
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.early_stopping = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat_run_eagerly_early_stop(self):
|
||||
self._setup()
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.early_stopping = True
|
||||
FLAGS.run_eagerly = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_1_gpu_early_stop(self):
|
||||
self._setup()
|
||||
FLAGS.early_stopping = True
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_ctl_early_stop(self):
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.early_stopping = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_ctl_run_eagerly_early_stop(self):
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.early_stopping = True
|
||||
FLAGS.run_eagerly = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_1_gpu_ctl_early_stop(self):
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.early_stopping = True
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_2_gpus_early_stop(self):
|
||||
self._setup()
|
||||
FLAGS.early_stopping = True
|
||||
FLAGS.num_gpus = 2
|
||||
FLAGS.eval_batch_size = 160000
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_2_gpus_ctl_early_stop(self):
|
||||
"""NCF with custom training loop. Works only in TF 2.0."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.early_stopping = True
|
||||
FLAGS.num_gpus = 2
|
||||
FLAGS.eval_batch_size = 160000
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
#############################################
|
||||
# Tests below with mlperf in the test name are of two types:
|
||||
# 1) 1 GPU tests are based on MLPerf 0.5 and the TensorFlow pulled submission.
|
||||
# 2) 8 GPU tests are based on MLPerf 0.5 and use NVIDIA's hyper parameters.
|
||||
#
|
||||
# The purpose of both is to get a number to compare to existing results. To do
|
||||
# this the number of epochs is held constant rather than a race to a given
|
||||
# accuracy. The accuracy validation is done by the "early_stop" tests.
|
||||
#############################################
|
||||
|
||||
def benchmark_1_gpu_mlperf_like(self):
|
||||
"""1 GPU using keras fit/compile."""
|
||||
self._setup()
|
||||
FLAGS.train_epochs = 7
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat_mlperf_like(self):
|
||||
"""1 GPU using compile/fit without dist_strat."""
|
||||
self._setup()
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat_run_eagerly_mlperf_like(self):
|
||||
self._setup()
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.run_eagerly = True
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_xla_1_gpu_mlperf_like(self):
|
||||
"""1 GPU using compile/fit with XLA."""
|
||||
self._setup()
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_1_gpu_ctl_mlperf_like(self):
|
||||
"""1 GPU using CTL."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.train_epochs = 7
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_1_gpu_ctl_fp16_mlperf_like(self):
|
||||
"""1 GPU using CTL and FP16."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 8192
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_1_gpu_fp16_mlperf_like(self):
|
||||
"""1 GPU using FP16."""
|
||||
self._setup()
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 8192
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_1_gpu_ctl_fp16_graph_rewrite_mlperf_like(self):
|
||||
"""1 GPU using CTL and FP16 graph rewrite."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
FLAGS.loss_scale = 8192
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_1_gpu_fp16_graph_rewrite_mlperf_like(self):
|
||||
"""1 GPU using FP16 graph rewrite."""
|
||||
self._setup()
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
FLAGS.loss_scale = 8192
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self):
|
||||
"""1 GPU using CTL with eager and distribution strategy."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.run_eagerly = True
|
||||
FLAGS.train_epochs = 7
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_1_gpu_ctl_mlperf_like(self):
|
||||
"""1 GPU using CTL with XLA."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.train_epochs = 7
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_xla_1_gpu_fp16_mlperf_like(self):
|
||||
"""1 GPU using with XLA and FP16."""
|
||||
self._setup()
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 8192
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_xla_1_gpu_ctl_fp16_mlperf_like(self):
|
||||
"""1 GPU using CTL with XLA and FP16."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.train_epochs = 7
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 8192
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_8_gpu_mlperf_like(self):
|
||||
"""8 GPU using keras fit/compile."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.train_epochs = 17
|
||||
FLAGS.batch_size = 1048576
|
||||
FLAGS.eval_batch_size = 160000
|
||||
FLAGS.learning_rate = 0.0045
|
||||
FLAGS.beta1 = 0.25
|
||||
FLAGS.beta2 = 0.5
|
||||
FLAGS.epsilon = 1e-8
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_8_gpu_ctl_mlperf_like(self):
|
||||
"""8 GPU using CTL."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.train_epochs = 17
|
||||
FLAGS.batch_size = 1048576
|
||||
FLAGS.eval_batch_size = 160000
|
||||
FLAGS.learning_rate = 0.0045
|
||||
FLAGS.beta1 = 0.25
|
||||
FLAGS.beta2 = 0.5
|
||||
FLAGS.epsilon = 1e-8
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_8_gpu_tf_data_ctl_mlperf_like(self):
|
||||
"""8 GPU using CTL."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.train_epochs = 17
|
||||
FLAGS.batch_size = 1048576
|
||||
FLAGS.eval_batch_size = 1048000
|
||||
FLAGS.learning_rate = 0.0045
|
||||
FLAGS.beta1 = 0.25
|
||||
FLAGS.beta2 = 0.5
|
||||
FLAGS.epsilon = 1e-8
|
||||
FLAGS.train_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "training_cycle_*/*")
|
||||
FLAGS.eval_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "eval_data/*")
|
||||
FLAGS.input_meta_data_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "meta_data.json")
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_8_gpu_tf_data_fp16_mlperf_like(self):
|
||||
"""8 GPU FP16"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.train_epochs = 17
|
||||
FLAGS.batch_size = 1048576
|
||||
FLAGS.eval_batch_size = 1048000
|
||||
FLAGS.learning_rate = 0.0045
|
||||
FLAGS.beta1 = 0.25
|
||||
FLAGS.beta2 = 0.5
|
||||
FLAGS.epsilon = 1e-8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 8192
|
||||
FLAGS.train_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "training_cycle_*/*")
|
||||
FLAGS.eval_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "eval_data/*")
|
||||
FLAGS.input_meta_data_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "meta_data.json")
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_8_gpu_tf_data_ctl_fp16_mlperf_like(self):
|
||||
"""8 GPU FP16 using CTL"""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.train_epochs = 17
|
||||
FLAGS.batch_size = 1048576
|
||||
FLAGS.eval_batch_size = 1048000
|
||||
FLAGS.learning_rate = 0.0045
|
||||
FLAGS.beta1 = 0.25
|
||||
FLAGS.beta2 = 0.5
|
||||
FLAGS.epsilon = 1e-8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.loss_scale = 8192
|
||||
FLAGS.train_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "training_cycle_*/*")
|
||||
FLAGS.eval_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "eval_data/*")
|
||||
FLAGS.input_meta_data_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "meta_data.json")
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
def benchmark_8_gpu_tf_data_ctl_fp16_graph_rewrite_mlperf_like(self):
|
||||
"""8 GPU FP16 graph rewrite using CTL."""
|
||||
self._setup()
|
||||
FLAGS.keras_use_ctl = True
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.train_epochs = 17
|
||||
FLAGS.batch_size = 1048576
|
||||
FLAGS.eval_batch_size = 1048000
|
||||
FLAGS.learning_rate = 0.0045
|
||||
FLAGS.beta1 = 0.25
|
||||
FLAGS.beta2 = 0.5
|
||||
FLAGS.epsilon = 1e-8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
FLAGS.loss_scale = 8192
|
||||
FLAGS.train_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME,
|
||||
'training_cycle_*/*')
|
||||
FLAGS.eval_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME,
|
||||
'eval_data/*')
|
||||
FLAGS.input_meta_data_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME,
|
||||
'meta_data.json')
|
||||
self._run_and_report_benchmark_mlperf_like()
|
||||
|
||||
|
||||
class NCFKerasSynth(NCFKerasBenchmarkBase):
|
||||
"""Benchmark NCF model using synthetic data."""
|
||||
|
||||
def __init__(self,
|
||||
output_dir=None,
|
||||
default_flags=None,
|
||||
**kwargs):
|
||||
|
||||
default_flags = {}
|
||||
default_flags['dataset'] = 'ml-20m'
|
||||
default_flags['num_gpus'] = 1
|
||||
default_flags['train_epochs'] = 8
|
||||
default_flags['batch_size'] = 99000
|
||||
default_flags['eval_batch_size'] = 160000
|
||||
default_flags['learning_rate'] = 0.00382059
|
||||
default_flags['beta1'] = 0.783529
|
||||
default_flags['beta2'] = 0.909003
|
||||
default_flags['epsilon'] = 1.45439e-07
|
||||
default_flags['layers'] = [256, 256, 128, 64]
|
||||
default_flags['num_factors'] = 64
|
||||
default_flags['hr_threshold'] = 0.635
|
||||
default_flags['use_synthetic_data'] = True
|
||||
|
||||
super(NCFKerasSynth, self).__init__(
|
||||
output_dir=output_dir,
|
||||
default_flags=default_flags,
|
||||
**kwargs)
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
self._setup()
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_2_gpus(self):
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 2
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+91
@@ -0,0 +1,91 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utils for creating PerfZero benchmarks."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
from absl.testing import flagsaver
|
||||
import tensorflow as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class PerfZeroBenchmark(tf.test.Benchmark):
|
||||
"""Common methods used in PerfZero Benchmarks.
|
||||
|
||||
Handles the resetting of flags between tests, loading of default_flags,
|
||||
overriding of defaults. PerfZero (OSS) runs each test in a separate
|
||||
process reducing some need to reset the flags.
|
||||
"""
|
||||
local_flags = None
|
||||
|
||||
def __init__(self,
|
||||
output_dir=None,
|
||||
default_flags=None,
|
||||
flag_methods=None,
|
||||
tpu=None):
|
||||
"""Initialize class.
|
||||
|
||||
Args:
|
||||
output_dir: Base directory to store all output for the test.
|
||||
default_flags: Set of flags to pass to model.
|
||||
flag_methods: Set of flag methods to run during setup.
|
||||
tpu: (optional) TPU name to use in a TPU benchmark.
|
||||
"""
|
||||
if os.getenv('BENCHMARK_OUTPUT_DIR'):
|
||||
self.output_dir = os.getenv('BENCHMARK_OUTPUT_DIR')
|
||||
elif output_dir:
|
||||
self.output_dir = output_dir
|
||||
else:
|
||||
self.output_dir = '/tmp'
|
||||
self.default_flags = default_flags or {}
|
||||
self.flag_methods = flag_methods or {}
|
||||
|
||||
if os.getenv('BENCHMARK_TPU'):
|
||||
resolved_tpu = os.getenv('BENCHMARK_TPU')
|
||||
elif tpu:
|
||||
resolved_tpu = tpu
|
||||
else:
|
||||
resolved_tpu = None
|
||||
|
||||
if resolved_tpu:
|
||||
# TPU models are expected to accept a --tpu=name flag. PerfZero creates
|
||||
# the TPU at runtime and passes the TPU's name to this flag.
|
||||
self.default_flags['tpu'] = resolved_tpu
|
||||
|
||||
def _get_model_dir(self, folder_name):
|
||||
"""Returns directory to store info, e.g. saved model and event log."""
|
||||
return os.path.join(self.output_dir, folder_name)
|
||||
|
||||
def _setup(self):
|
||||
"""Sets up and resets flags before each test."""
|
||||
logging.set_verbosity(logging.INFO)
|
||||
if PerfZeroBenchmark.local_flags is None:
|
||||
for flag_method in self.flag_methods:
|
||||
flag_method()
|
||||
# Loads flags to get defaults to then override. List cannot be empty.
|
||||
flags.FLAGS(['foo'])
|
||||
# Overrides flag values with defaults for the class of tests.
|
||||
for k, v in self.default_flags.items():
|
||||
setattr(FLAGS, k, v)
|
||||
saved_flag_values = flagsaver.save_flag_values()
|
||||
PerfZeroBenchmark.local_flags = saved_flag_values
|
||||
else:
|
||||
flagsaver.restore_flag_values(PerfZeroBenchmark.local_flags)
|
||||
+412
@@ -0,0 +1,412 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes CTL benchmarks and accuracy tests."""
|
||||
# pylint: disable=line-too-long,g-bad-import-order
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from official.vision.image_classification.resnet import common
|
||||
from official.vision.image_classification.resnet import resnet_ctl_imagenet_main
|
||||
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
|
||||
from official.benchmark import benchmark_wrappers
|
||||
from official.utils.flags import core as flags_core
|
||||
|
||||
MIN_TOP_1_ACCURACY = 0.76
|
||||
MAX_TOP_1_ACCURACY = 0.77
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class CtlBenchmark(PerfZeroBenchmark):
|
||||
"""Base benchmark class with methods to simplify testing."""
|
||||
|
||||
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
|
||||
self.output_dir = output_dir
|
||||
self.default_flags = default_flags or {}
|
||||
self.flag_methods = flag_methods or {}
|
||||
super(CtlBenchmark, self).__init__(
|
||||
output_dir=self.output_dir,
|
||||
default_flags=self.default_flags,
|
||||
flag_methods=self.flag_methods)
|
||||
|
||||
def _report_benchmark(self,
|
||||
stats,
|
||||
wall_time_sec,
|
||||
top_1_max=None,
|
||||
top_1_min=None,
|
||||
total_batch_size=None,
|
||||
log_steps=None,
|
||||
warmup=1,
|
||||
start_time_sec=None):
|
||||
"""Report benchmark results by writing to local protobuf file.
|
||||
|
||||
Args:
|
||||
stats: dict returned from keras models with known entries.
|
||||
wall_time_sec: the during of the benchmark execution in seconds
|
||||
top_1_max: highest passing level for top_1 accuracy.
|
||||
top_1_min: lowest passing level for top_1 accuracy.
|
||||
total_batch_size: Global batch-size.
|
||||
log_steps: How often the log was created for stats['step_timestamp_log'].
|
||||
warmup: number of entries in stats['step_timestamp_log'] to ignore.
|
||||
start_time_sec: the start time of the program in seconds since epoch.
|
||||
"""
|
||||
|
||||
metrics = []
|
||||
if 'eval_acc' in stats:
|
||||
metrics.append({
|
||||
'name': 'accuracy_top_1',
|
||||
'value': stats['eval_acc'],
|
||||
'min_value': top_1_min,
|
||||
'max_value': top_1_max
|
||||
})
|
||||
metrics.append({'name': 'eval_loss', 'value': stats['eval_loss']})
|
||||
|
||||
metrics.append({
|
||||
'name': 'top_1_train_accuracy',
|
||||
'value': stats['train_acc']
|
||||
})
|
||||
metrics.append({'name': 'train_loss', 'value': stats['train_loss']})
|
||||
|
||||
if (warmup and 'step_timestamp_log' in stats and
|
||||
len(stats['step_timestamp_log']) > warmup + 1):
|
||||
# first entry in the time_log is start of step 0. The rest of the
|
||||
# entries are the end of each step recorded
|
||||
time_log = stats['step_timestamp_log']
|
||||
steps_elapsed = time_log[-1].batch_index - time_log[warmup].batch_index
|
||||
time_elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
|
||||
examples_per_sec = total_batch_size * (steps_elapsed / time_elapsed)
|
||||
metrics.append({'name': 'exp_per_second', 'value': examples_per_sec})
|
||||
|
||||
if 'avg_exp_per_second' in stats:
|
||||
metrics.append({
|
||||
'name': 'avg_exp_per_second',
|
||||
'value': stats['avg_exp_per_second']
|
||||
})
|
||||
|
||||
if start_time_sec and 'step_timestamp_log' in stats:
|
||||
time_log = stats['step_timestamp_log']
|
||||
# time_log[0] is recorded at the beginning of the first step.
|
||||
startup_time = time_log[0].timestamp - start_time_sec
|
||||
metrics.append({'name': 'startup_time', 'value': startup_time})
|
||||
|
||||
flags_str = flags_core.get_nondefault_flags_as_str()
|
||||
self.report_benchmark(
|
||||
iters=-1,
|
||||
wall_time=wall_time_sec,
|
||||
metrics=metrics,
|
||||
extras={'flags': flags_str})
|
||||
|
||||
|
||||
class Resnet50CtlAccuracy(CtlBenchmark):
|
||||
"""Benchmark accuracy tests for ResNet50 in CTL."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
"""A benchmark class.
|
||||
|
||||
Args:
|
||||
output_dir: directory where to output e.g. log files
|
||||
root_data_dir: directory under which to look for dataset
|
||||
**kwargs: arbitrary named arguments. This is needed to make the
|
||||
constructor forward compatible in case PerfZero provides more named
|
||||
arguments before updating the constructor.
|
||||
"""
|
||||
|
||||
flag_methods = [common.define_keras_flags]
|
||||
|
||||
self.data_dir = os.path.join(root_data_dir, 'imagenet')
|
||||
super(Resnet50CtlAccuracy, self).__init__(
|
||||
output_dir=output_dir, flag_methods=flag_methods)
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Test Keras model with eager, dist_strat and 8 GPUs."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 128 * 8
|
||||
FLAGS.train_epochs = 90
|
||||
FLAGS.epochs_between_evals = 10
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
|
||||
FLAGS.dtype = 'fp32'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_fp16(self):
|
||||
"""Test Keras model with eager, 8 GPUs with tf.keras mixed precision."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 256 * 8
|
||||
FLAGS.train_epochs = 90
|
||||
FLAGS.epochs_between_evals = 10
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
|
||||
FLAGS.dtype = 'fp16'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_amp(self):
|
||||
"""Test Keras model with 8 GPUs and mixed precision via graph rewrite."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.data_dir = self.data_dir
|
||||
FLAGS.batch_size = 256 * 8
|
||||
FLAGS.train_epochs = 90
|
||||
FLAGS.epochs_between_evals = 10
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self):
|
||||
start_time_sec = time.time()
|
||||
stats = resnet_ctl_imagenet_main.run(flags.FLAGS)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
super(Resnet50CtlAccuracy, self)._report_benchmark(
|
||||
stats,
|
||||
wall_time_sec,
|
||||
top_1_min=MIN_TOP_1_ACCURACY,
|
||||
top_1_max=MAX_TOP_1_ACCURACY,
|
||||
total_batch_size=FLAGS.batch_size,
|
||||
log_steps=100,
|
||||
start_time_sec=start_time_sec)
|
||||
|
||||
def _get_model_dir(self, folder_name):
|
||||
return os.path.join(self.output_dir, folder_name)
|
||||
|
||||
|
||||
class Resnet50CtlBenchmarkBase(CtlBenchmark):
|
||||
"""Resnet50 benchmarks."""
|
||||
|
||||
def __init__(self, output_dir=None, default_flags=None):
|
||||
flag_methods = [common.define_keras_flags]
|
||||
|
||||
super(Resnet50CtlBenchmarkBase, self).__init__(
|
||||
output_dir=output_dir,
|
||||
flag_methods=flag_methods,
|
||||
default_flags=default_flags)
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self):
|
||||
start_time_sec = time.time()
|
||||
stats = resnet_ctl_imagenet_main.run(FLAGS)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
# Number of logged step time entries that are excluded in performance
|
||||
# report. We keep results from last 100 batches in this case.
|
||||
warmup = (FLAGS.train_steps - 100) // FLAGS.log_steps
|
||||
|
||||
super(Resnet50CtlBenchmarkBase, self)._report_benchmark(
|
||||
stats,
|
||||
wall_time_sec,
|
||||
total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
warmup=warmup,
|
||||
start_time_sec=start_time_sec)
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat(self):
|
||||
"""Test Keras model with 1 GPU, no distribution strategy."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
"""Test Keras model with 1 GPU."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_fp16(self):
|
||||
"""Test Keras model with 1 GPU with tf.keras mixed precision."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
|
||||
FLAGS.batch_size = 256
|
||||
FLAGS.dtype = 'fp16'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_amp(self):
|
||||
"""Test Keras model with 1 GPU with automatic mixed precision."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
|
||||
FLAGS.batch_size = 256
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_1_gpu_amp(self):
|
||||
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
|
||||
FLAGS.batch_size = 256
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_eager(self):
|
||||
"""Test Keras model with 1 GPU in pure eager mode."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
|
||||
FLAGS.batch_size = 120
|
||||
FLAGS.use_tf_function = False
|
||||
FLAGS.use_tf_while_loop = False
|
||||
FLAGS.single_l2_loss_op = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_fp16_eager(self):
|
||||
"""Test Keras model with 1 GPU with fp16 and pure eager mode."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'one_device'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_eager')
|
||||
FLAGS.batch_size = 240
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.use_tf_function = False
|
||||
FLAGS.use_tf_while_loop = False
|
||||
FLAGS.single_l2_loss_op = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Test Keras model with 8 GPUs."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.distribution_strategy = 'mirrored'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
|
||||
FLAGS.batch_size = 128 * 8 # 8 GPUs
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_fp16(self):
|
||||
"""Test Keras model with 8 GPUs with tf.keras mixed precision."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.distribution_strategy = 'mirrored'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
|
||||
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
||||
FLAGS.dtype = 'fp16'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_eager(self):
|
||||
"""Test Keras model with 8 GPUs, eager, fp32."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.use_tf_function = False
|
||||
FLAGS.use_tf_while_loop = False
|
||||
FLAGS.distribution_strategy = 'mirrored'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_eager_fp16(self):
|
||||
"""Test Keras model with 8 GPUs, eager, fp16."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.use_tf_function = False
|
||||
FLAGS.use_tf_while_loop = False
|
||||
FLAGS.distribution_strategy = 'mirrored'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager_fp16')
|
||||
FLAGS.batch_size = 128
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_amp(self):
|
||||
"""Test Keras model with 8 GPUs with automatic mixed precision."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.distribution_strategy = 'mirrored'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
|
||||
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_8_gpu_amp(self):
|
||||
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
|
||||
self._setup()
|
||||
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.distribution_strategy = 'mirrored'
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
|
||||
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def fill_report_object(self, stats):
|
||||
super(Resnet50CtlBenchmarkBase, self).fill_report_object(
|
||||
stats, total_batch_size=FLAGS.batch_size, log_steps=FLAGS.log_steps)
|
||||
|
||||
|
||||
class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
|
||||
"""Resnet50 synthetic benchmark tests."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
def_flags = {}
|
||||
def_flags['skip_eval'] = True
|
||||
def_flags['use_synthetic_data'] = True
|
||||
def_flags['train_steps'] = 110
|
||||
def_flags['steps_per_loop'] = 20
|
||||
def_flags['log_steps'] = 10
|
||||
|
||||
super(Resnet50CtlBenchmarkSynth, self).__init__(
|
||||
output_dir=output_dir, default_flags=def_flags)
|
||||
|
||||
|
||||
class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
|
||||
"""Resnet50 real data benchmark tests."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
def_flags = {}
|
||||
def_flags['skip_eval'] = True
|
||||
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
|
||||
def_flags['train_steps'] = 110
|
||||
def_flags['steps_per_loop'] = 20
|
||||
def_flags['log_steps'] = 10
|
||||
|
||||
super(Resnet50CtlBenchmarkReal, self).__init__(
|
||||
output_dir=output_dir, default_flags=def_flags)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+294
@@ -0,0 +1,294 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes RetinaNet benchmarks and accuracy tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
from absl.testing import flagsaver
|
||||
import tensorflow as tf
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
from official.benchmark import bert_benchmark_utils as benchmark_utils
|
||||
from official.utils.flags import core as flags_core
|
||||
from official.benchmark import benchmark_wrappers
|
||||
from official.vision.detection import main as detection
|
||||
|
||||
TMP_DIR = os.getenv('TMPDIR')
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
COCO_TRAIN_DATA = 'gs://tf-perfzero-data/coco/train*'
|
||||
COCO_EVAL_DATA = 'gs://tf-perfzero-data/coco/val*'
|
||||
COCO_EVAL_JSON = 'gs://tf-perfzero-data/coco/instances_val2017.json'
|
||||
RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoint-2018-02-07'
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
class DetectionBenchmarkBase(tf.test.Benchmark):
|
||||
"""Base class to hold methods common to test classes."""
|
||||
local_flags = None
|
||||
|
||||
def __init__(self, output_dir=None):
|
||||
self.num_gpus = 8
|
||||
|
||||
if not output_dir:
|
||||
output_dir = '/tmp'
|
||||
self.output_dir = output_dir
|
||||
self.timer_callback = None
|
||||
|
||||
def _get_model_dir(self, folder_name):
|
||||
"""Returns directory to store info, e.g. saved model and event log."""
|
||||
return os.path.join(self.output_dir, folder_name)
|
||||
|
||||
def _setup(self):
|
||||
"""Sets up and resets flags before each test."""
|
||||
self.timer_callback = benchmark_utils.BenchmarkTimerCallback()
|
||||
|
||||
if DetectionBenchmarkBase.local_flags is None:
|
||||
# Loads flags to get defaults to then override. List cannot be empty.
|
||||
flags.FLAGS(['foo'])
|
||||
saved_flag_values = flagsaver.save_flag_values()
|
||||
DetectionBenchmarkBase.local_flags = saved_flag_values
|
||||
else:
|
||||
flagsaver.restore_flag_values(DetectionBenchmarkBase.local_flags)
|
||||
|
||||
def _report_benchmark(self,
|
||||
stats,
|
||||
wall_time_sec,
|
||||
min_ap,
|
||||
max_ap,
|
||||
train_batch_size=None):
|
||||
"""Report benchmark results by writing to local protobuf file.
|
||||
|
||||
Args:
|
||||
stats: dict returned from Detection models with known entries.
|
||||
wall_time_sec: the during of the benchmark execution in seconds
|
||||
min_ap: Minimum detection AP constraint to verify correctness of the
|
||||
model.
|
||||
max_ap: Maximum detection AP accuracy constraint to verify correctness of
|
||||
the model.
|
||||
train_batch_size: Train batch size. It is needed for computing
|
||||
exp_per_second.
|
||||
"""
|
||||
metrics = [{
|
||||
'name': 'total_loss',
|
||||
'value': stats['total_loss'],
|
||||
}]
|
||||
if self.timer_callback:
|
||||
metrics.append({
|
||||
'name': 'exp_per_second',
|
||||
'value': self.timer_callback.get_examples_per_sec(train_batch_size)
|
||||
})
|
||||
else:
|
||||
metrics.append({
|
||||
'name': 'exp_per_second',
|
||||
'value': 0.0,
|
||||
})
|
||||
|
||||
if 'eval_metrics' in stats:
|
||||
metrics.append({
|
||||
'name': 'AP',
|
||||
'value': stats['AP'],
|
||||
'min_value': min_ap,
|
||||
'max_value': max_ap,
|
||||
})
|
||||
flags_str = flags_core.get_nondefault_flags_as_str()
|
||||
self.report_benchmark(
|
||||
iters=stats['total_steps'],
|
||||
wall_time=wall_time_sec,
|
||||
metrics=metrics,
|
||||
extras={'flags': flags_str})
|
||||
|
||||
|
||||
class RetinanetBenchmarkBase(DetectionBenchmarkBase):
|
||||
"""Base class to hold methods common to test classes in the module."""
|
||||
|
||||
def __init__(self, output_dir=None, **kwargs):
|
||||
self.train_data_path = COCO_TRAIN_DATA
|
||||
self.eval_data_path = COCO_EVAL_DATA
|
||||
self.eval_json_path = COCO_EVAL_JSON
|
||||
self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH
|
||||
|
||||
super(RetinanetBenchmarkBase, self).__init__(output_dir=output_dir)
|
||||
|
||||
def _run_detection_main(self):
|
||||
"""Starts detection job."""
|
||||
if self.timer_callback:
|
||||
return detection.run(callbacks=[self.timer_callback])
|
||||
else:
|
||||
return detection.run()
|
||||
|
||||
|
||||
class RetinanetAccuracy(RetinanetBenchmarkBase):
|
||||
"""Accuracy test for RetinaNet model.
|
||||
|
||||
Tests RetinaNet detection task model accuracy. The naming
|
||||
convention of below test cases follow
|
||||
`benchmark_(number of gpus)_gpu_(dataset type)` format.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=TMP_DIR, **kwargs):
|
||||
super(RetinanetAccuracy, self).__init__(output_dir=output_dir)
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self, min_ap=0.325, max_ap=0.35):
|
||||
"""Starts RetinaNet accuracy benchmark test."""
|
||||
|
||||
start_time_sec = time.time()
|
||||
FLAGS.mode = 'train'
|
||||
summary, _ = self._run_detection_main()
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
FLAGS.mode = 'eval'
|
||||
eval_metrics = self._run_detection_main()
|
||||
summary.update(eval_metrics)
|
||||
|
||||
summary['train_batch_size'] = self.params_override['train']['batch_size']
|
||||
summary['total_steps'] = self.params_override['train']['total_steps']
|
||||
super(RetinanetAccuracy, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_ap=min_ap,
|
||||
max_ap=max_ap,
|
||||
train_batch_size=self.params_override['train']['batch_size'])
|
||||
|
||||
def _setup(self):
|
||||
super(RetinanetAccuracy, self)._setup()
|
||||
FLAGS.strategy_type = 'mirrored'
|
||||
FLAGS.model = 'retinanet'
|
||||
|
||||
self.params_override = {
|
||||
'train': {
|
||||
'batch_size': 64,
|
||||
'iterations_per_loop': 100,
|
||||
'total_steps': 22500,
|
||||
'train_file_pattern': self.train_data_path,
|
||||
'checkpoint': {
|
||||
'path': self.resnet_checkpoint_path,
|
||||
'prefix': 'resnet50/'
|
||||
},
|
||||
},
|
||||
'eval': {
|
||||
'batch_size': 8,
|
||||
'eval_samples': 5000,
|
||||
'val_json_file': self.eval_json_path,
|
||||
'eval_file_pattern': self.eval_data_path,
|
||||
},
|
||||
}
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def benchmark_8_gpu_coco(self):
|
||||
"""Run RetinaNet model accuracy test with 8 GPUs."""
|
||||
self._setup()
|
||||
params = copy.deepcopy(self.params_override)
|
||||
FLAGS.params_override = json.dumps(params)
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_coco')
|
||||
# Sets timer_callback to None as we do not use it now.
|
||||
self.timer_callback = None
|
||||
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
|
||||
class RetinanetBenchmarkReal(RetinanetAccuracy):
|
||||
"""Short benchmark performance tests for RetinaNet model.
|
||||
|
||||
Tests RetinaNet performance in different GPU configurations.
|
||||
The naming convention of below test cases follow
|
||||
`benchmark_(number of gpus)_gpu` format.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=TMP_DIR, **kwargs):
|
||||
super(RetinanetBenchmarkReal, self).__init__(output_dir=output_dir)
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def benchmark_8_gpu_coco(self):
|
||||
"""Run RetinaNet model accuracy test with 8 GPUs."""
|
||||
self.num_gpus = 8
|
||||
self._setup()
|
||||
params = copy.deepcopy(self.params_override)
|
||||
params['train']['total_steps'] = 1875 # One epoch.
|
||||
# The iterations_per_loop must be one, otherwise the number of examples per
|
||||
# second would be wrong. Currently only support calling callback per batch
|
||||
# when each loop only runs on one batch, i.e. host loop for one step. The
|
||||
# performance of this situation might be lower than the case of
|
||||
# iterations_per_loop > 1.
|
||||
# Related bug: b/135933080
|
||||
params['train']['iterations_per_loop'] = 1
|
||||
params['eval']['eval_samples'] = 8
|
||||
FLAGS.num_gpus = self.num_gpus
|
||||
FLAGS.params_override = json.dumps(params)
|
||||
FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco')
|
||||
# Use negative value to avoid saving checkpoints.
|
||||
FLAGS.save_checkpoint_freq = -1
|
||||
if self.timer_callback is None:
|
||||
logging.error('Cannot measure performance without timer callback')
|
||||
else:
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def benchmark_1_gpu_coco(self):
|
||||
"""Run RetinaNet model accuracy test with 1 GPU."""
|
||||
self.num_gpus = 1
|
||||
self._setup()
|
||||
params = copy.deepcopy(self.params_override)
|
||||
params['train']['batch_size'] = 8
|
||||
params['train']['total_steps'] = 200
|
||||
params['train']['iterations_per_loop'] = 1
|
||||
params['eval']['eval_samples'] = 8
|
||||
FLAGS.num_gpus = self.num_gpus
|
||||
FLAGS.params_override = json.dumps(params)
|
||||
FLAGS.model_dir = self._get_model_dir('real_benchmark_1_gpu_coco')
|
||||
FLAGS.strategy_type = 'one_device'
|
||||
# Use negative value to avoid saving checkpoints.
|
||||
FLAGS.save_checkpoint_freq = -1
|
||||
if self.timer_callback is None:
|
||||
logging.error('Cannot measure performance without timer callback')
|
||||
else:
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def benchmark_xla_1_gpu_coco(self):
|
||||
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled."""
|
||||
self.num_gpus = 1
|
||||
self._setup()
|
||||
params = copy.deepcopy(self.params_override)
|
||||
params['train']['batch_size'] = 8
|
||||
params['train']['total_steps'] = 200
|
||||
params['train']['iterations_per_loop'] = 1
|
||||
params['eval']['eval_samples'] = 8
|
||||
FLAGS.num_gpus = self.num_gpus
|
||||
FLAGS.params_override = json.dumps(params)
|
||||
FLAGS.model_dir = self._get_model_dir('real_benchmark_1_gpu_coco')
|
||||
FLAGS.strategy_type = 'one_device'
|
||||
FLAGS.enable_xla = True
|
||||
# Use negative value to avoid saving checkpoints.
|
||||
FLAGS.save_checkpoint_freq = -1
|
||||
if self.timer_callback is None:
|
||||
logging.error('Cannot measure performance without timer callback')
|
||||
else:
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+359
@@ -0,0 +1,359 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes Shakespeare (LSTM) benchmark and accuracy tests."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf # pylint: disable=g-bad-import-order
|
||||
|
||||
from official.benchmark.models.shakespeare import shakespeare_main
|
||||
from official.utils.flags import core as flags_core
|
||||
from official.utils.misc import keras_utils
|
||||
from official.benchmark import benchmark_wrappers
|
||||
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
|
||||
|
||||
SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt'
|
||||
TMP_DIR = os.getenv('TMPDIR')
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class ShakespeareBenchmarkBase(PerfZeroBenchmark):
|
||||
"""Base class for Shakespeare (LSTM) benchmark and accuracy tests."""
|
||||
|
||||
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None):
|
||||
super(ShakespeareBenchmarkBase, self).__init__(
|
||||
output_dir=output_dir,
|
||||
default_flags=default_flags,
|
||||
flag_methods=[shakespeare_main.define_flags])
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
top_1_train_min=0.91,
|
||||
top_1_train_max=0.94,
|
||||
warmup=1,
|
||||
log_steps=100):
|
||||
"""Report benchmark results by writing to local protobuf file.
|
||||
|
||||
Average epoch time is calculated by skipping the first epoch. This average
|
||||
ignores time spent between epoch and is recorded by begin and end epoch. To
|
||||
skip accuracy check set `top_1_train_min=None`.
|
||||
|
||||
Args:
|
||||
top_1_train_min: lowest passing value.
|
||||
top_1_train_max: highest passing value.
|
||||
warmup: number of entries in `timestamp_log` to ignore.
|
||||
log_steps: How often the log was created for `timestamp_log`.
|
||||
"""
|
||||
total_batch_size = FLAGS.batch_size
|
||||
metrics = []
|
||||
start_time_sec = time.time()
|
||||
stats = shakespeare_main.run(FLAGS)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
if top_1_train_min:
|
||||
metrics.append({'name': 'accuracy_top_1_train',
|
||||
'value': stats['history']['RecallAt1'][-1],
|
||||
'min_value': top_1_train_min,
|
||||
'max_value': top_1_train_max})
|
||||
|
||||
# Look for the time history callback which was used during keras.fit
|
||||
for callback in stats['callbacks']:
|
||||
if isinstance(callback, keras_utils.TimeHistory):
|
||||
epoch_timings = callback.epoch_runtime_log
|
||||
if len(epoch_timings) > 1:
|
||||
average_time = sum(epoch_timings[1:]) / len(epoch_timings[1:])
|
||||
metrics.append({'name': 'avg_epoch_time',
|
||||
'value': average_time})
|
||||
|
||||
# First entry in timestamp_log is the start of step 1. The rest of the
|
||||
# entries are the end of each step recorded.
|
||||
time_log = callback.timestamp_log
|
||||
elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
|
||||
num_examples = (
|
||||
total_batch_size * log_steps * (len(time_log) - warmup - 1))
|
||||
if elapsed > 0:
|
||||
examples_per_sec = num_examples / elapsed
|
||||
metrics.append({'name': 'exp_per_second',
|
||||
'value': examples_per_sec})
|
||||
|
||||
flags_str = flags_core.get_nondefault_flags_as_str()
|
||||
self.report_benchmark(iters=-1, wall_time=wall_time_sec,
|
||||
metrics=metrics,
|
||||
extras={'flags': flags_str})
|
||||
|
||||
|
||||
class ShakespeareAccuracy(ShakespeareBenchmarkBase):
|
||||
"""Shakespeare accuracy tests.
|
||||
|
||||
This is not an ideal test. The best we can use for the accuracy check is to
|
||||
validate top_1 of the training set. At batch size 64 the top_1 training
|
||||
stabilizes to ~0.92 around 40-45 epochs.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
"""Shakespeare accuracy tests.
|
||||
|
||||
Args:
|
||||
output_dir: directory where to output e.g. log files
|
||||
root_data_dir: directory under which to look for dataset
|
||||
**kwargs: arbitrary named arguments. This is needed to make the
|
||||
constructor forward compatible in case PerfZero provides more
|
||||
named arguments before updating the constructor.
|
||||
"""
|
||||
self.train_data = os.path.join(root_data_dir, SHAKESPEARE_TRAIN_DATA)
|
||||
super(ShakespeareAccuracy, self).__init__(
|
||||
output_dir=output_dir, root_data_dir=root_data_dir)
|
||||
|
||||
def benchmark_cpu(self):
|
||||
"""Benchmark cpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.training_data = self.train_data
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.train_epochs = 43
|
||||
FLAGS.model_dir = ''
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu_no_ds_run_eagerly(self):
|
||||
"""Benchmark cpu without distribution strategies and run eagerly."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.training_data = self.train_data
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.train_epochs = 43
|
||||
FLAGS.model_dir = ''
|
||||
FLAGS.run_eagerly = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
"""Benchmark 1 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.training_data = self.train_data
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.train_epochs = 43
|
||||
FLAGS.model_dir = ''
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_ds(self):
|
||||
"""Benchmark 1 gpu without distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.training_data = self.train_data
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.train_epochs = 43
|
||||
FLAGS.model_dir = ''
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_ds_run_eagerly(self):
|
||||
"""Benchmark 1 gpu without distribution strategies and run eagerly."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.training_data = self.train_data
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.train_epochs = 43
|
||||
FLAGS.model_dir = ''
|
||||
FLAGS.run_eagerly = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_1_gpu(self):
|
||||
"""Benchmark 1 gpu w/xla."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.training_data = self.train_data
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.train_epochs = 43
|
||||
FLAGS.model_dir = ''
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Benchmark 8 gpu.
|
||||
|
||||
This is test is for accuracy not scaling. The batch-size is not scaled to
|
||||
the number of gpus.
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.training_data = self.train_data
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.train_epochs = 43
|
||||
FLAGS.model_dir = ''
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
|
||||
class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
|
||||
"""Benchmark accuracy tests."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=TMP_DIR, **kwargs):
|
||||
"""Benchmark tests w/Keras.
|
||||
|
||||
Args:
|
||||
output_dir: directory where to output e.g. log files
|
||||
root_data_dir: directory under which to look for dataset
|
||||
**kwargs: arbitrary named arguments. This is needed to make the
|
||||
constructor forward compatible in case PerfZero provides more
|
||||
named arguments before updating the constructor.
|
||||
"""
|
||||
self.train_data = os.path.join(root_data_dir, SHAKESPEARE_TRAIN_DATA)
|
||||
|
||||
def_flags = {}
|
||||
def_flags['training_data'] = self.train_data
|
||||
def_flags['model_dir'] = ''
|
||||
def_flags['train_epochs'] = 4
|
||||
def_flags['log_steps'] = 50
|
||||
|
||||
super(ShakespeareKerasBenchmarkReal, self).__init__(
|
||||
output_dir=output_dir,
|
||||
root_data_dir=root_data_dir,
|
||||
default_flags=def_flags)
|
||||
|
||||
def benchmark_cpu(self):
|
||||
"""Benchmark cpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.batch_size = 64
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu_no_ds_run_eagerly(self):
|
||||
"""Benchmark cpu without distribution strategy and run eagerly."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.run_eagerly = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu_no_ds(self):
|
||||
"""Benchmark cpu without distribution strategy."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_cpu_no_ds_force_v2(self):
|
||||
"""Benchmark cpu no ds, and force v2."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 0
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
"""Benchmark 1 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = 64
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_cudnn(self):
|
||||
"""Benchmark 1 gpu with CuDNN disabled."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.cudnn = False
|
||||
FLAGS.enable_eager = keras_utils.is_v2_0()
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_ds(self):
|
||||
"""Benchmark 1 gpu without distribution strategies."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_1_gpu_no_ds_run_eagerly(self):
|
||||
"""Benchmark 1 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.run_eagerly = True
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_1_gpu(self):
|
||||
"""Benchmark 1 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_1_gpu_no_cudnn(self):
|
||||
"""Benchmark 1 gpu w/xla and CuDNN disabled."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = 64
|
||||
FLAGS.cudnn = False
|
||||
FLAGS.enable_eager = keras_utils.is_v2_0()
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Benchmark 8 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.batch_size = 64 * 8
|
||||
FLAGS.log_steps = 10
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_8_gpu_no_cudnn(self):
|
||||
"""Benchmark 8 gpu with CuDNN disabled."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.batch_size = 64 * 8
|
||||
FLAGS.log_steps = 10
|
||||
FLAGS.cudnn = False
|
||||
FLAGS.enable_eager = keras_utils.is_v2_0()
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_8_gpu(self):
|
||||
"""Benchmark 8 gpu w/xla."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = 64 * 8
|
||||
FLAGS.log_steps = 10
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def benchmark_xla_8_gpu_no_cudnn(self):
|
||||
"""Benchmark 8 gpu w/xla and CuDNN disabled."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.batch_size = 64 * 8
|
||||
FLAGS.log_steps = 10
|
||||
FLAGS.cudnn = False
|
||||
FLAGS.enable_eager = keras_utils.is_v2_0()
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark()
|
||||
|
||||
def _run_and_report_benchmark(self):
|
||||
"""Run and report benchmark."""
|
||||
super(ShakespeareKerasBenchmarkReal, self)._run_and_report_benchmark(
|
||||
top_1_train_min=None, log_steps=FLAGS.log_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+69
@@ -0,0 +1,69 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Runs a memory usage benchmark for a Tensorflow Hub model.
|
||||
|
||||
Loads a SavedModel and records memory usage.
|
||||
"""
|
||||
import functools
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class TfHubMemoryUsageBenchmark(PerfZeroBenchmark):
|
||||
"""A benchmark measuring memory usage for a given TF Hub SavedModel."""
|
||||
|
||||
def __init__(self,
|
||||
hub_model_handle_list=None,
|
||||
output_dir=None,
|
||||
default_flags=None,
|
||||
root_data_dir=None,
|
||||
**kwargs):
|
||||
super(TfHubMemoryUsageBenchmark, self).__init__(
|
||||
output_dir=output_dir, default_flags=default_flags, **kwargs)
|
||||
if hub_model_handle_list:
|
||||
for hub_model_handle in hub_model_handle_list.split(';'):
|
||||
# Converts a model handle of the form
|
||||
# https://tfhub.dev/google/nnlm-en-dim128/1 to valid python method name
|
||||
# like google_nnlm_en_dim128_1.
|
||||
hub_model_method_name = hub_model_handle.replace(
|
||||
'https://tfhub.dev',
|
||||
'').replace('/', '_').replace('-', '_').strip('_')
|
||||
setattr(
|
||||
self, 'benchmark_' + hub_model_method_name,
|
||||
functools.partial(self.benchmark_memory_usage, hub_model_handle))
|
||||
|
||||
def benchmark_memory_usage(
|
||||
self, hub_model_handle='https://tfhub.dev/google/nnlm-en-dim128/1'):
|
||||
start_time_sec = time.time()
|
||||
self.load_model(hub_model_handle)
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
metrics = []
|
||||
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)
|
||||
|
||||
def load_model(self, hub_model_handle):
|
||||
"""Loads a TF Hub module."""
|
||||
hub.load(hub_model_handle)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+681
@@ -0,0 +1,681 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes Transformer w/Keras benchmark and accuracy tests."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
from official.benchmark import benchmark_wrappers
|
||||
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
|
||||
from official.nlp.transformer import misc
|
||||
from official.nlp.transformer import transformer_main as transformer_main
|
||||
from official.utils.flags import core as flags_core
|
||||
|
||||
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
|
||||
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
|
||||
FLAGS = flags.FLAGS
|
||||
TMP_DIR = os.getenv('TMPDIR')
|
||||
|
||||
|
||||
class TransformerBenchmark(PerfZeroBenchmark):
|
||||
"""Methods common to executing transformer w/keras tests.
|
||||
|
||||
Code under test for the Transformer Keras models report the same data and
|
||||
require the same FLAG setup.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
|
||||
flag_methods=None):
|
||||
root_data_dir = root_data_dir if root_data_dir else ''
|
||||
|
||||
self.train_data_dir = os.path.join(root_data_dir,
|
||||
TRANSFORMER_EN2DE_DATA_DIR_NAME)
|
||||
|
||||
self.vocab_file = os.path.join(root_data_dir,
|
||||
TRANSFORMER_EN2DE_DATA_DIR_NAME,
|
||||
'vocab.ende.32768')
|
||||
|
||||
self.bleu_source = os.path.join(root_data_dir,
|
||||
EN2DE_2014_BLEU_DATA_DIR_NAME,
|
||||
'newstest2014.en')
|
||||
|
||||
self.bleu_ref = os.path.join(root_data_dir,
|
||||
EN2DE_2014_BLEU_DATA_DIR_NAME,
|
||||
'newstest2014.de')
|
||||
|
||||
if default_flags is None:
|
||||
default_flags = {}
|
||||
default_flags['data_dir'] = self.train_data_dir
|
||||
default_flags['vocab_file'] = self.vocab_file
|
||||
|
||||
super(TransformerBenchmark, self).__init__(
|
||||
output_dir=output_dir,
|
||||
default_flags=default_flags,
|
||||
flag_methods=flag_methods)
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
bleu_max=None,
|
||||
bleu_min=None,
|
||||
log_steps=None,
|
||||
total_batch_size=None,
|
||||
warmup=1):
|
||||
"""Report benchmark results by writing to local protobuf file.
|
||||
|
||||
Args:
|
||||
bleu_max: highest passing level for bleu score.
|
||||
bleu_min: lowest passing level for bleu score.
|
||||
log_steps: How often the log was created for stats['step_timestamp_log'].
|
||||
total_batch_size: Global batch-size.
|
||||
warmup: number of entries in stats['step_timestamp_log'] to ignore.
|
||||
"""
|
||||
start_time_sec = time.time()
|
||||
task = transformer_main.TransformerTask(FLAGS)
|
||||
stats = task.train()
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
metrics = []
|
||||
if 'bleu_uncased' in stats:
|
||||
if 'bleu_uncased_history' in stats:
|
||||
bleu_uncased_best = max(stats['bleu_uncased_history'],
|
||||
key=lambda x: x[1])
|
||||
metrics.append({'name': 'bleu_uncased',
|
||||
'value': bleu_uncased_best[1],
|
||||
'min_value': bleu_min,
|
||||
'max_value': bleu_max})
|
||||
metrics.append({'name': 'bleu_best_score_iteration',
|
||||
'value': bleu_uncased_best[0]})
|
||||
metrics.append({'name': 'bleu_uncased_last',
|
||||
'value': stats['bleu_uncased']})
|
||||
else:
|
||||
metrics.append({'name': 'bleu_uncased',
|
||||
'value': stats['bleu_uncased'],
|
||||
'min_value': bleu_min,
|
||||
'max_value': bleu_max})
|
||||
|
||||
if (warmup and 'step_timestamp_log' in stats and
|
||||
len(stats['step_timestamp_log']) > warmup):
|
||||
# first entry in the time_log is start of step 1. The rest of the
|
||||
# entries are the end of each step recorded
|
||||
time_log = stats['step_timestamp_log']
|
||||
elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
|
||||
num_examples = (
|
||||
total_batch_size * log_steps * (len(time_log) - warmup - 1))
|
||||
examples_per_sec = num_examples / elapsed
|
||||
metrics.append({'name': 'exp_per_second',
|
||||
'value': examples_per_sec})
|
||||
|
||||
if 'avg_exp_per_second' in stats:
|
||||
metrics.append({'name': 'avg_exp_per_second',
|
||||
'value': stats['avg_exp_per_second']})
|
||||
|
||||
flags_str = flags_core.get_nondefault_flags_as_str()
|
||||
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics,
|
||||
extras={'flags': flags_str})
|
||||
|
||||
|
||||
class TransformerBaseKerasAccuracy(TransformerBenchmark):
|
||||
"""Benchmark accuracy tests for Transformer Base model w/ Keras."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
"""Benchmark accuracy tests for Transformer Base model w/ Keras.
|
||||
|
||||
Args:
|
||||
output_dir: directory where to output e.g. log files
|
||||
root_data_dir: directory under which to look for dataset
|
||||
**kwargs: arbitrary named arguments. This is needed to make the
|
||||
constructor forward compatible in case PerfZero provides more
|
||||
named arguments before updating the constructor.
|
||||
"""
|
||||
flag_methods = [misc.define_transformer_flags]
|
||||
|
||||
super(TransformerBaseKerasAccuracy, self).__init__(
|
||||
output_dir=output_dir, root_data_dir=root_data_dir,
|
||||
flag_methods=flag_methods)
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
"""Benchmark 1 gpu.
|
||||
|
||||
The paper uses 8 GPUs and a much larger effective batch size, this is will
|
||||
not converge to the 27.3 BLEU (uncased) SOTA.
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'base'
|
||||
FLAGS.batch_size = 2048
|
||||
FLAGS.train_steps = 1000
|
||||
FLAGS.steps_between_evals = 500
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
||||
# These bleu scores are based on test runs after at this limited
|
||||
# number of steps and batch size after verifying SOTA at 8xV100s.
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=25.3,
|
||||
bleu_max=26)
|
||||
|
||||
def benchmark_1_gpu_static_batch(self):
|
||||
"""Benchmark 1 gpu with static_batch.
|
||||
|
||||
The paper uses 8 GPUs and a much larger effective batch size, this is will
|
||||
not converge to the 27.3 BLEU (uncased) SOTA.
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'base'
|
||||
FLAGS.batch_size = 4096
|
||||
FLAGS.train_steps = 100000
|
||||
FLAGS.steps_between_evals = 5000
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
|
||||
# These bleu scores are based on test runs after at this limited
|
||||
# number of steps and batch size after verifying SOTA at 8xV100s.
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=25.3,
|
||||
bleu_max=26)
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Benchmark 8 gpu.
|
||||
|
||||
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'base'
|
||||
FLAGS.batch_size = 4096*8
|
||||
FLAGS.train_steps = 100000
|
||||
FLAGS.steps_between_evals = 20000
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=27,
|
||||
bleu_max=28)
|
||||
|
||||
def benchmark_8_gpu_static_batch(self):
|
||||
"""Benchmark 8 gpu.
|
||||
|
||||
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'base'
|
||||
FLAGS.batch_size = 4096*8
|
||||
FLAGS.train_steps = 100000
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
FLAGS.steps_between_evals = 5000
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=27,
|
||||
bleu_max=28)
|
||||
|
||||
|
||||
class TransformerBigKerasAccuracy(TransformerBenchmark):
|
||||
"""Benchmark accuracy tests for Transformer Big model w/ Keras."""
|
||||
|
||||
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
||||
"""Benchmark accuracy tests for Transformer Big model w/ Keras.
|
||||
|
||||
Args:
|
||||
output_dir: directory where to output e.g. log files
|
||||
root_data_dir: directory under which to look for dataset
|
||||
**kwargs: arbitrary named arguments. This is needed to make the
|
||||
constructor forward compatible in case PerfZero provides more
|
||||
named arguments before updating the constructor.
|
||||
"""
|
||||
flag_methods = [misc.define_transformer_flags]
|
||||
|
||||
super(TransformerBigKerasAccuracy, self).__init__(
|
||||
output_dir=output_dir, root_data_dir=root_data_dir,
|
||||
flag_methods=flag_methods)
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Benchmark 8 gpu.
|
||||
|
||||
Over 6 runs with eval every 20K steps the average highest value was 28.195
|
||||
(bleu uncased). 28.424 was the highest and 27.96 the lowest. The values are
|
||||
the highest value seen during a run and occurred at a median of iteration 9.
|
||||
Iterations are not epochs, an iteration is a number of steps between evals.
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'big'
|
||||
FLAGS.batch_size = 3072*8
|
||||
FLAGS.train_steps = 20000 * 12
|
||||
FLAGS.steps_between_evals = 20000
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=27.9,
|
||||
bleu_max=29.2)
|
||||
|
||||
def benchmark_8_gpu_static_batch(self):
|
||||
"""Benchmark 8 gpu.
|
||||
|
||||
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'big'
|
||||
FLAGS.batch_size = 3072*8
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
FLAGS.train_steps = 20000 * 12
|
||||
FLAGS.steps_between_evals = 20000
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=28,
|
||||
bleu_max=29.2)
|
||||
|
||||
def benchmark_8_gpu_fp16(self):
|
||||
"""Benchmark 8 gpu with dynamic batch and fp16.
|
||||
|
||||
Over 6 runs with eval every 20K steps the average highest value was 28.247
|
||||
(bleu uncased). 28.424 was the highest and 28.09 the lowest. The values are
|
||||
the highest value seen during a run and occurred at a median of iteration
|
||||
11. While this could be interpreted as worse than FP32, if looking at the
|
||||
first iteration at which 28 is passed FP16 performs equal and possibly
|
||||
better. Although not part of the initial test runs, the highest value
|
||||
recorded with the arguments below was 28.9 at iteration 12. Iterations are
|
||||
not epochs, an iteration is a number of steps between evals.
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'big'
|
||||
FLAGS.batch_size = 3072*8
|
||||
FLAGS.train_steps = 20000 * 12
|
||||
FLAGS.steps_between_evals = 20000
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=28,
|
||||
bleu_max=29.2)
|
||||
|
||||
def benchmark_8_gpu_fp16_amp(self):
|
||||
"""Benchmark 8 gpu with dynamic batch and fp16 with automatic mixed precision.
|
||||
|
||||
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.fp16_implementation = 'graph_rewrite'
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'big'
|
||||
FLAGS.batch_size = 3072*8
|
||||
FLAGS.train_steps = 20000 * 12
|
||||
FLAGS.steps_between_evals = 20000
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_amp')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=28,
|
||||
bleu_max=29)
|
||||
|
||||
def benchmark_8_gpu_static_batch_fp16(self):
|
||||
"""Benchmark 8 gpu with static batch and fp16.
|
||||
|
||||
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'big'
|
||||
FLAGS.batch_size = 3072*8
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
FLAGS.train_steps = 400000
|
||||
FLAGS.steps_between_evals = 20000
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch_fp16')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=28,
|
||||
bleu_max=29.2)
|
||||
|
||||
def benchmark_xla_8_gpu_static_batch_fp16(self):
|
||||
"""Benchmark 8 gpu with static batch, XLA, and FP16.
|
||||
|
||||
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
|
||||
"""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.data_dir = self.train_data_dir
|
||||
FLAGS.vocab_file = self.vocab_file
|
||||
# Sets values directly to avoid validation check.
|
||||
FLAGS['bleu_source'].value = self.bleu_source
|
||||
FLAGS['bleu_ref'].value = self.bleu_ref
|
||||
FLAGS.param_set = 'big'
|
||||
FLAGS.batch_size = 3072*8
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
FLAGS.train_steps = 400000
|
||||
FLAGS.steps_between_evals = 20000
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_xla_8_gpu_static_batch_fp16')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
bleu_min=28,
|
||||
bleu_max=29.2)
|
||||
|
||||
|
||||
class TransformerKerasBenchmark(TransformerBenchmark):
|
||||
"""Benchmarks for Transformer (Base and Big) using Keras."""
|
||||
|
||||
def __init__(self, output_dir=None, default_flags=None,
|
||||
root_data_dir=None, batch_per_gpu=4096):
|
||||
"""Initialize.
|
||||
|
||||
Args:
|
||||
output_dir: Based directory for saving artifacts, e.g. checkpoints.
|
||||
default_flags: default flags to use for all tests.
|
||||
root_data_dir: root directory for data, e.g. training.
|
||||
batch_per_gpu: batch size to use per gpu.
|
||||
"""
|
||||
flag_methods = [misc.define_transformer_flags]
|
||||
self.batch_per_gpu = batch_per_gpu
|
||||
|
||||
super(TransformerKerasBenchmark, self).__init__(
|
||||
output_dir=output_dir,
|
||||
default_flags=default_flags,
|
||||
root_data_dir=root_data_dir,
|
||||
flag_methods=flag_methods)
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat(self):
|
||||
"""Benchmark 1 gpu without distribution strategy."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_1_gpu_no_dist_strat_static_batch(self):
|
||||
"""Benchmark 1 gpu without distribution strategy with static batch."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.distribution_strategy = 'off'
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_ds_sb')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_1_gpu(self):
|
||||
"""Benchmark 1 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_1_gpu_fp16(self):
|
||||
"""Benchmark 1 gpu FP16."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
|
||||
FLAGS.dtype = 'fp16'
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_xla_1_gpu(self):
|
||||
"""Benchmark 1 gpu w/xla."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_xla_1_gpu_fp16(self):
|
||||
"""Benchmark 1 gpu w/xla and FP16."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.dtype = 'fp16'
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_1_gpu_static_batch(self):
|
||||
"""Benchmark 1 gpu with static batch."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_xla_1_gpu_static_batch(self):
|
||||
"""Benchmark 1 gpu with static batch w/xla."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_static_batch')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
FLAGS.enable_xla = True
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_1_gpu_static_batch_fp16(self):
|
||||
"""Benchmark 1 gpu with static batch FP16."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_1_gpu_static_batch_fp16')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
FLAGS.dtype = 'fp16'
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_xla_1_gpu_static_batch_fp16(self):
|
||||
"""Benchmark 1 gpu with static batch w/xla and FP16."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 1
|
||||
FLAGS.batch_size = self.batch_per_gpu
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_xla_1_gpu_static_batch_fp16')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.dtype = 'fp16'
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_8_gpu(self):
|
||||
"""Benchmark 8 gpu."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.batch_size = self.batch_per_gpu * 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_8_gpu_fp16(self):
|
||||
"""Benchmark 8 gpu FP16."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.batch_size = self.batch_per_gpu * 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_xla_8_gpu(self):
|
||||
"""Benchmark 8 gpu w/xla."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.batch_size = self.batch_per_gpu * 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_xla_8_gpu_fp16(self):
|
||||
"""Benchmark 8 gpu w/xla and FP16."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.batch_size = self.batch_per_gpu * 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16')
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_8_gpu_static_batch(self):
|
||||
"""Benchmark 8 gpu with static batch."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.batch_size = self.batch_per_gpu * 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_8_gpu_static_batch_fp16(self):
|
||||
"""Benchmark 8 gpu with static batch FP16."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.batch_size = self.batch_per_gpu * 8
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_8_gpu_static_batch_fp16')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_xla_8_gpu_static_batch(self):
|
||||
"""Benchmark 8 gpu with static batch w/xla."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.batch_size = self.batch_per_gpu * 8
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_static_batch')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
def benchmark_xla_8_gpu_static_batch_fp16(self):
|
||||
"""Benchmark 8 gpu with static batch w/xla and FP16."""
|
||||
self._setup()
|
||||
FLAGS.num_gpus = 8
|
||||
FLAGS.enable_xla = True
|
||||
FLAGS.dtype = 'fp16'
|
||||
FLAGS.batch_size = self.batch_per_gpu * 8
|
||||
FLAGS.model_dir = self._get_model_dir(
|
||||
'benchmark_xla_8_gpu_static_batch_fp16')
|
||||
FLAGS.static_batch = True
|
||||
FLAGS.max_length = 64
|
||||
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
|
||||
log_steps=FLAGS.log_steps)
|
||||
|
||||
|
||||
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
|
||||
"""Transformer based version real data benchmark tests."""
|
||||
|
||||
def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
|
||||
def_flags = {}
|
||||
def_flags['param_set'] = 'base'
|
||||
def_flags['train_steps'] = 50
|
||||
def_flags['log_steps'] = 10
|
||||
|
||||
super(TransformerBaseKerasBenchmarkReal, self).__init__(
|
||||
output_dir=output_dir, default_flags=def_flags,
|
||||
root_data_dir=root_data_dir, batch_per_gpu=4096)
|
||||
|
||||
|
||||
class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
|
||||
"""Transformer based version real data benchmark tests."""
|
||||
|
||||
def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
|
||||
def_flags = {}
|
||||
def_flags['param_set'] = 'big'
|
||||
def_flags['train_steps'] = 50
|
||||
def_flags['log_steps'] = 10
|
||||
|
||||
super(TransformerBigKerasBenchmarkReal, self).__init__(
|
||||
output_dir=output_dir, default_flags=def_flags,
|
||||
root_data_dir=root_data_dir, batch_per_gpu=3072)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+216
@@ -0,0 +1,216 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Executes XLNet benchmarks and accuracy tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from absl import flags
|
||||
from absl.testing import flagsaver
|
||||
import tensorflow as tf
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
from official.benchmark import bert_benchmark_utils as benchmark_utils
|
||||
from official.nlp.xlnet import run_classifier
|
||||
from official.nlp.xlnet import run_squad
|
||||
from official.benchmark import benchmark_wrappers
|
||||
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
|
||||
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.train.tf_record'
|
||||
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.dev.eval.tf_record'
|
||||
SQUAD_DATA_PATH = 'gs://tf-perfzero-data/xlnet/squadv2_cased/'
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class XLNetBenchmarkBase(benchmark_utils.BertBenchmarkBase):
|
||||
"""Base class to hold methods common to test classes in the module."""
|
||||
|
||||
def __init__(self, output_dir=None):
|
||||
super(XLNetBenchmarkBase, self).__init__(output_dir)
|
||||
self.num_epochs = None
|
||||
self.num_steps_per_epoch = None
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def _run_xlnet_classifier(self):
|
||||
"""Starts XLNet classification task."""
|
||||
run_classifier.main(unused_argv=None)
|
||||
|
||||
@flagsaver.flagsaver
|
||||
def _run_xlnet_squad(self):
|
||||
"""Starts XLNet classification task."""
|
||||
run_squad.main(unused_argv=None)
|
||||
|
||||
|
||||
class XLNetClassifyAccuracy(XLNetBenchmarkBase):
|
||||
"""Short accuracy test for XLNet classifier model.
|
||||
|
||||
Tests XLNet classification task model accuracy. The naming
|
||||
convention of below test cases follow
|
||||
`benchmark_(number of gpus)_gpu_(dataset type)` format.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=None, **kwargs):
|
||||
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
|
||||
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
|
||||
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
|
||||
|
||||
super(XLNetClassifyAccuracy, self).__init__(output_dir=output_dir)
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
training_summary_path,
|
||||
min_accuracy=0.95,
|
||||
max_accuracy=0.97):
|
||||
"""Starts XLNet accuracy benchmark test."""
|
||||
|
||||
start_time_sec = time.time()
|
||||
self._run_xlnet_classifier()
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
|
||||
summary = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
super(XLNetClassifyAccuracy, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_accuracy=min_accuracy,
|
||||
max_accuracy=max_accuracy)
|
||||
|
||||
def _setup(self):
|
||||
super(XLNetClassifyAccuracy, self)._setup()
|
||||
FLAGS.test_data_size = 25024
|
||||
FLAGS.train_batch_size = 16
|
||||
FLAGS.seq_len = 512
|
||||
FLAGS.mem_len = 0
|
||||
FLAGS.n_layer = 24
|
||||
FLAGS.d_model = 1024
|
||||
FLAGS.d_embed = 1024
|
||||
FLAGS.n_head = 16
|
||||
FLAGS.d_head = 64
|
||||
FLAGS.d_inner = 4096
|
||||
FLAGS.untie_r = True
|
||||
FLAGS.n_class = 2
|
||||
FLAGS.ff_activation = 'gelu'
|
||||
FLAGS.strategy_type = 'mirror'
|
||||
FLAGS.learning_rate = 2e-5
|
||||
FLAGS.train_steps = 4000
|
||||
FLAGS.warmup_steps = 500
|
||||
FLAGS.iterations = 200
|
||||
FLAGS.bi_data = False
|
||||
FLAGS.init_checkpoint = self.pretrained_checkpoint_path
|
||||
FLAGS.train_tfrecord_path = self.train_data_path
|
||||
FLAGS.test_tfrecord_path = self.eval_data_path
|
||||
|
||||
def benchmark_8_gpu_imdb(self):
|
||||
"""Run XLNet model accuracy test with 8 GPUs."""
|
||||
self._setup()
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_imdb')
|
||||
# Sets timer_callback to None as we do not use it now.
|
||||
self.timer_callback = None
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path)
|
||||
|
||||
|
||||
class XLNetSquadAccuracy(XLNetBenchmarkBase):
|
||||
"""Short accuracy test for XLNet squad model.
|
||||
|
||||
Tests XLNet squad task model accuracy. The naming
|
||||
convention of below test cases follow
|
||||
`benchmark_(number of gpus)_gpu_(dataset type)` format.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir=None, **kwargs):
|
||||
self.train_data_path = SQUAD_DATA_PATH
|
||||
self.predict_file = os.path.join(SQUAD_DATA_PATH, "dev-v2.0.json")
|
||||
self.test_data_path = os.path.join(SQUAD_DATA_PATH, "12048.eval.tf_record")
|
||||
self.spiece_model_file = os.path.join(SQUAD_DATA_PATH, "spiece.cased.model")
|
||||
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
|
||||
|
||||
super(XLNetSquadAccuracy, self).__init__(output_dir=output_dir)
|
||||
|
||||
@benchmark_wrappers.enable_runtime_flags
|
||||
def _run_and_report_benchmark(self,
|
||||
training_summary_path,
|
||||
min_accuracy=87.0,
|
||||
max_accuracy=89.0):
|
||||
"""Starts XLNet accuracy benchmark test."""
|
||||
|
||||
start_time_sec = time.time()
|
||||
self._run_xlnet_squad()
|
||||
wall_time_sec = time.time() - start_time_sec
|
||||
|
||||
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
|
||||
summary = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
super(XLNetSquadAccuracy, self)._report_benchmark(
|
||||
stats=summary,
|
||||
wall_time_sec=wall_time_sec,
|
||||
min_accuracy=min_accuracy,
|
||||
max_accuracy=max_accuracy)
|
||||
|
||||
def _setup(self):
|
||||
super(XLNetSquadAccuracy, self)._setup()
|
||||
FLAGS.train_batch_size = 16
|
||||
FLAGS.seq_len = 512
|
||||
FLAGS.mem_len = 0
|
||||
FLAGS.n_layer = 24
|
||||
FLAGS.d_model = 1024
|
||||
FLAGS.d_embed = 1024
|
||||
FLAGS.n_head = 16
|
||||
FLAGS.d_head = 64
|
||||
FLAGS.d_inner = 4096
|
||||
FLAGS.untie_r = True
|
||||
FLAGS.ff_activation = 'gelu'
|
||||
FLAGS.strategy_type = 'mirror'
|
||||
FLAGS.learning_rate = 3e-5
|
||||
FLAGS.train_steps = 8000
|
||||
FLAGS.warmup_steps = 1000
|
||||
FLAGS.iterations = 1000
|
||||
FLAGS.bi_data = False
|
||||
FLAGS.init_checkpoint = self.pretrained_checkpoint_path
|
||||
FLAGS.train_tfrecord_path = self.train_data_path
|
||||
FLAGS.test_tfrecord_path = self.test_data_path
|
||||
FLAGS.spiece_model_file = self.spiece_model_file
|
||||
FLAGS.predict_file = self.predict_file
|
||||
FLAGS.adam_epsilon=1e-6
|
||||
FLAGS.lr_layer_decay_rate=0.75
|
||||
|
||||
def benchmark_8_gpu_squadv2(self):
|
||||
"""Run XLNet model squad v2 accuracy test with 8 GPUs."""
|
||||
self._setup()
|
||||
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squadv2')
|
||||
FLAGS.predict_dir = FLAGS.model_dir
|
||||
# Sets timer_callback to None as we do not use it now.
|
||||
self.timer_callback = None
|
||||
|
||||
summary_path = os.path.join(FLAGS.model_dir,
|
||||
'summaries/training_summary.txt')
|
||||
self._run_and_report_benchmark(summary_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+383
@@ -0,0 +1,383 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "How-to Guide: Using a PIP package for fine-tuning a BERT model",
|
||||
"provenance": [],
|
||||
"collapsed_sections": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "5T_-iFRIqliG",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## How-to Guide: Using a PIP package for fine-tuning a BERT model\n",
|
||||
"\n",
|
||||
"Author: [Chen Chen](https://github.com/chenGitHuber)\n",
|
||||
"\n",
|
||||
"In this example, we will work through fine-tuning a BERT model using the tensorflow-models PIP package."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "mY1vX5VAq4SS",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## License\n",
|
||||
"\n",
|
||||
"Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"you may not use this file except in compliance with the License.\n",
|
||||
"You may obtain a copy of the License at\n",
|
||||
"\n",
|
||||
" http://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"\n",
|
||||
"Unless required by applicable law or agreed to in writing, software\n",
|
||||
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"See the License for the specific language governing permissions and\n",
|
||||
"limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "XV3k63bt0ihl",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Learning objectives\n",
|
||||
"\n",
|
||||
"In this Colab notebook, you will learn how to fine-tune a BERT model using the TensorFlow Model Garden PIP package."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "MHA-RWherfG4",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Enable the GPU acceleration\n",
|
||||
"Please enable GPU for better performance.\n",
|
||||
"* Navigate to Edit 🡒 Notebook settings\n",
|
||||
"* Select GPU from the \"Hardware Accelerator\" drop-down list\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "l2B4N5Djrs2l",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## Install the Model Garden PIP package\n",
|
||||
"\n",
|
||||
"Install the Model Garden PIP package (tf-models-nightly) and other necessary PIP packages."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "VMYZDly6rx97",
|
||||
"colab_type": "code",
|
||||
"outputId": "146956ab-4568-4de6-c78e-cf25f115a5a8",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 1000
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"pip install tf-models-nightly"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Collecting tf-models-nightly\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/bd/7c/1390d4e05d4d370e91d32dd9700d3a462dbc560c7f4e95a6477592b17def/tf_models_nightly-2.2.0.dev20200326-py2.py3-none-any.whl (710kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 716kB 2.8MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.4.1)\n",
|
||||
"Collecting opencv-python-headless\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/0b/23/5f10b30a48b218a4884bc84188c14381ac71288b210f6f8079a54f7a05e8/opencv_python_headless-4.2.0.32-cp36-cp36m-manylinux1_x86_64.whl (21.6MB)\n",
|
||||
"\u001b[K |████████████████████████████████| 21.6MB 1.3MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: tensorflow-hub>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.7.0)\n",
|
||||
"Requirement already satisfied: typing in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.6.6)\n",
|
||||
"Collecting tensorflow-model-optimization>=0.2.1\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8f/c4/4c3d011e432bd9c19f0323f7da7d3f783402615e4c3b5a98416c7da9cb05/tensorflow_model_optimization-0.2.1-py2.py3-none-any.whl (93kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 102kB 10.2MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: numpy>=1.15.4 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.18.2)\n",
|
||||
"Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.25.3)\n",
|
||||
"Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.3.0)\n",
|
||||
"Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.21.0)\n",
|
||||
"Collecting mlperf-compliance==0.0.10\n",
|
||||
" Downloading https://files.pythonhosted.org/packages/f4/08/f2febd8cbd5c9371f7dab311e90400d83238447ba7609b3bf0145b4cb2a2/mlperf_compliance-0.0.10-py3-none-any.whl\n",
|
||||
"Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.5.6)\n",
|
||||
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.12.0)\n",
|
||||
"Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.7.12)\n",
|
||||
"Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.8.3)\n",
|
||||
"Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (5.4.8)\n",
|
||||
"Requirement already satisfied: tensorflow-datasets in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (2.1.0)\n",
|
||||
"Collecting sentencepiece\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
|
||||
"\u001b[K |████████████████████████████████| 1.0MB 58.3MB/s \n",
|
||||
"\u001b[?25hCollecting tf-nightly\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/39/1c/4408b4c4b0d8008a7de62162e35089d59d19cc7543cfd1b23a70121f3086/tf_nightly-2.2.0.dev20200325-cp36-cp36m-manylinux2010_x86_64.whl (516.1MB)\n",
|
||||
"\u001b[K |████████████████████████████████| 516.1MB 21kB/s \n",
|
||||
"\u001b[?25hCollecting py-cpuinfo>=3.3.0\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/42/60/63f28a5401da733043abe7053e7d9591491b4784c4f87c339bf51215aa0a/py-cpuinfo-5.0.0.tar.gz (82kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 92kB 13.7MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (7.0.0)\n",
|
||||
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.13)\n",
|
||||
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.2.1)\n",
|
||||
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.7)\n",
|
||||
"Requirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.29.15)\n",
|
||||
"Requirement already satisfied: oauth2client>=4.1.2 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (4.1.3)\n",
|
||||
"Requirement already satisfied: protobuf>=3.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.6.0->tf-models-nightly) (3.10.0)\n",
|
||||
"Collecting enum34~=1.1\n",
|
||||
" Downloading https://files.pythonhosted.org/packages/63/f6/ccb1c83687756aeabbf3ca0f213508fcfb03883ff200d201b3a4c60cedcc/enum34-1.1.10-py3-none-any.whl\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly) (2.8.1)\n",
|
||||
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly) (2018.9)\n",
|
||||
"Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly) (0.4.1)\n",
|
||||
"Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly) (1.0.3)\n",
|
||||
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (1.24.3)\n",
|
||||
"Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (2019.11.28)\n",
|
||||
"Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (4.0.0)\n",
|
||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (4.38.0)\n",
|
||||
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (2.21.0)\n",
|
||||
"Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (0.0.3)\n",
|
||||
"Requirement already satisfied: httplib2<1dev,>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (0.17.0)\n",
|
||||
"Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (3.0.1)\n",
|
||||
"Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (1.7.2)\n",
|
||||
"Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly) (2.7.1)\n",
|
||||
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.16.0)\n",
|
||||
"Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.9.0)\n",
|
||||
"Requirement already satisfied: wrapt in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (1.12.1)\n",
|
||||
"Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (2.3)\n",
|
||||
"Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (19.3.0)\n",
|
||||
"Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (1.1.0)\n",
|
||||
"Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.21.1)\n",
|
||||
"Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.3.1.1)\n",
|
||||
"Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.3.3)\n",
|
||||
"Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.34.2)\n",
|
||||
"Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.27.2)\n",
|
||||
"Collecting tb-nightly<2.3.0a0,>=2.2.0a0\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/52/b6/aa559ea9edbc6129a64ff752dbf6567bcd62fba34566defca00fdff4345e/tb_nightly-2.2.0a20200324-py3-none-any.whl (2.8MB)\n",
|
||||
"\u001b[K |████████████████████████████████| 2.8MB 51.9MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (2.10.0)\n",
|
||||
"Collecting tf-estimator-nightly\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/bf/eb/0d6c06d1181cd9b52d7c82535073887d68d224f3dbeeb00adefd04762a9c/tf_estimator_nightly-2.3.0.dev2020032501-py2.py3-none-any.whl (455kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 460kB 53.5MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.2.0)\n",
|
||||
"Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.6.3)\n",
|
||||
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (3.2.0)\n",
|
||||
"Requirement already satisfied: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.1.0)\n",
|
||||
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (2.4.6)\n",
|
||||
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (1.1.0)\n",
|
||||
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (0.10.0)\n",
|
||||
"Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (0.2.8)\n",
|
||||
"Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (4.0)\n",
|
||||
"Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (0.4.8)\n",
|
||||
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.4.0->tensorflow-hub>=0.6.0->tf-models-nightly) (46.0.0)\n",
|
||||
"Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly) (1.16.0)\n",
|
||||
"Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly) (1.3)\n",
|
||||
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle>=1.3.9->tf-models-nightly) (2.8)\n",
|
||||
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle>=1.3.9->tf-models-nightly) (3.0.4)\n",
|
||||
"Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly) (3.1.1)\n",
|
||||
"Requirement already satisfied: googleapis-common-protos in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets->tf-models-nightly) (1.51.0)\n",
|
||||
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (3.2.1)\n",
|
||||
"Collecting tensorboard-plugin-wit>=1.6.0\n",
|
||||
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/41/ec/3da49289b93963bd8b32d29ed108f1809436ff3d9cd4e29c90bac4a7292f/tensorboard_plugin_wit-1.6.0.post2-py3-none-any.whl (775kB)\n",
|
||||
"\u001b[K |████████████████████████████████| 778kB 43.8MB/s \n",
|
||||
"\u001b[?25hRequirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (0.4.1)\n",
|
||||
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (1.0.0)\n",
|
||||
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (1.3.0)\n",
|
||||
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (3.1.0)\n",
|
||||
"Building wheels for collected packages: py-cpuinfo\n",
|
||||
" Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
||||
" Created wheel for py-cpuinfo: filename=py_cpuinfo-5.0.0-cp36-none-any.whl size=18684 sha256=093853bc49757be8f9facd8d38eab95b3cc2de13514dafa221ad7898725491ff\n",
|
||||
" Stored in directory: /root/.cache/pip/wheels/01/7e/a9/b982d0fea22b7e4ae5619de949570cde5ad55420cec16e86a5\n",
|
||||
"Successfully built py-cpuinfo\n",
|
||||
"Installing collected packages: opencv-python-headless, enum34, tensorflow-model-optimization, mlperf-compliance, sentencepiece, tensorboard-plugin-wit, tb-nightly, tf-estimator-nightly, tf-nightly, py-cpuinfo, tf-models-nightly\n",
|
||||
"Successfully installed enum34-1.1.10 mlperf-compliance-0.0.10 opencv-python-headless-4.2.0.32 py-cpuinfo-5.0.0 sentencepiece-0.1.85 tb-nightly-2.2.0a20200324 tensorboard-plugin-wit-1.6.0.post2 tensorflow-model-optimization-0.2.1 tf-estimator-nightly-2.3.0.dev2020032501 tf-models-nightly-2.2.0.dev20200326 tf-nightly-2.2.0.dev20200325\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.colab-display-data+json": {
|
||||
"pip_warning": {
|
||||
"packages": [
|
||||
"enum"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "w4oyRCTji-aa",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"## BERT Fine-tuning\n",
|
||||
"\n",
|
||||
"The following code import necessary modules for fine-tuning a BERT model on a classification task.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "rV8MIX7g078-",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"%tensorflow_version 2.x\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"import json\n",
|
||||
"import math\n",
|
||||
"\n",
|
||||
"from official.utils.misc import distribution_utils\n",
|
||||
"from official.nlp import optimization\n",
|
||||
"from official.nlp.bert import bert_models\n",
|
||||
"from official.nlp.bert import configs as bert_configs\n",
|
||||
"from official.nlp.bert import run_classifier\n",
|
||||
"from official.modeling import activations\n",
|
||||
"from official.nlp.modeling import networks\n",
|
||||
"from official.nlp.modeling.models import bert_classifier"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "DIuS8nYD08n3",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"This section of code performs the following tasks:\n",
|
||||
"* Load data for fine-tuning\n",
|
||||
"* Fine-tune a BERT model\n",
|
||||
"* Save the fine-tuned model to a TensorFlow SavedModel file\n",
|
||||
"\n",
|
||||
"Please check [create_finetuning_data.py](https://github.com/tensorflow/models/blob/master/official/nlp/data/create_finetuning_data.py) if you want to know how the train/eval data are created."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "PAby1RTCi_1e",
|
||||
"colab_type": "code",
|
||||
"outputId": "e663e830-cc9b-4b5d-99db-4504fd66d5f3",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 258
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"\n",
|
||||
"train_data_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_train.tf_record\"\n",
|
||||
"eval_data_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_eval.tf_record\"\n",
|
||||
"input_meta_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_meta_data\"\n",
|
||||
"\n",
|
||||
"bert_config_file = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json\"\n",
|
||||
"ckpt_path = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
|
||||
"\n",
|
||||
"with tf.io.gfile.GFile(input_meta_path, 'rb') as reader:\n",
|
||||
" input_meta_data = json.loads(reader.read().decode('utf-8'))\n",
|
||||
"\n",
|
||||
"max_seq_length = input_meta_data['max_seq_length']\n",
|
||||
"num_classes = input_meta_data['num_labels']\n",
|
||||
"batch_size = 32\n",
|
||||
"eval_batch_size = 32\n",
|
||||
"train_input_fn = run_classifier.get_dataset_fn(train_data_path, max_seq_length, batch_size, is_training=True)\n",
|
||||
"eval_input_fn = run_classifier.get_dataset_fn(eval_data_path, max_seq_length, eval_batch_size, is_training=False)\n",
|
||||
"\n",
|
||||
"strategy = distribution_utils.get_distribution_strategy(\n",
|
||||
" distribution_strategy='one_device', num_gpus=1)\n",
|
||||
"\n",
|
||||
"with strategy.scope():\n",
|
||||
" training_dataset = train_input_fn()\n",
|
||||
" evaluation_dataset = eval_input_fn()\n",
|
||||
" bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)\n",
|
||||
" classifier_model, encoder = bert_models.classifier_model(\n",
|
||||
" bert_config, num_classes, max_seq_length)\n",
|
||||
"\n",
|
||||
" checkpoint = tf.train.Checkpoint(model=encoder)\n",
|
||||
" checkpoint.restore(ckpt_path).assert_consumed()\n",
|
||||
"\n",
|
||||
" epochs = 3\n",
|
||||
" train_data_size = input_meta_data['train_data_size']\n",
|
||||
" eval_data_size = input_meta_data['eval_data_size']\n",
|
||||
" steps_per_epoch = int(train_data_size / batch_size)\n",
|
||||
" warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)\n",
|
||||
" optimizer = optimization.create_optimizer(\n",
|
||||
" 2e-5, num_train_steps=steps_per_epoch * epochs, num_warmup_steps=warmup_steps)\n",
|
||||
"\n",
|
||||
" def metric_fn():\n",
|
||||
" return tf.keras.metrics.SparseCategoricalAccuracy(\n",
|
||||
" 'test_accuracy', dtype=tf.float32)\n",
|
||||
"\n",
|
||||
" classifier_model.compile(optimizer=optimizer,\n",
|
||||
" loss=run_classifier.get_loss_fn(num_classes=2),\n",
|
||||
" metrics=[metric_fn()])\n",
|
||||
" classifier_model.fit(\n",
|
||||
" x=training_dataset,\n",
|
||||
" validation_data=evaluation_dataset,\n",
|
||||
" steps_per_epoch=steps_per_epoch,\n",
|
||||
" epochs=epochs,\n",
|
||||
" validation_steps=int(eval_data_size / eval_batch_size))\n",
|
||||
"\n",
|
||||
" classifier_model.save('/tmp/saved_model', include_optimizer=False, save_format='tf')"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING:tensorflow:BertClassifier inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to \"bert_classifier\" was not an Input tensor, it was generated by layer input_mask.\n",
|
||||
"Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n",
|
||||
"The tensor that caused the issue was: input_mask:0\n",
|
||||
"WARNING:tensorflow:BertClassifier inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to \"bert_classifier\" was not an Input tensor, it was generated by layer input_type_ids.\n",
|
||||
"Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n",
|
||||
"The tensor that caused the issue was: input_type_ids:0\n",
|
||||
"Epoch 1/3\n",
|
||||
"114/114 [==============================] - 96s 840ms/step - loss: 0.5932 - test_accuracy: 0.6960 - val_loss: 0.5083 - val_test_accuracy: 0.7604\n",
|
||||
"Epoch 2/3\n",
|
||||
"114/114 [==============================] - 100s 878ms/step - loss: 0.4225 - test_accuracy: 0.8183 - val_loss: 0.4020 - val_test_accuracy: 0.8438\n",
|
||||
"Epoch 3/3\n",
|
||||
"114/114 [==============================] - 100s 880ms/step - loss: 0.2482 - test_accuracy: 0.9134 - val_loss: 0.4065 - val_test_accuracy: 0.8151\n",
|
||||
"INFO:tensorflow:Assets written to: /tmp/saved_model/assets\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Activations package definition."""
|
||||
from official.modeling.activations.gelu import gelu
|
||||
from official.modeling.activations.swish import hard_swish
|
||||
from official.modeling.activations.swish import identity
|
||||
from official.modeling.activations.swish import simple_swish
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Gaussian error linear unit."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package='Text')
|
||||
def gelu(x):
|
||||
"""Gaussian Error Linear Unit.
|
||||
|
||||
This is a smoother version of the RELU.
|
||||
Original paper: https://arxiv.org/abs/1606.08415
|
||||
Args:
|
||||
x: float Tensor to perform activation.
|
||||
|
||||
Returns:
|
||||
`x` with the GELU activation applied.
|
||||
"""
|
||||
cdf = 0.5 * (1.0 + tf.tanh(
|
||||
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||
return x * cdf
|
||||
+38
@@ -0,0 +1,38 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the Gaussian error linear unit."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
|
||||
from official.modeling import activations
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class GeluTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_gelu(self):
|
||||
expected_data = [[0.14967535, 0., -0.10032465],
|
||||
[-0.15880796, -0.04540223, 2.9963627]]
|
||||
gelu_data = activations.gelu([[.25, 0, -.25], [-1, -2, 3]])
|
||||
self.assertAllClose(expected_data, gelu_data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+75
@@ -0,0 +1,75 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Customized Swish activation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package='Text')
|
||||
def simple_swish(features):
|
||||
"""Computes the Swish activation function.
|
||||
|
||||
The tf.nn.swish operation uses a custom gradient to reduce memory usage.
|
||||
Since saving custom gradients in SavedModel is currently not supported, and
|
||||
one would not be able to use an exported TF-Hub module for fine-tuning, we
|
||||
provide this wrapper that can allow to select whether to use the native
|
||||
TensorFlow swish operation, or whether to use a customized operation that
|
||||
has uses default TensorFlow gradient computation.
|
||||
|
||||
Args:
|
||||
features: A `Tensor` representing preactivation values.
|
||||
|
||||
Returns:
|
||||
The activation value.
|
||||
"""
|
||||
features = tf.convert_to_tensor(features)
|
||||
return features * tf.nn.sigmoid(features)
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package='Text')
|
||||
def hard_swish(features):
|
||||
"""Computes a hard version of the swish function.
|
||||
|
||||
This operation can be used to reduce computational cost and improve
|
||||
quantization for edge devices.
|
||||
|
||||
Args:
|
||||
features: A `Tensor` representing preactivation values.
|
||||
|
||||
Returns:
|
||||
The activation value.
|
||||
"""
|
||||
features = tf.convert_to_tensor(features)
|
||||
return features * tf.nn.relu6(features + tf.constant(3.)) * (1. / 6.)
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package='Text')
|
||||
def identity(features):
|
||||
"""Computes the identity function.
|
||||
|
||||
Useful for helping in quantization.
|
||||
|
||||
Args:
|
||||
features: A `Tensor` representing preactivation values.
|
||||
|
||||
Returns:
|
||||
The activation value.
|
||||
"""
|
||||
features = tf.convert_to_tensor(features)
|
||||
return tf.identity(features)
|
||||
+49
@@ -0,0 +1,49 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the customized Swish activation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
|
||||
from official.modeling import activations
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class CustomizedSwishTest(keras_parameterized.TestCase):
|
||||
|
||||
def _hard_swish_np(self, x):
|
||||
x = np.float32(x)
|
||||
return x * np.clip(x + 3, 0, 6) / 6
|
||||
|
||||
def test_simple_swish(self):
|
||||
features = [[.25, 0, -.25], [-1, -2, 3]]
|
||||
customized_swish_data = activations.simple_swish(features)
|
||||
swish_data = tf.nn.swish(features)
|
||||
self.assertAllClose(customized_swish_data, swish_data)
|
||||
|
||||
def test_hard_swish(self):
|
||||
features = [[.25, 0, -.25], [-1, -2, 3]]
|
||||
customized_swish_data = activations.hard_swish(features)
|
||||
swish_data = self._hard_swish_np(features)
|
||||
self.assertAllClose(customized_swish_data, swish_data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+323
@@ -0,0 +1,323 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Base configurations to standardize experiments."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
# from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import functools
|
||||
from typing import Any, List, Mapping, Optional, Type
|
||||
|
||||
import dataclasses
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from official.modeling.hyperparams import params_dict
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config(params_dict.ParamsDict):
|
||||
"""The base configuration class that supports YAML/JSON based overrides.
|
||||
|
||||
* It recursively enforces a whitelist of basic types and container types, so
|
||||
it avoids surprises with copy and reuse caused by unanticipated types.
|
||||
* It converts dict to Config even within sequences,
|
||||
e.g. for config = Config({'key': [([{'a': 42}],)]),
|
||||
type(config.key[0][0][0]) is Config rather than dict.
|
||||
"""
|
||||
|
||||
# It's safe to add bytes and other immutable types here.
|
||||
IMMUTABLE_TYPES = (str, int, float, bool, type(None))
|
||||
# It's safe to add set, frozenset and other collections here.
|
||||
SEQUENCE_TYPES = (list, tuple)
|
||||
|
||||
default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
|
||||
restrictions: dataclasses.InitVar[Optional[List[str]]] = None
|
||||
|
||||
@classmethod
|
||||
def _isvalidsequence(cls, v):
|
||||
"""Check if the input values are valid sequences.
|
||||
|
||||
Args:
|
||||
v: Input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is valid. Valid sequence includes the sequence
|
||||
type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
|
||||
is dict or ParamsDict.
|
||||
"""
|
||||
if not isinstance(v, cls.SEQUENCE_TYPES):
|
||||
return False
|
||||
return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
|
||||
all(isinstance(e, dict) for e in v) or
|
||||
all(isinstance(e, params_dict.ParamsDict) for e in v))
|
||||
|
||||
@classmethod
|
||||
def _import_config(cls, v, subconfig_type):
|
||||
"""Returns v with dicts converted to Configs, recursively."""
|
||||
if not issubclass(subconfig_type, params_dict.ParamsDict):
|
||||
raise TypeError(
|
||||
'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
|
||||
subconfig_type))
|
||||
if isinstance(v, cls.IMMUTABLE_TYPES):
|
||||
return v
|
||||
elif isinstance(v, cls.SEQUENCE_TYPES):
|
||||
# Only support one layer of sequence.
|
||||
if not cls._isvalidsequence(v):
|
||||
raise TypeError(
|
||||
'Invalid sequence: only supports single level {!r} of {!r} or '
|
||||
'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
|
||||
cls.IMMUTABLE_TYPES, v))
|
||||
import_fn = functools.partial(
|
||||
cls._import_config, subconfig_type=subconfig_type)
|
||||
return type(v)(map(import_fn, v))
|
||||
elif isinstance(v, params_dict.ParamsDict):
|
||||
# Deepcopy here is a temporary solution for preserving type in nested
|
||||
# Config object.
|
||||
return copy.deepcopy(v)
|
||||
elif isinstance(v, dict):
|
||||
return subconfig_type(v)
|
||||
else:
|
||||
raise TypeError('Unknown type: {!r}'.format(type(v)))
|
||||
|
||||
@classmethod
|
||||
def _export_config(cls, v):
|
||||
"""Returns v with Configs converted to dicts, recursively."""
|
||||
if isinstance(v, cls.IMMUTABLE_TYPES):
|
||||
return v
|
||||
elif isinstance(v, cls.SEQUENCE_TYPES):
|
||||
return type(v)(map(cls._export_config, v))
|
||||
elif isinstance(v, params_dict.ParamsDict):
|
||||
return v.as_dict()
|
||||
elif isinstance(v, dict):
|
||||
raise TypeError('dict value not supported in converting.')
|
||||
else:
|
||||
raise TypeError('Unknown type: {!r}'.format(type(v)))
|
||||
|
||||
@classmethod
|
||||
def _get_subconfig_type(cls, k) -> Type[params_dict.ParamsDict]:
|
||||
"""Get element type by the field name.
|
||||
|
||||
Args:
|
||||
k: the key/name of the field.
|
||||
|
||||
Returns:
|
||||
Config as default. If a type annotation is found for `k`,
|
||||
1) returns the type of the annotation if it is subtype of ParamsDict;
|
||||
2) returns the element type if the annotation of `k` is List[SubType]
|
||||
or Tuple[SubType].
|
||||
"""
|
||||
subconfig_type = Config
|
||||
if k in cls.__annotations__:
|
||||
# Directly Config subtype.
|
||||
type_annotation = cls.__annotations__[k]
|
||||
if (isinstance(type_annotation, type) and
|
||||
issubclass(type_annotation, Config)):
|
||||
subconfig_type = cls.__annotations__[k]
|
||||
else:
|
||||
# Check if the field is a sequence of subtypes.
|
||||
field_type = getattr(type_annotation, '__origin__', type(None))
|
||||
if (isinstance(field_type, type) and
|
||||
issubclass(field_type, cls.SEQUENCE_TYPES)):
|
||||
element_type = getattr(type_annotation, '__args__', [type(None)])[0]
|
||||
subconfig_type = (
|
||||
element_type if issubclass(element_type, params_dict.ParamsDict)
|
||||
else subconfig_type)
|
||||
return subconfig_type
|
||||
|
||||
def __post_init__(self, default_params, restrictions, *args, **kwargs):
|
||||
super().__init__(default_params=default_params,
|
||||
restrictions=restrictions,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
def _set(self, k, v):
|
||||
"""Overrides same method in ParamsDict.
|
||||
|
||||
Also called by ParamsDict methods.
|
||||
|
||||
Args:
|
||||
k: key to set.
|
||||
v: value.
|
||||
|
||||
Raises:
|
||||
RuntimeError
|
||||
"""
|
||||
subconfig_type = self._get_subconfig_type(k)
|
||||
if isinstance(v, dict):
|
||||
if k not in self.__dict__ or not self.__dict__[k]:
|
||||
# If the key not exist or the value is None, a new Config-family object
|
||||
# sould be created for the key.
|
||||
self.__dict__[k] = subconfig_type(v)
|
||||
else:
|
||||
self.__dict__[k].override(v)
|
||||
else:
|
||||
self.__dict__[k] = self._import_config(v, subconfig_type)
|
||||
|
||||
def __setattr__(self, k, v):
|
||||
if k not in self.RESERVED_ATTR:
|
||||
if getattr(self, '_locked', False):
|
||||
raise ValueError('The Config has been locked. ' 'No change is allowed.')
|
||||
self._set(k, v)
|
||||
|
||||
def _override(self, override_dict, is_strict=True):
|
||||
"""Overrides same method in ParamsDict.
|
||||
|
||||
Also called by ParamsDict methods.
|
||||
|
||||
Args:
|
||||
override_dict: dictionary to write to .
|
||||
is_strict: If True, not allows to add new keys.
|
||||
|
||||
Raises:
|
||||
KeyError: overriding reserved keys or keys not exist (is_strict=True).
|
||||
"""
|
||||
for k, v in sorted(override_dict.items()):
|
||||
if k in self.RESERVED_ATTR:
|
||||
raise KeyError('The key {!r} is internally reserved. '
|
||||
'Can not be overridden.'.format(k))
|
||||
if k not in self.__dict__:
|
||||
if is_strict:
|
||||
raise KeyError('The key {!r} does not exist in {!r}. '
|
||||
'To extend the existing keys, use '
|
||||
'`override` with `is_strict` = False.'.format(
|
||||
k, type(self)))
|
||||
else:
|
||||
self._set(k, v)
|
||||
else:
|
||||
if isinstance(v, dict) and self.__dict__[k]:
|
||||
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
|
||||
elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
|
||||
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
|
||||
else:
|
||||
self._set(k, v)
|
||||
|
||||
def as_dict(self):
|
||||
"""Returns a dict representation of params_dict.ParamsDict.
|
||||
|
||||
For the nested params_dict.ParamsDict, a nested dict will be returned.
|
||||
"""
|
||||
return {
|
||||
k: self._export_config(v)
|
||||
for k, v in self.__dict__.items()
|
||||
if k not in self.RESERVED_ATTR
|
||||
}
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Like `override`, but returns a copy with the current config unchanged."""
|
||||
params = self.__class__(self)
|
||||
params.override(kwargs, is_strict=True)
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, file_path: str):
|
||||
# Note: This only works if the Config has all default values.
|
||||
with tf.io.gfile.GFile(file_path, 'r') as f:
|
||||
loaded = yaml.load(f)
|
||||
config = cls()
|
||||
config.override(loaded)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, file_path: str):
|
||||
"""Wrapper for `from_yaml`."""
|
||||
return cls.from_yaml(file_path)
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, *args, **kwargs):
|
||||
"""Builds a config from the given list of arguments."""
|
||||
attributes = list(cls.__annotations__.keys())
|
||||
default_params = {a: p for a, p in zip(attributes, args)}
|
||||
default_params.update(kwargs)
|
||||
return cls(default_params)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RuntimeConfig(Config):
|
||||
"""High-level configurations for Runtime.
|
||||
|
||||
These include parameters that are not directly related to the experiment,
|
||||
e.g. directories, accelerator type, etc.
|
||||
|
||||
Attributes:
|
||||
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
|
||||
enable_xla: Whether or not to enable XLA.
|
||||
per_gpu_thread_count: thread count per GPU.
|
||||
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
|
||||
dataset_num_private_threads: Number of threads for a private threadpool
|
||||
created for all datasets computation.
|
||||
tpu: The address of the TPU to use, if any.
|
||||
num_gpus: The number of GPUs to use, if any.
|
||||
worker_hosts: comma-separated list of worker ip:port pairs for running
|
||||
multi-worker models with DistributionStrategy.
|
||||
task_index: If multi-worker training, the task index of this worker.
|
||||
all_reduce_alg: Defines the algorithm for performing all-reduce.
|
||||
num_packs: Sets `num_packs` in the cross device ops used in
|
||||
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
|
||||
loss_scale: The type of loss scale. This is used when setting the mixed
|
||||
precision policy.
|
||||
run_eagerly: Whether or not to run the experiment eagerly.
|
||||
|
||||
"""
|
||||
distribution_strategy: str = 'mirrored'
|
||||
enable_xla: bool = False
|
||||
gpu_thread_mode: Optional[str] = None
|
||||
dataset_num_private_threads: Optional[int] = None
|
||||
per_gpu_thread_count: int = 0
|
||||
tpu: Optional[str] = None
|
||||
num_gpus: int = 0
|
||||
worker_hosts: Optional[str] = None
|
||||
task_index: int = -1
|
||||
all_reduce_alg: Optional[str] = None
|
||||
num_packs: int = 1
|
||||
loss_scale: Optional[str] = None
|
||||
run_eagerly: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TensorboardConfig(Config):
|
||||
"""Configuration for Tensorboard.
|
||||
|
||||
Attributes:
|
||||
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
|
||||
to True.
|
||||
write_model_weights: Whether or not to write the model weights as
|
||||
images in Tensorboard. Defaults to False.
|
||||
|
||||
"""
|
||||
track_lr: bool = True
|
||||
write_model_weights: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CallbacksConfig(Config):
|
||||
"""Configuration for Callbacks.
|
||||
|
||||
Attributes:
|
||||
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
|
||||
Callback. Defaults to True.
|
||||
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
|
||||
Defaults to True.
|
||||
enable_time_history: Whether or not to enable TimeHistory Callbacks.
|
||||
Defaults to True.
|
||||
|
||||
"""
|
||||
enable_checkpoint_and_export: bool = True
|
||||
enable_tensorboard: bool = True
|
||||
enable_time_history: bool = True
|
||||
+299
@@ -0,0 +1,299 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import pprint
|
||||
from typing import List, Tuple
|
||||
|
||||
from absl.testing import parameterized
|
||||
import dataclasses
|
||||
import tensorflow as tf
|
||||
from official.modeling.hyperparams import base_config
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DumpConfig1(base_config.Config):
|
||||
a: int = 1
|
||||
b: str = 'text'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DumpConfig2(base_config.Config):
|
||||
c: int = 2
|
||||
d: str = 'text'
|
||||
e: DumpConfig1 = DumpConfig1()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DumpConfig3(DumpConfig2):
|
||||
f: int = 2
|
||||
g: str = 'text'
|
||||
h: List[DumpConfig1] = dataclasses.field(
|
||||
default_factory=lambda: [DumpConfig1(), DumpConfig1()])
|
||||
g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
|
||||
|
||||
|
||||
class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
|
||||
|
||||
def assertHasSameTypes(self, c, d, msg=''):
|
||||
"""Checks if a Config has the same structure as a given dict.
|
||||
|
||||
Args:
|
||||
c: the Config object to be check.
|
||||
d: the reference dict object.
|
||||
msg: The error message to show when type mismatched.
|
||||
"""
|
||||
# Make sure d is not a Config. Assume d is either
|
||||
# dictionary or primitive type and c is the Config or primitive types.
|
||||
self.assertNotIsInstance(d, base_config.Config)
|
||||
if isinstance(d, base_config.Config.IMMUTABLE_TYPES):
|
||||
self.assertEqual(pprint.pformat(c), pprint.pformat(d), msg=msg)
|
||||
elif isinstance(d, base_config.Config.SEQUENCE_TYPES):
|
||||
self.assertEqual(type(c), type(d), msg=msg)
|
||||
for i, v in enumerate(d):
|
||||
self.assertHasSameTypes(c[i], v, msg='{}[{!r}]'.format(msg, i))
|
||||
elif isinstance(d, dict):
|
||||
self.assertIsInstance(c, base_config.Config, msg=msg)
|
||||
for k, v in sorted(d.items()):
|
||||
self.assertHasSameTypes(getattr(c, k), v, msg='{}[{!r}]'.format(msg, k))
|
||||
else:
|
||||
raise TypeError('Unknown type: %r' % type(d))
|
||||
|
||||
def assertImportExport(self, v):
|
||||
config = base_config.Config({'key': v})
|
||||
back = config.as_dict()['key']
|
||||
self.assertEqual(pprint.pformat(back), pprint.pformat(v))
|
||||
self.assertHasSameTypes(config.key, v, msg='=%s v' % pprint.pformat(v))
|
||||
|
||||
def test_invalid_keys(self):
|
||||
params = base_config.Config()
|
||||
with self.assertRaises(AttributeError):
|
||||
_ = params.a
|
||||
|
||||
def test_nested_config_types(self):
|
||||
config = DumpConfig3()
|
||||
self.assertIsInstance(config.e, DumpConfig1)
|
||||
self.assertIsInstance(config.h[0], DumpConfig1)
|
||||
self.assertIsInstance(config.h[1], DumpConfig1)
|
||||
self.assertIsInstance(config.g[0], DumpConfig1)
|
||||
|
||||
config.override({'e': {'a': 2, 'b': 'new text'}})
|
||||
self.assertIsInstance(config.e, DumpConfig1)
|
||||
self.assertEqual(config.e.a, 2)
|
||||
self.assertEqual(config.e.b, 'new text')
|
||||
|
||||
config.override({'h': [{'a': 3, 'b': 'new text 2'}]})
|
||||
self.assertIsInstance(config.h[0], DumpConfig1)
|
||||
self.assertLen(config.h, 1)
|
||||
self.assertEqual(config.h[0].a, 3)
|
||||
self.assertEqual(config.h[0].b, 'new text 2')
|
||||
|
||||
config.override({'g': [{'a': 4, 'b': 'new text 3'}]})
|
||||
self.assertIsInstance(config.g[0], DumpConfig1)
|
||||
self.assertLen(config.g, 1)
|
||||
self.assertEqual(config.g[0].a, 4)
|
||||
self.assertEqual(config.g[0].b, 'new text 3')
|
||||
|
||||
@parameterized.parameters(
|
||||
('_locked', "The key '_locked' is internally reserved."),
|
||||
('_restrictions', "The key '_restrictions' is internally reserved."),
|
||||
('aa', "The key 'aa' does not exist."),
|
||||
)
|
||||
def test_key_error(self, key, msg):
|
||||
params = base_config.Config()
|
||||
with self.assertRaisesRegex(KeyError, msg):
|
||||
params.override({key: True})
|
||||
|
||||
@parameterized.parameters(
|
||||
('str data',),
|
||||
(123,),
|
||||
(1.23,),
|
||||
(None,),
|
||||
(['str', 1, 2.3, None],),
|
||||
(('str', 1, 2.3, None),),
|
||||
)
|
||||
def test_import_export_immutable_types(self, v):
|
||||
self.assertImportExport(v)
|
||||
out = base_config.Config({'key': v})
|
||||
self.assertEqual(pprint.pformat(v), pprint.pformat(out.key))
|
||||
|
||||
def test_override_is_strict_true(self):
|
||||
params = base_config.Config({
|
||||
'a': 'aa',
|
||||
'b': 2,
|
||||
'c': {
|
||||
'c1': 'cc',
|
||||
'c2': 20
|
||||
}
|
||||
})
|
||||
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
|
||||
self.assertEqual(params.a, 2)
|
||||
self.assertEqual(params.c.c1, 'ccc')
|
||||
with self.assertRaises(KeyError):
|
||||
params.override({'d': 'ddd'}, is_strict=True)
|
||||
with self.assertRaises(KeyError):
|
||||
params.override({'c': {'c3': 30}}, is_strict=True)
|
||||
|
||||
config = base_config.Config({'key': [{'a': 42}]})
|
||||
config.override({'key': [{'b': 43}]})
|
||||
self.assertEqual(config.key[0].b, 43)
|
||||
with self.assertRaisesRegex(AttributeError, 'The key `a` does not exist'):
|
||||
_ = config.key[0].a
|
||||
|
||||
@parameterized.parameters(
|
||||
(lambda x: x, 'Unknown type'),
|
||||
(object(), 'Unknown type'),
|
||||
(set(), 'Unknown type'),
|
||||
(frozenset(), 'Unknown type'),
|
||||
)
|
||||
def test_import_unsupport_types(self, v, msg):
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
_ = base_config.Config({'key': v})
|
||||
|
||||
@parameterized.parameters(
|
||||
({
|
||||
'a': [{
|
||||
'b': 2,
|
||||
}, {
|
||||
'c': 3,
|
||||
}]
|
||||
},),
|
||||
({
|
||||
'c': [{
|
||||
'f': 1.1,
|
||||
}, {
|
||||
'h': [1, 2],
|
||||
}]
|
||||
},),
|
||||
(({
|
||||
'a': 'aa',
|
||||
'b': 2,
|
||||
'c': {
|
||||
'c1': 10,
|
||||
'c2': 20,
|
||||
}
|
||||
},),),
|
||||
)
|
||||
def test_import_export_nested_structure(self, d):
|
||||
self.assertImportExport(d)
|
||||
|
||||
@parameterized.parameters(
|
||||
([{
|
||||
'a': 42,
|
||||
'b': 'hello',
|
||||
'c': 1.2
|
||||
}],),
|
||||
(({
|
||||
'a': 42,
|
||||
'b': 'hello',
|
||||
'c': 1.2
|
||||
},),),
|
||||
)
|
||||
def test_import_export_nested_sequences(self, v):
|
||||
self.assertImportExport(v)
|
||||
|
||||
@parameterized.parameters(
|
||||
([([{}],)],),
|
||||
([['str', 1, 2.3, None]],),
|
||||
((('str', 1, 2.3, None),),),
|
||||
([
|
||||
('str', 1, 2.3, None),
|
||||
],),
|
||||
([
|
||||
('str', 1, 2.3, None),
|
||||
],),
|
||||
([[{
|
||||
'a': 42,
|
||||
'b': 'hello',
|
||||
'c': 1.2
|
||||
}]],),
|
||||
([[[{
|
||||
'a': 42,
|
||||
'b': 'hello',
|
||||
'c': 1.2
|
||||
}]]],),
|
||||
((({
|
||||
'a': 42,
|
||||
'b': 'hello',
|
||||
'c': 1.2
|
||||
},),),),
|
||||
(((({
|
||||
'a': 42,
|
||||
'b': 'hello',
|
||||
'c': 1.2
|
||||
},),),),),
|
||||
([({
|
||||
'a': 42,
|
||||
'b': 'hello',
|
||||
'c': 1.2
|
||||
},)],),
|
||||
(([{
|
||||
'a': 42,
|
||||
'b': 'hello',
|
||||
'c': 1.2
|
||||
}],),),
|
||||
)
|
||||
def test_import_export_unsupport_sequence(self, v):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Invalid sequence: only supports single level'):
|
||||
_ = base_config.Config({'key': v})
|
||||
|
||||
def test_construct_subtype(self):
|
||||
pass
|
||||
|
||||
def test_import_config(self):
|
||||
params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
|
||||
self.assertLen(params.a, 2)
|
||||
self.assertEqual(params.a[0].b, 2)
|
||||
self.assertEqual(type(params.a[0]), base_config.Config)
|
||||
self.assertEqual(pprint.pformat(params.a[0].b), '2')
|
||||
self.assertEqual(type(params.a[1]), base_config.Config)
|
||||
self.assertEqual(type(params.a[1].c), base_config.Config)
|
||||
self.assertEqual(pprint.pformat(params.a[1].c.d), '3')
|
||||
|
||||
def test_override(self):
|
||||
params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
|
||||
params.override({'a': [{'b': 4}, {'c': {'d': 5}}]}, is_strict=False)
|
||||
self.assertEqual(type(params.a), list)
|
||||
self.assertEqual(type(params.a[0]), base_config.Config)
|
||||
self.assertEqual(pprint.pformat(params.a[0].b), '4')
|
||||
self.assertEqual(type(params.a[1]), base_config.Config)
|
||||
self.assertEqual(type(params.a[1].c), base_config.Config)
|
||||
self.assertEqual(pprint.pformat(params.a[1].c.d), '5')
|
||||
|
||||
@parameterized.parameters(
|
||||
([{}],),
|
||||
(({},),),
|
||||
)
|
||||
def test_config_vs_params_dict(self, v):
|
||||
d = {'key': v}
|
||||
self.assertEqual(type(base_config.Config(d).key[0]), base_config.Config)
|
||||
self.assertEqual(type(base_config.params_dict.ParamsDict(d).key[0]), dict)
|
||||
|
||||
def test_ppformat(self):
|
||||
self.assertEqual(
|
||||
pprint.pformat([
|
||||
's', 1, 1.0, True, None, {}, [], (), {
|
||||
(2,): (3, [4], {
|
||||
6: 7,
|
||||
}),
|
||||
8: 9,
|
||||
}
|
||||
]),
|
||||
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+410
@@ -0,0 +1,410 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A parameter dictionary class which supports the nest structure."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import re
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
# regex pattern that matches on key-value pairs in a comma-separated
|
||||
# key-value pair string. It splits each k-v pair on the = sign, and
|
||||
# matches on values that are within single quotes, double quotes, single
|
||||
# values (e.g. floats, ints, etc.), and a lists within brackets.
|
||||
_PARAM_RE = re.compile(r"""
|
||||
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
|
||||
\s*=\s*
|
||||
((?P<val>\'(.*?)\' # single quote
|
||||
|
|
||||
\"(.*?)\" # double quote
|
||||
|
|
||||
[^,\[]* # single value
|
||||
|
|
||||
\[[^\]]*\])) # list of values
|
||||
($|,\s*)""", re.VERBOSE)
|
||||
|
||||
|
||||
class ParamsDict(object):
|
||||
"""A hyperparameter container class."""
|
||||
|
||||
RESERVED_ATTR = ['_locked', '_restrictions']
|
||||
|
||||
def __init__(self, default_params=None, restrictions=None):
|
||||
"""Instantiate a ParamsDict.
|
||||
|
||||
Instantiate a ParamsDict given a set of default parameters and a list of
|
||||
restrictions. Upon initialization, it validates itself by checking all the
|
||||
defined restrictions, and raise error if it finds inconsistency.
|
||||
|
||||
Args:
|
||||
default_params: a Python dict or another ParamsDict object including the
|
||||
default parameters to initialize.
|
||||
restrictions: a list of strings, which define a list of restrictions to
|
||||
ensure the consistency of different parameters internally. Each
|
||||
restriction string is defined as a binary relation with a set of
|
||||
operators, including {'==', '!=', '<', '<=', '>', '>='}.
|
||||
"""
|
||||
self._locked = False
|
||||
self._restrictions = []
|
||||
if restrictions:
|
||||
self._restrictions = restrictions
|
||||
if default_params is None:
|
||||
default_params = {}
|
||||
self.override(default_params, is_strict=False)
|
||||
self.validate()
|
||||
|
||||
def _set(self, k, v):
|
||||
if isinstance(v, dict):
|
||||
self.__dict__[k] = ParamsDict(v)
|
||||
else:
|
||||
self.__dict__[k] = copy.deepcopy(v)
|
||||
|
||||
def __setattr__(self, k, v):
|
||||
"""Sets the value of the existing key.
|
||||
|
||||
Note that this does not allow directly defining a new key. Use the
|
||||
`override` method with `is_strict=False` instead.
|
||||
|
||||
Args:
|
||||
k: the key string.
|
||||
v: the value to be used to set the key `k`.
|
||||
|
||||
Raises:
|
||||
KeyError: if k is not defined in the ParamsDict.
|
||||
"""
|
||||
if k not in ParamsDict.RESERVED_ATTR:
|
||||
if k not in self.__dict__.keys():
|
||||
raise KeyError('The key `%{}` does not exist. '
|
||||
'To extend the existing keys, use '
|
||||
'`override` with `is_strict` = True.'.format(k))
|
||||
if self._locked:
|
||||
raise ValueError('The ParamsDict has been locked. '
|
||||
'No change is allowed.')
|
||||
self._set(k, v)
|
||||
|
||||
def __getattr__(self, k):
|
||||
"""Gets the value of the existing key.
|
||||
|
||||
Args:
|
||||
k: the key string.
|
||||
|
||||
Returns:
|
||||
the value of the key.
|
||||
|
||||
Raises:
|
||||
AttributeError: if k is not defined in the ParamsDict.
|
||||
"""
|
||||
if k not in self.__dict__.keys():
|
||||
raise AttributeError('The key `{}` does not exist. '.format(k))
|
||||
return self.__dict__[k]
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Implements the membership test operator."""
|
||||
return key in self.__dict__
|
||||
|
||||
def get(self, key, value=None):
|
||||
"""Accesses through built-in dictionary get method."""
|
||||
return self.__dict__.get(key, value)
|
||||
|
||||
def override(self, override_params, is_strict=True):
|
||||
"""Override the ParamsDict with a set of given params.
|
||||
|
||||
Args:
|
||||
override_params: a dict or a ParamsDict specifying the parameters to
|
||||
be overridden.
|
||||
is_strict: a boolean specifying whether override is strict or not. If
|
||||
True, keys in `override_params` must be present in the ParamsDict.
|
||||
If False, keys in `override_params` can be different from what is
|
||||
currently defined in the ParamsDict. In this case, the ParamsDict will
|
||||
be extended to include the new keys.
|
||||
"""
|
||||
if self._locked:
|
||||
raise ValueError('The ParamsDict has been locked. No change is allowed.')
|
||||
if isinstance(override_params, ParamsDict):
|
||||
override_params = override_params.as_dict()
|
||||
self._override(override_params, is_strict) # pylint: disable=protected-access
|
||||
|
||||
def _override(self, override_dict, is_strict=True):
|
||||
"""The implementation of `override`."""
|
||||
for k, v in six.iteritems(override_dict):
|
||||
if k in ParamsDict.RESERVED_ATTR:
|
||||
raise KeyError('The key `%{}` is internally reserved. '
|
||||
'Can not be overridden.')
|
||||
if k not in self.__dict__.keys():
|
||||
if is_strict:
|
||||
raise KeyError('The key `{}` does not exist. '
|
||||
'To extend the existing keys, use '
|
||||
'`override` with `is_strict` = False.'.format(k))
|
||||
else:
|
||||
self._set(k, v)
|
||||
else:
|
||||
if isinstance(v, dict):
|
||||
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
|
||||
elif isinstance(v, ParamsDict):
|
||||
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
|
||||
else:
|
||||
self.__dict__[k] = copy.deepcopy(v)
|
||||
|
||||
def lock(self):
|
||||
"""Makes the ParamsDict immutable."""
|
||||
self._locked = True
|
||||
|
||||
def as_dict(self):
|
||||
"""Returns a dict representation of ParamsDict.
|
||||
|
||||
For the nested ParamsDict, a nested dict will be returned.
|
||||
"""
|
||||
params_dict = {}
|
||||
for k, v in six.iteritems(self.__dict__):
|
||||
if k not in ParamsDict.RESERVED_ATTR:
|
||||
if isinstance(v, ParamsDict):
|
||||
params_dict[k] = v.as_dict()
|
||||
else:
|
||||
params_dict[k] = copy.deepcopy(v)
|
||||
return params_dict
|
||||
|
||||
def validate(self):
|
||||
"""Validate the parameters consistency based on the restrictions.
|
||||
|
||||
This method validates the internal consistency using the pre-defined list of
|
||||
restrictions. A restriction is defined as a string which specfiies a binary
|
||||
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
|
||||
'>='}. Note that the meaning of these operators are consistent with the
|
||||
underlying Python immplementation. Users should make sure the define
|
||||
restrictions on their type make sense.
|
||||
|
||||
For example, for a ParamsDict like the following
|
||||
```
|
||||
a:
|
||||
a1: 1
|
||||
a2: 2
|
||||
b:
|
||||
bb:
|
||||
bb1: 10
|
||||
bb2: 20
|
||||
ccc:
|
||||
a1: 1
|
||||
a3: 3
|
||||
```
|
||||
one can define two restrictions like this
|
||||
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
|
||||
|
||||
What it enforces are:
|
||||
- a.a1 = 1 == b.ccc.a1 = 2
|
||||
- a.a2 = 2 <= b.bb.bb2 = 20
|
||||
|
||||
Raises:
|
||||
KeyError: if any of the following happens
|
||||
(1) any of parameters in any of restrictions is not defined in
|
||||
ParamsDict,
|
||||
(2) any inconsistency violating the restriction is found.
|
||||
ValueError: if the restriction defined in the string is not supported.
|
||||
"""
|
||||
def _get_kv(dotted_string, params_dict):
|
||||
tokenized_params = dotted_string.split('.')
|
||||
v = params_dict
|
||||
for t in tokenized_params:
|
||||
v = v[t]
|
||||
return tokenized_params[-1], v
|
||||
|
||||
def _get_kvs(tokens, params_dict):
|
||||
if len(tokens) != 2:
|
||||
raise ValueError('Only support binary relation in restriction.')
|
||||
stripped_tokens = [t.strip() for t in tokens]
|
||||
left_k, left_v = _get_kv(stripped_tokens[0], params_dict)
|
||||
right_k, right_v = _get_kv(stripped_tokens[1], params_dict)
|
||||
return left_k, left_v, right_k, right_v
|
||||
|
||||
params_dict = self.as_dict()
|
||||
for restriction in self._restrictions:
|
||||
if '==' in restriction:
|
||||
tokens = restriction.split('==')
|
||||
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
|
||||
if left_v != right_v:
|
||||
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
|
||||
.format(tokens[0], tokens[1]))
|
||||
elif '!=' in restriction:
|
||||
tokens = restriction.split('!=')
|
||||
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
|
||||
if left_v == right_v:
|
||||
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
|
||||
.format(tokens[0], tokens[1]))
|
||||
elif '<' in restriction:
|
||||
tokens = restriction.split('<')
|
||||
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
|
||||
if left_v >= right_v:
|
||||
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
|
||||
.format(tokens[0], tokens[1]))
|
||||
elif '<=' in restriction:
|
||||
tokens = restriction.split('<=')
|
||||
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
|
||||
if left_v > right_v:
|
||||
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
|
||||
.format(tokens[0], tokens[1]))
|
||||
elif '>' in restriction:
|
||||
tokens = restriction.split('>')
|
||||
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
|
||||
if left_v <= right_v:
|
||||
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
|
||||
.format(tokens[0], tokens[1]))
|
||||
elif '>=' in restriction:
|
||||
tokens = restriction.split('>=')
|
||||
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
|
||||
if left_v < right_v:
|
||||
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
|
||||
.format(tokens[0], tokens[1]))
|
||||
else:
|
||||
raise ValueError('Unsupported relation in restriction.')
|
||||
|
||||
|
||||
def read_yaml_to_params_dict(file_path):
|
||||
"""Reads a YAML file to a ParamsDict."""
|
||||
with tf.io.gfile.GFile(file_path, 'r') as f:
|
||||
params_dict = yaml.load(f)
|
||||
return ParamsDict(params_dict)
|
||||
|
||||
|
||||
def save_params_dict_to_yaml(params, file_path):
|
||||
"""Saves the input ParamsDict to a YAML file."""
|
||||
with tf.io.gfile.GFile(file_path, 'w') as f:
|
||||
|
||||
def _my_list_rep(dumper, data):
|
||||
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
|
||||
return dumper.represent_sequence(
|
||||
u'tag:yaml.org,2002:seq', data, flow_style=True)
|
||||
yaml.add_representer(list, _my_list_rep)
|
||||
yaml.dump(params.as_dict(), f, default_flow_style=False)
|
||||
|
||||
|
||||
def nested_csv_str_to_json_str(csv_str):
|
||||
"""Converts a nested (using '.') comma-separated k=v string to a JSON string.
|
||||
|
||||
Converts a comma-separated string of key/value pairs that supports
|
||||
nesting of keys to a JSON string. Nesting is implemented using
|
||||
'.' between levels for a given key.
|
||||
|
||||
Spacing between commas and = is supported (e.g. there is no difference between
|
||||
"a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
|
||||
keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
|
||||
|
||||
Note that this will only support values supported by CSV, meaning
|
||||
values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
|
||||
supported. Strings are supported as well, e.g. "a='hello'".
|
||||
|
||||
An example conversion would be:
|
||||
|
||||
"a=1, b=2, c.a=2, c.b=3, d.a.a=5"
|
||||
|
||||
to
|
||||
|
||||
"{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
|
||||
|
||||
Args:
|
||||
csv_str: the comma separated string.
|
||||
|
||||
Returns:
|
||||
the converted JSON string.
|
||||
|
||||
Raises:
|
||||
ValueError: If csv_str is not in a comma separated string or
|
||||
if the string is formatted incorrectly.
|
||||
"""
|
||||
if not csv_str:
|
||||
return ''
|
||||
|
||||
formatted_entries = []
|
||||
nested_map = collections.defaultdict(list)
|
||||
pos = 0
|
||||
while pos < len(csv_str):
|
||||
m = _PARAM_RE.match(csv_str, pos)
|
||||
if not m:
|
||||
raise ValueError('Malformed hyperparameter value while parsing '
|
||||
'CSV string: %s' % csv_str[pos:])
|
||||
pos = m.end()
|
||||
# Parse the values.
|
||||
m_dict = m.groupdict()
|
||||
name = m_dict['name']
|
||||
v = m_dict['val']
|
||||
|
||||
# If a GCS path (e.g. gs://...) is provided, wrap this in quotes
|
||||
# as yaml.load would otherwise throw an exception
|
||||
if re.match(r'(?=[^\"\'])(?=[gs://])', v):
|
||||
v = '\'{}\''.format(v)
|
||||
|
||||
name_nested = name.split('.')
|
||||
if len(name_nested) > 1:
|
||||
grouping = name_nested[0]
|
||||
value = '.'.join(name_nested[1:]) + '=' + v
|
||||
nested_map[grouping].append(value)
|
||||
else:
|
||||
formatted_entries.append('%s : %s' % (name, v))
|
||||
|
||||
for grouping, value in nested_map.items():
|
||||
value = ','.join(value)
|
||||
value = nested_csv_str_to_json_str(value)
|
||||
formatted_entries.append('%s : %s' % (grouping, value))
|
||||
return '{' + ', '.join(formatted_entries) + '}'
|
||||
|
||||
|
||||
def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
|
||||
"""Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
|
||||
|
||||
The logic of the function is outlined below:
|
||||
1. Test that the input is a dict. If not, proceed to 2.
|
||||
2. Tests that the input is a string. If not, raise unknown ValueError
|
||||
2.1. Test if the string is in a CSV format. If so, parse.
|
||||
If not, proceed to 2.2.
|
||||
2.2. Try loading the string as a YAML/JSON. If successful, parse to
|
||||
dict and use it to override. If not, proceed to 2.3.
|
||||
2.3. Try using the string as a file path and load the YAML file.
|
||||
|
||||
Args:
|
||||
params: a ParamsDict object to be overridden.
|
||||
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
|
||||
path to a YAML file specifying the parameters to be overridden.
|
||||
is_strict: a boolean specifying whether override is strict or not.
|
||||
|
||||
Returns:
|
||||
params: the overridden ParamsDict object.
|
||||
|
||||
Raises:
|
||||
ValueError: if failed to override the parameters.
|
||||
"""
|
||||
if not dict_or_string_or_yaml_file:
|
||||
return params
|
||||
if isinstance(dict_or_string_or_yaml_file, dict):
|
||||
params.override(dict_or_string_or_yaml_file, is_strict)
|
||||
elif isinstance(dict_or_string_or_yaml_file, six.string_types):
|
||||
try:
|
||||
dict_or_string_or_yaml_file = (
|
||||
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
|
||||
except ValueError:
|
||||
pass
|
||||
params_dict = yaml.load(dict_or_string_or_yaml_file)
|
||||
if isinstance(params_dict, dict):
|
||||
params.override(params_dict, is_strict)
|
||||
else:
|
||||
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
|
||||
params.override(yaml.load(f), is_strict)
|
||||
else:
|
||||
raise ValueError('Unknown input type to parse.')
|
||||
return params
|
||||
+322
@@ -0,0 +1,322 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for official.modeling.hyperparams.params_dict.py."""
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from official.modeling.hyperparams import params_dict
|
||||
|
||||
|
||||
class ParamsDictTest(tf.test.TestCase):
|
||||
|
||||
def test_init_from_an_empty_dict(self):
|
||||
params = params_dict.ParamsDict()
|
||||
with self.assertRaises(AttributeError):
|
||||
_ = params.a
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
params.a = 'aa'
|
||||
|
||||
def test_init_from_a_dict(self):
|
||||
params = params_dict.ParamsDict({'a': 'aa', 'b': 2})
|
||||
self.assertEqual(params.a, 'aa')
|
||||
self.assertEqual(params.b, 2)
|
||||
|
||||
def test_init_from_a_param_dict(self):
|
||||
params_init = params_dict.ParamsDict({'a': 'aa', 'b': 2})
|
||||
params = params_dict.ParamsDict(params_init)
|
||||
self.assertEqual(params.a, 'aa')
|
||||
self.assertEqual(params.b, 2)
|
||||
|
||||
def test_lock(self):
|
||||
params = params_dict.ParamsDict({'a': 1, 'b': 2})
|
||||
params.lock()
|
||||
with self.assertRaises(ValueError):
|
||||
params.a = 10
|
||||
with self.assertRaises(ValueError):
|
||||
params.override({'b': 20})
|
||||
|
||||
def test_setattr(self):
|
||||
params = params_dict.ParamsDict()
|
||||
params.override(
|
||||
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
|
||||
params.c = 'ccc'
|
||||
self.assertEqual(params.a, 'aa')
|
||||
self.assertEqual(params.b, 2)
|
||||
self.assertEqual(params.c, 'ccc')
|
||||
|
||||
def test_getattr(self):
|
||||
params = params_dict.ParamsDict()
|
||||
params.override(
|
||||
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
|
||||
self.assertEqual(params.a, 'aa')
|
||||
self.assertEqual(params.b, 2)
|
||||
self.assertEqual(params.c, None)
|
||||
|
||||
def test_contains(self):
|
||||
params = params_dict.ParamsDict()
|
||||
params.override(
|
||||
{'a': 'aa'}, is_strict=False)
|
||||
self.assertIn('a', params)
|
||||
self.assertNotIn('b', params)
|
||||
|
||||
def test_get(self):
|
||||
params = params_dict.ParamsDict()
|
||||
params.override(
|
||||
{'a': 'aa'}, is_strict=False)
|
||||
self.assertEqual(params.get('a'), 'aa')
|
||||
self.assertEqual(params.get('b', 2), 2)
|
||||
self.assertEqual(params.get('b'), None)
|
||||
|
||||
def test_override_is_strict_true(self):
|
||||
params = params_dict.ParamsDict(
|
||||
{'a': 'aa', 'b': 2, 'c': {'c1': 'cc', 'c2': 20}})
|
||||
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
|
||||
self.assertEqual(params.a, 2)
|
||||
self.assertEqual(params.c.c1, 'ccc')
|
||||
with self.assertRaises(KeyError):
|
||||
params.override({'d': 'ddd'}, is_strict=True)
|
||||
with self.assertRaises(KeyError):
|
||||
params.override({'c': {'c3': 30}}, is_strict=True)
|
||||
|
||||
def test_override_is_strict_false(self):
|
||||
params = params_dict.ParamsDict(
|
||||
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
|
||||
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
|
||||
self.assertEqual(params.a, 2)
|
||||
self.assertEqual(params.c.c3, 3000)
|
||||
params.override({'d': 'ddd'}, is_strict=False)
|
||||
self.assertEqual(params.d, 'ddd')
|
||||
params.override({'c': {'c4': 4444}}, is_strict=False)
|
||||
self.assertEqual(params.c.c4, 4444)
|
||||
|
||||
def test_as_dict(self):
|
||||
params = params_dict.ParamsDict(
|
||||
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
|
||||
params_d = params.as_dict()
|
||||
self.assertEqual(params_d['a'], 'aa')
|
||||
self.assertEqual(params_d['b'], 2)
|
||||
self.assertEqual(params_d['c']['c1'], 10)
|
||||
self.assertEqual(params_d['c']['c2'], 20)
|
||||
|
||||
def test_validate(self):
|
||||
# Raise error due to the unknown parameter.
|
||||
with self.assertRaises(KeyError):
|
||||
params = params_dict.ParamsDict(
|
||||
{'a': 1, 'b': {'a': 11}}, ['a == c'])
|
||||
|
||||
# OK to check equality of two nested dicts.
|
||||
params = params_dict.ParamsDict(
|
||||
{'a': 1, 'b': {'a': 10}, 'c': {'a': 10}}, ['b == c'])
|
||||
|
||||
# Raise error due to inconsistency
|
||||
with self.assertRaises(KeyError):
|
||||
params = params_dict.ParamsDict(
|
||||
{'a': 1, 'c': {'a': 10}}, ['a == c.a'])
|
||||
|
||||
# Valid rule.
|
||||
params = params_dict.ParamsDict(
|
||||
{'a': 1, 'c': {'a': 1}}, ['a == c.a'])
|
||||
|
||||
# Overridding violates the existing rule, raise error upon validate.
|
||||
params.override({'a': 11})
|
||||
with self.assertRaises(KeyError):
|
||||
params.validate()
|
||||
|
||||
|
||||
class ParamsDictIOTest(tf.test.TestCase):
|
||||
|
||||
def write_temp_file(self, filename, text):
|
||||
temp_file = os.path.join(self.get_temp_dir(), filename)
|
||||
with tf.io.gfile.GFile(temp_file, 'w') as writer:
|
||||
writer.write(text)
|
||||
return temp_file
|
||||
|
||||
def test_save_params_dict_to_yaml(self):
|
||||
params = params_dict.ParamsDict(
|
||||
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
|
||||
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
|
||||
params_dict.save_params_dict_to_yaml(params, output_yaml_file)
|
||||
|
||||
with tf.io.gfile.GFile(output_yaml_file, 'r') as f:
|
||||
params_d = yaml.load(f)
|
||||
self.assertEqual(params.a, params_d['a'])
|
||||
self.assertEqual(params.b, params_d['b'])
|
||||
self.assertEqual(params.c.c1, params_d['c']['c1'])
|
||||
self.assertEqual(params.c.c2, params_d['c']['c2'])
|
||||
|
||||
def test_read_yaml_to_params_dict(self):
|
||||
input_yaml_file = self.write_temp_file(
|
||||
'params.yaml', r"""
|
||||
a: 'aa'
|
||||
b: 2
|
||||
c:
|
||||
c1: 10
|
||||
c2: 20
|
||||
""")
|
||||
params = params_dict.read_yaml_to_params_dict(input_yaml_file)
|
||||
|
||||
self.assertEqual(params.a, 'aa')
|
||||
self.assertEqual(params.b, 2)
|
||||
self.assertEqual(params.c.c1, 10)
|
||||
self.assertEqual(params.c.c2, 20)
|
||||
|
||||
def test_override_params_dict_using_dict(self):
|
||||
params = params_dict.ParamsDict({
|
||||
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
|
||||
override_dict = {'b': 5.2, 'c': [30, 40]}
|
||||
params = params_dict.override_params_dict(
|
||||
params, override_dict, is_strict=True)
|
||||
self.assertEqual(1, params.a)
|
||||
self.assertEqual(5.2, params.b)
|
||||
self.assertEqual([30, 40], params.c)
|
||||
self.assertEqual('hello', params.d)
|
||||
self.assertEqual(False, params.e)
|
||||
|
||||
def test_override_params_dict_using_yaml_string(self):
|
||||
params = params_dict.ParamsDict({
|
||||
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
|
||||
override_yaml_string = "'b': 5.2\n'c': [30, 40]"
|
||||
params = params_dict.override_params_dict(
|
||||
params, override_yaml_string, is_strict=True)
|
||||
self.assertEqual(1, params.a)
|
||||
self.assertEqual(5.2, params.b)
|
||||
self.assertEqual([30, 40], params.c)
|
||||
self.assertEqual('hello', params.d)
|
||||
self.assertEqual(False, params.e)
|
||||
|
||||
def test_override_params_dict_using_json_string(self):
|
||||
params = params_dict.ParamsDict({
|
||||
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
|
||||
'd': {'d1': {'d2': 'hello'}}, 'e': False})
|
||||
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
|
||||
params = params_dict.override_params_dict(
|
||||
params, override_json_string, is_strict=True)
|
||||
self.assertEqual(1, params.a)
|
||||
self.assertEqual(2, params.b.b1)
|
||||
self.assertEqual([3, 4], params.b.b2)
|
||||
self.assertEqual('hi', params.d.d1.d2)
|
||||
self.assertEqual(False, params.e)
|
||||
|
||||
def test_override_params_dict_using_csv_string(self):
|
||||
params = params_dict.ParamsDict({
|
||||
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
|
||||
'd': {'d1': {'d2': 'hello'}}, 'e': False})
|
||||
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
|
||||
params = params_dict.override_params_dict(
|
||||
params, override_csv_string, is_strict=True)
|
||||
self.assertEqual(1, params.a)
|
||||
self.assertEqual(2, params.b.b1)
|
||||
self.assertEqual([3, 4], params.b.b2)
|
||||
self.assertEqual('hi, world', params.d.d1.d2)
|
||||
self.assertEqual('gs://test', params.e)
|
||||
|
||||
def test_override_params_dict_using_yaml_file(self):
|
||||
params = params_dict.ParamsDict({
|
||||
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
|
||||
override_yaml_file = self.write_temp_file(
|
||||
'params.yaml', r"""
|
||||
b: 5.2
|
||||
c: [30, 40]
|
||||
""")
|
||||
params = params_dict.override_params_dict(
|
||||
params, override_yaml_file, is_strict=True)
|
||||
self.assertEqual(1, params.a)
|
||||
self.assertEqual(5.2, params.b)
|
||||
self.assertEqual([30, 40], params.c)
|
||||
self.assertEqual('hello', params.d)
|
||||
self.assertEqual(False, params.e)
|
||||
|
||||
|
||||
class IOTest(tf.test.TestCase):
|
||||
|
||||
def test_basic_csv_str_to_json_str(self):
|
||||
csv_str = 'a=1,b=2,c=3'
|
||||
json_str = '{a : 1, b : 2, c : 3}'
|
||||
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
|
||||
self.assertEqual(converted_csv_str, json_str)
|
||||
|
||||
def test_basic_csv_str_load(self):
|
||||
csv_str = 'a=1,b=2,c=3'
|
||||
expected_output = {'a': 1, 'b': 2, 'c': 3}
|
||||
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
|
||||
converted_dict = yaml.load(converted_csv_str)
|
||||
self.assertDictEqual(converted_dict, expected_output)
|
||||
|
||||
def test_basic_nested_csv_str_to_json_str(self):
|
||||
csv_str = 'a=1,b.b1=2'
|
||||
json_str = '{a : 1, b : {b1 : 2}}'
|
||||
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
|
||||
self.assertEqual(converted_csv_str, json_str)
|
||||
|
||||
def test_basic_nested_csv_str_load(self):
|
||||
csv_str = 'a=1,b.b1=2,c.c1=3'
|
||||
expected_output = {'a': 1, 'b': {'b1': 2}, 'c': {'c1': 3}}
|
||||
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
|
||||
converted_dict = yaml.load(converted_csv_str)
|
||||
self.assertDictEqual(converted_dict, expected_output)
|
||||
|
||||
def test_complex_nested_csv_str_to_json_str(self):
|
||||
csv_str = 'a.aa.aaa.aaaaa.a=1'
|
||||
json_str = '{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
|
||||
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
|
||||
self.assertEqual(converted_csv_str, json_str)
|
||||
|
||||
def test_complex_nested_csv_str_load(self):
|
||||
csv_str = 'a.aa.aaa.aaaaa.a=1,a.a=2'
|
||||
expected_output = {'a': {'aa': {'aaa': {'aaaaa': {'a': 1}}}, 'a': 2}}
|
||||
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
|
||||
converted_dict = yaml.load(converted_csv_str)
|
||||
self.assertDictEqual(converted_dict, expected_output)
|
||||
|
||||
def test_csv_str_load_supported_datatypes(self):
|
||||
csv_str = 'a=1,b=2.,c=[1,2,3],d=\'hello, there\',e=\"Hi.\"'
|
||||
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
|
||||
converted_dict = yaml.load(converted_csv_str)
|
||||
self.assertEqual(converted_dict['a'], 1)
|
||||
self.assertEqual(converted_dict['b'], 2.)
|
||||
self.assertEqual(converted_dict['c'], [1, 2, 3])
|
||||
self.assertEqual(converted_dict['d'], 'hello, there')
|
||||
self.assertEqual(converted_dict['e'], 'Hi.')
|
||||
|
||||
def test_csv_str_load_unsupported_datatypes(self):
|
||||
csv_str = 'a=[[1,2,3],[4,5,6]]'
|
||||
self.assertRaises(ValueError,
|
||||
params_dict.nested_csv_str_to_json_str,
|
||||
csv_str)
|
||||
|
||||
def test_csv_str_to_json_str_spacing(self):
|
||||
csv_str1 = 'a=1,b=2,c=3'
|
||||
csv_str2 = 'a = 1, b = 2, c = 3'
|
||||
json_str = '{a : 1, b : 2, c : 3}'
|
||||
converted_csv_str1 = params_dict.nested_csv_str_to_json_str(csv_str1)
|
||||
converted_csv_str2 = params_dict.nested_csv_str_to_json_str(csv_str2)
|
||||
self.assertEqual(converted_csv_str1, converted_csv_str2)
|
||||
self.assertEqual(converted_csv_str1, json_str)
|
||||
self.assertEqual(converted_csv_str2, json_str)
|
||||
|
||||
def test_gcs_added_quotes(self):
|
||||
csv_str = 'a=gs://abc, b=gs://def'
|
||||
expected_output = '{a : \'gs://abc\', b : \'gs://def\'}'
|
||||
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
|
||||
self.assertEqual(converted_csv_str, expected_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functions and classes related to training performance."""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def configure_optimizer(optimizer,
|
||||
use_float16=False,
|
||||
use_graph_rewrite=False,
|
||||
loss_scale="dynamic"):
|
||||
"""Configures optimizer object with performance options."""
|
||||
if use_float16:
|
||||
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
|
||||
# in compile() with the "mixed_float16" policy, but since we do not call
|
||||
# compile(), we must wrap the optimizer manually.
|
||||
optimizer = (
|
||||
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
|
||||
optimizer, loss_scale=loss_scale))
|
||||
if use_graph_rewrite:
|
||||
# Note: the model dtype must be 'float32', which will ensure
|
||||
# tf.ckeras.mixed_precision and
|
||||
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
|
||||
# up.
|
||||
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
|
||||
optimizer)
|
||||
return optimizer
|
||||
|
||||
|
||||
def set_mixed_precision_policy(dtype, loss_scale=None):
|
||||
"""Sets mix precision policy."""
|
||||
if dtype == tf.float16:
|
||||
policy = tf.keras.mixed_precision.experimental.Policy(
|
||||
'mixed_float16', loss_scale=loss_scale)
|
||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
||||
elif dtype == tf.bfloat16:
|
||||
policy = tf.keras.mixed_precision.experimental.Policy(
|
||||
'mixed_bfloat16')
|
||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
||||
elif dtype == tf.float32:
|
||||
tf.keras.mixed_precision.experimental.set_policy('float32')
|
||||
else:
|
||||
raise ValueError("Unexpected dtype: %s" % dtype)
|
||||
+175
@@ -0,0 +1,175 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common TF utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.util import deprecation
|
||||
from official.modeling import activations
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None,
|
||||
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
|
||||
"input tensors. pack/unpack inputs to override __call__ is no longer "
|
||||
"needed."
|
||||
)
|
||||
def pack_inputs(inputs):
|
||||
"""Pack a list of `inputs` tensors to a tuple.
|
||||
|
||||
Args:
|
||||
inputs: a list of tensors.
|
||||
|
||||
Returns:
|
||||
a tuple of tensors. if any input is None, replace it with a special constant
|
||||
tensor.
|
||||
"""
|
||||
inputs = tf.nest.flatten(inputs)
|
||||
outputs = []
|
||||
for x in inputs:
|
||||
if x is None:
|
||||
outputs.append(tf.constant(0, shape=[], dtype=tf.int32))
|
||||
else:
|
||||
outputs.append(x)
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None,
|
||||
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
|
||||
"input tensors. pack/unpack inputs to override __call__ is no longer "
|
||||
"needed."
|
||||
)
|
||||
def unpack_inputs(inputs):
|
||||
"""unpack a tuple of `inputs` tensors to a tuple.
|
||||
|
||||
Args:
|
||||
inputs: a list of tensors.
|
||||
|
||||
Returns:
|
||||
a tuple of tensors. if any input is a special constant tensor, replace it
|
||||
with None.
|
||||
"""
|
||||
inputs = tf.nest.flatten(inputs)
|
||||
outputs = []
|
||||
for x in inputs:
|
||||
if is_special_none_tensor(x):
|
||||
outputs.append(None)
|
||||
else:
|
||||
outputs.append(x)
|
||||
x = tuple(outputs)
|
||||
|
||||
# To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
|
||||
# from triggering.
|
||||
if len(x) == 1:
|
||||
return x[0]
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
def is_special_none_tensor(tensor):
|
||||
"""Checks if a tensor is a special None Tensor."""
|
||||
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
|
||||
|
||||
|
||||
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
|
||||
def get_activation(identifier):
|
||||
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
|
||||
|
||||
It checks string first and if it is one of customized activation not in TF,
|
||||
the corresponding activation will be returned. For non-customized activation
|
||||
names and callable identifiers, always fallback to tf.keras.activations.get.
|
||||
|
||||
Args:
|
||||
identifier: String name of the activation function or callable.
|
||||
|
||||
Returns:
|
||||
A Python function corresponding to the activation function.
|
||||
"""
|
||||
if isinstance(identifier, six.string_types):
|
||||
name_to_fn = {
|
||||
"gelu": activations.gelu,
|
||||
"simple_swish": activations.simple_swish,
|
||||
"hard_swish": activations.hard_swish,
|
||||
"identity": activations.identity,
|
||||
}
|
||||
identifier = str(identifier).lower()
|
||||
if identifier in name_to_fn:
|
||||
return tf.keras.activations.get(name_to_fn[identifier])
|
||||
return tf.keras.activations.get(identifier)
|
||||
|
||||
|
||||
def get_shape_list(tensor, expected_rank=None, name=None):
|
||||
"""Returns a list of the shape of tensor, preferring static dimensions.
|
||||
|
||||
Args:
|
||||
tensor: A tf.Tensor object to find the shape of.
|
||||
expected_rank: (optional) int. The expected rank of `tensor`. If this is
|
||||
specified and the `tensor` has a different rank, and exception will be
|
||||
thrown.
|
||||
name: Optional name of the tensor for the error message.
|
||||
|
||||
Returns:
|
||||
A list of dimensions of the shape of tensor. All static dimensions will
|
||||
be returned as python integers, and dynamic dimensions will be returned
|
||||
as tf.Tensor scalars.
|
||||
"""
|
||||
if expected_rank is not None:
|
||||
assert_rank(tensor, expected_rank, name)
|
||||
|
||||
shape = tensor.shape.as_list()
|
||||
|
||||
non_static_indexes = []
|
||||
for (index, dim) in enumerate(shape):
|
||||
if dim is None:
|
||||
non_static_indexes.append(index)
|
||||
|
||||
if not non_static_indexes:
|
||||
return shape
|
||||
|
||||
dyn_shape = tf.shape(tensor)
|
||||
for index in non_static_indexes:
|
||||
shape[index] = dyn_shape[index]
|
||||
return shape
|
||||
|
||||
|
||||
def assert_rank(tensor, expected_rank, name=None):
|
||||
"""Raises an exception if the tensor rank is not of the expected rank.
|
||||
|
||||
Args:
|
||||
tensor: A tf.Tensor to check the rank of.
|
||||
expected_rank: Python integer or list of integers, expected rank.
|
||||
name: Optional name of the tensor for the error message.
|
||||
|
||||
Raises:
|
||||
ValueError: If the expected shape doesn't match the actual shape.
|
||||
"""
|
||||
expected_rank_dict = {}
|
||||
if isinstance(expected_rank, six.integer_types):
|
||||
expected_rank_dict[expected_rank] = True
|
||||
else:
|
||||
for x in expected_rank:
|
||||
expected_rank_dict[x] = True
|
||||
|
||||
actual_rank = tensor.shape.ndims
|
||||
if actual_rank not in expected_rank_dict:
|
||||
raise ValueError(
|
||||
"For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
|
||||
"equal to the expected tensor rank `%s`" %
|
||||
(name, actual_rank, str(tensor.shape), str(expected_rank)))
|
||||
+759
@@ -0,0 +1,759 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Custom training loop for running TensorFlow 2.0 models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
# from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
|
||||
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
|
||||
from official.modeling.hyperparams import params_dict
|
||||
from official.utils.misc import distribution_utils
|
||||
from official.utils.misc import keras_utils
|
||||
from official.utils import hyperparams_flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
strategy_flags_dict = hyperparams_flags.strategy_flags_dict
|
||||
hparam_flags_dict = hyperparams_flags.hparam_flags_dict
|
||||
|
||||
|
||||
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
|
||||
"""Saves model to model_dir with provided checkpoint prefix."""
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
|
||||
saved_path = checkpoint.save(checkpoint_path)
|
||||
logging.info('Saving model as TF checkpoint: %s', saved_path)
|
||||
|
||||
|
||||
def _steps_to_run(current_step, total_steps, steps_per_loop):
|
||||
"""Calculates steps to run on device."""
|
||||
if steps_per_loop <= 0:
|
||||
raise ValueError('steps_per_loop should be positive integer.')
|
||||
return min(total_steps - current_step, steps_per_loop)
|
||||
|
||||
|
||||
def _no_metric():
|
||||
return None
|
||||
|
||||
|
||||
class SummaryWriter(object):
|
||||
"""Simple SummaryWriter for writing dictionary of metrics.
|
||||
|
||||
Attributes:
|
||||
writer: The tf.SummaryWriter.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: Text, name: Text):
|
||||
"""Inits SummaryWriter with paths.
|
||||
|
||||
Arguments:
|
||||
model_dir: the model folder path.
|
||||
name: the summary subfolder name.
|
||||
"""
|
||||
self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
|
||||
|
||||
def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
|
||||
"""Write metrics to summary with the given writer.
|
||||
|
||||
Args:
|
||||
metrics: a dictionary of metrics values. Prefer dictionary.
|
||||
step: integer. The training step.
|
||||
"""
|
||||
if not isinstance(metrics, dict):
|
||||
# Support scalar metric without name.
|
||||
logging.warning('Warning: summary writer prefer metrics as dictionary.')
|
||||
metrics = {'metric': metrics}
|
||||
|
||||
with self.writer.as_default():
|
||||
for k, v in metrics.items():
|
||||
tf.summary.scalar(k, v, step=step)
|
||||
self.writer.flush()
|
||||
|
||||
|
||||
class DistributedExecutor(object):
|
||||
"""Interface to train and eval models with tf.distribute.Strategy.
|
||||
|
||||
Arguments:
|
||||
strategy: an instance of tf.distribute.Strategy.
|
||||
params: Model configuration needed to run distribution strategy.
|
||||
model_fn: Keras model function. Signature:
|
||||
(params: ParamsDict) -> tf.keras.models.Model.
|
||||
loss_fn: loss function. Signature:
|
||||
(y_true: Tensor, y_pred: Tensor) -> Tensor
|
||||
metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric.
|
||||
is_multi_host: Set to True when using multi hosts for training, like multi
|
||||
worker GPU or TPU pod (slice). Otherwise, False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy,
|
||||
params,
|
||||
model_fn,
|
||||
loss_fn,
|
||||
is_multi_host=False):
|
||||
|
||||
self._params = params
|
||||
self._model_fn = model_fn
|
||||
self._loss_fn = loss_fn
|
||||
self._strategy = strategy
|
||||
self._checkpoint_name = 'ctl_step_{step}.ckpt'
|
||||
self._is_multi_host = is_multi_host
|
||||
self.train_summary_writer = None
|
||||
self.eval_summary_writer = None
|
||||
self.global_train_step = None
|
||||
|
||||
@property
|
||||
def checkpoint_name(self):
|
||||
"""Returns default checkpoint name."""
|
||||
return self._checkpoint_name
|
||||
|
||||
@checkpoint_name.setter
|
||||
def checkpoint_name(self, name):
|
||||
"""Sets default summary writer for the current thread."""
|
||||
self._checkpoint_name = name
|
||||
|
||||
def loss_fn(self):
|
||||
return self._loss_fn()
|
||||
|
||||
def model_fn(self, params):
|
||||
return self._model_fn(params)
|
||||
|
||||
def _save_config(self, model_dir):
|
||||
"""Save parameters to config files if model_dir is defined."""
|
||||
|
||||
logging.info('Save config to model_dir %s.', model_dir)
|
||||
if model_dir:
|
||||
if not tf.io.gfile.exists(model_dir):
|
||||
tf.io.gfile.makedirs(model_dir)
|
||||
self._params.lock()
|
||||
params_dict.save_params_dict_to_yaml(self._params,
|
||||
model_dir + '/params.yaml')
|
||||
else:
|
||||
logging.warning('model_dir is empty, so skip the save config.')
|
||||
|
||||
def _get_input_iterator(
|
||||
self, input_fn: Callable[..., tf.data.Dataset],
|
||||
strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
|
||||
"""Returns distributed dataset iterator.
|
||||
|
||||
Args:
|
||||
input_fn: (params: dict) -> tf.data.Dataset.
|
||||
strategy: an instance of tf.distribute.Strategy.
|
||||
|
||||
Returns:
|
||||
An iterator that yields input tensors.
|
||||
"""
|
||||
|
||||
if input_fn is None:
|
||||
return None
|
||||
# When training with multiple TPU workers, datasets needs to be cloned
|
||||
# across workers. Since Dataset instance cannot be cloned in eager mode,
|
||||
# we instead pass callable that returns a dataset.
|
||||
if self._is_multi_host:
|
||||
return iter(
|
||||
strategy.experimental_distribute_datasets_from_function(input_fn))
|
||||
else:
|
||||
input_data = input_fn()
|
||||
return iter(strategy.experimental_distribute_dataset(input_data))
|
||||
|
||||
def _create_replicated_step(self,
|
||||
strategy,
|
||||
model,
|
||||
loss_fn,
|
||||
optimizer,
|
||||
metric=None):
|
||||
|
||||
def _replicated_step(inputs):
|
||||
"""Replicated training step."""
|
||||
inputs, labels = inputs
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
outputs = model(inputs, training=True)
|
||||
prediction_loss = loss_fn(labels, outputs)
|
||||
loss = tf.reduce_mean(prediction_loss)
|
||||
loss = loss / strategy.num_replicas_in_sync
|
||||
if isinstance(metric, tf.keras.metrics.Metric):
|
||||
metric.update_state(labels, outputs)
|
||||
else:
|
||||
logging.error('train metric is not an instance of '
|
||||
'tf.keras.metrics.Metric.')
|
||||
|
||||
grads = tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(zip(grads, model.trainable_variables))
|
||||
return loss
|
||||
|
||||
return _replicated_step
|
||||
|
||||
def _create_train_step(self,
|
||||
strategy,
|
||||
model,
|
||||
loss_fn,
|
||||
optimizer,
|
||||
metric=None):
|
||||
"""Creates a distributed training step.
|
||||
|
||||
Args:
|
||||
strategy: an instance of tf.distribute.Strategy.
|
||||
model: (Tensor, bool) -> Tensor. model function.
|
||||
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
|
||||
optimizer: tf.keras.optimizers.Optimizer.
|
||||
iterator: an iterator that yields input tensors.
|
||||
metric: tf.keras.metrics.Metric subclass.
|
||||
|
||||
Returns:
|
||||
The training step callable.
|
||||
"""
|
||||
_replicated_step = self._create_replicated_step(strategy, model, loss_fn,
|
||||
optimizer, metric)
|
||||
|
||||
@tf.function
|
||||
def train_step(iterator, num_steps):
|
||||
"""Performs a distributed training step.
|
||||
|
||||
Args:
|
||||
iterator: an iterator that yields input tensors.
|
||||
|
||||
Returns:
|
||||
The loss tensor.
|
||||
"""
|
||||
if not isinstance(num_steps, tf.Tensor):
|
||||
raise ValueError('steps should be an Tensor. Python object may cause '
|
||||
'retracing.')
|
||||
|
||||
per_replica_losses = strategy.run(
|
||||
_replicated_step, args=(next(iterator),))
|
||||
for _ in tf.range(num_steps - 1):
|
||||
per_replica_losses = strategy.run(
|
||||
_replicated_step, args=(next(iterator),))
|
||||
|
||||
# For reporting, we returns the mean of losses.
|
||||
losses = tf.nest.map_structure(
|
||||
lambda x: strategy.reduce(tf.distribute.ReduceOp.MEAN, x, axis=None),
|
||||
per_replica_losses)
|
||||
return losses
|
||||
|
||||
return train_step
|
||||
|
||||
def _create_test_step(self, strategy, model, metric):
|
||||
"""Creates a distributed test step."""
|
||||
|
||||
@tf.function
|
||||
def test_step(iterator):
|
||||
"""Calculates evaluation metrics on distributed devices."""
|
||||
if not metric:
|
||||
logging.info('Skip test_step because metric is None (%s)', metric)
|
||||
return None, None
|
||||
if not isinstance(metric, tf.keras.metrics.Metric):
|
||||
raise ValueError(
|
||||
'Metric must be an instance of tf.keras.metrics.Metric '
|
||||
'for running in test_step. Actual {}'.format(metric))
|
||||
|
||||
def _test_step_fn(inputs):
|
||||
"""Replicated accuracy calculation."""
|
||||
inputs, labels = inputs
|
||||
model_outputs = model(inputs, training=False)
|
||||
metric.update_state(labels, model_outputs)
|
||||
return labels, model_outputs
|
||||
|
||||
return strategy.run(_test_step_fn, args=(next(iterator),))
|
||||
|
||||
return test_step
|
||||
|
||||
def train(self,
|
||||
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
|
||||
eval_input_fn: Callable[[params_dict.ParamsDict],
|
||||
tf.data.Dataset] = None,
|
||||
model_dir: Text = None,
|
||||
total_steps: int = 1,
|
||||
iterations_per_loop: int = 1,
|
||||
train_metric_fn: Callable[[], Any] = None,
|
||||
eval_metric_fn: Callable[[], Any] = None,
|
||||
summary_writer_fn: Callable[[Text, Text],
|
||||
SummaryWriter] = SummaryWriter,
|
||||
init_checkpoint: Callable[[tf.keras.Model], Any] = None,
|
||||
custom_callbacks: List[tf.keras.callbacks.Callback] = None,
|
||||
save_config: bool = True):
|
||||
"""Runs distributed training.
|
||||
|
||||
Args:
|
||||
train_input_fn: (params: dict) -> tf.data.Dataset training data input
|
||||
function.
|
||||
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
|
||||
trigger evaluting metric on eval data. If None, will not run eval step.
|
||||
model_dir: the folder path for model checkpoints.
|
||||
total_steps: total training steps.
|
||||
iterations_per_loop: train steps per loop. After each loop, this job will
|
||||
update metrics like loss and save checkpoint.
|
||||
train_metric_fn: metric_fn for evaluation in train_step.
|
||||
eval_metric_fn: metric_fn for evaluation in test_step.
|
||||
summary_writer_fn: function to create summary writer.
|
||||
init_checkpoint: function to load checkpoint.
|
||||
custom_callbacks: A list of Keras Callbacks objects to run during
|
||||
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
|
||||
methods are invoked during training.
|
||||
save_config: bool. Whether to save params to model_dir.
|
||||
|
||||
Returns:
|
||||
The training loss and eval metrics.
|
||||
"""
|
||||
assert train_input_fn is not None
|
||||
if train_metric_fn and not callable(train_metric_fn):
|
||||
raise ValueError('if `train_metric_fn` is specified, '
|
||||
'train_metric_fn must be a callable.')
|
||||
if eval_metric_fn and not callable(eval_metric_fn):
|
||||
raise ValueError('if `eval_metric_fn` is specified, '
|
||||
'eval_metric_fn must be a callable.')
|
||||
train_metric_fn = train_metric_fn or _no_metric
|
||||
eval_metric_fn = eval_metric_fn or _no_metric
|
||||
|
||||
if custom_callbacks and iterations_per_loop != 1:
|
||||
logging.warning(
|
||||
'It is sematically wrong to run callbacks when '
|
||||
'iterations_per_loop is not one (%s)', iterations_per_loop)
|
||||
|
||||
custom_callbacks = custom_callbacks or []
|
||||
|
||||
def _run_callbacks_on_batch_begin(batch):
|
||||
"""Runs custom callbacks at the start of every step."""
|
||||
if not custom_callbacks:
|
||||
return
|
||||
for callback in custom_callbacks:
|
||||
if callback:
|
||||
callback.on_batch_begin(batch)
|
||||
|
||||
def _run_callbacks_on_batch_end(batch):
|
||||
"""Runs custom callbacks at the end of every step."""
|
||||
if not custom_callbacks:
|
||||
return
|
||||
for callback in custom_callbacks:
|
||||
if callback:
|
||||
callback.on_batch_end(batch)
|
||||
|
||||
if save_config:
|
||||
self._save_config(model_dir)
|
||||
|
||||
if FLAGS.save_checkpoint_freq:
|
||||
save_freq = FLAGS.save_checkpoint_freq
|
||||
else:
|
||||
save_freq = iterations_per_loop
|
||||
|
||||
params = self._params
|
||||
strategy = self._strategy
|
||||
# To reduce unnecessary send/receive input pipeline operation, we place
|
||||
# input pipeline ops in worker task.
|
||||
train_iterator = self._get_input_iterator(train_input_fn, strategy)
|
||||
train_loss = None
|
||||
eval_metric_result = None
|
||||
with strategy.scope():
|
||||
# To correctly place the model weights on accelerators,
|
||||
# model and optimizer should be created in scope.
|
||||
model = self.model_fn(params.as_dict())
|
||||
if not hasattr(model, 'optimizer'):
|
||||
raise ValueError('User should set optimizer attribute to model '
|
||||
'inside `model_fn`.')
|
||||
optimizer = model.optimizer
|
||||
|
||||
# Training loop starts here.
|
||||
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
|
||||
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
|
||||
initial_step = 0
|
||||
if latest_checkpoint_file:
|
||||
logging.info(
|
||||
'Checkpoint file %s found and restoring from '
|
||||
'checkpoint', latest_checkpoint_file)
|
||||
checkpoint.restore(latest_checkpoint_file)
|
||||
initial_step = optimizer.iterations.numpy()
|
||||
logging.info('Loading from checkpoint file completed. Init step %d',
|
||||
initial_step)
|
||||
elif init_checkpoint:
|
||||
logging.info('Restoring from init checkpoint function')
|
||||
init_checkpoint(model)
|
||||
logging.info('Loading from init checkpoint file completed')
|
||||
|
||||
current_step = optimizer.iterations.numpy()
|
||||
checkpoint_name = self.checkpoint_name
|
||||
|
||||
eval_metric = eval_metric_fn()
|
||||
train_metric = train_metric_fn()
|
||||
train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
|
||||
self.train_summary_writer = train_summary_writer.writer
|
||||
|
||||
test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
|
||||
self.eval_summary_writer = test_summary_writer.writer
|
||||
|
||||
# Use training summary writer in TimeHistory if it's in use
|
||||
for cb in custom_callbacks:
|
||||
if isinstance(cb, keras_utils.TimeHistory):
|
||||
cb.summary_writer = self.train_summary_writer
|
||||
|
||||
# Continue training loop.
|
||||
train_step = self._create_train_step(
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
loss_fn=self.loss_fn(),
|
||||
optimizer=optimizer,
|
||||
metric=train_metric)
|
||||
test_step = None
|
||||
if eval_input_fn and eval_metric:
|
||||
self.global_train_step = model.optimizer.iterations
|
||||
test_step = self._create_test_step(strategy, model, metric=eval_metric)
|
||||
|
||||
# Step-0 operations
|
||||
_save_checkpoint(
|
||||
checkpoint, model_dir, checkpoint_name.format(step=current_step))
|
||||
if test_step:
|
||||
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
|
||||
eval_metric_result = self._run_evaluation(
|
||||
test_step, current_step, eval_metric, eval_iterator)
|
||||
logging.info(
|
||||
'Step: %s evalation metric = %s.', current_step, eval_metric_result)
|
||||
test_summary_writer(
|
||||
metrics=eval_metric_result, step=optimizer.iterations)
|
||||
eval_metric.reset_states()
|
||||
|
||||
logging.info('Training started')
|
||||
last_save_checkpoint_step = current_step
|
||||
while current_step < total_steps:
|
||||
|
||||
num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
|
||||
_run_callbacks_on_batch_begin(current_step)
|
||||
train_loss = train_step(train_iterator,
|
||||
tf.convert_to_tensor(num_steps, dtype=tf.int32))
|
||||
current_step += num_steps
|
||||
|
||||
train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
|
||||
train_loss)
|
||||
|
||||
_run_callbacks_on_batch_end(current_step - 1)
|
||||
if not isinstance(train_loss, dict):
|
||||
train_loss = {'total_loss': train_loss}
|
||||
if np.isnan(train_loss['total_loss']):
|
||||
raise ValueError('total loss is NaN.')
|
||||
|
||||
if train_metric:
|
||||
train_metric_result = train_metric.result()
|
||||
if isinstance(train_metric, tf.keras.metrics.Metric):
|
||||
train_metric_result = tf.nest.map_structure(
|
||||
lambda x: x.numpy().astype(float), train_metric_result)
|
||||
if not isinstance(train_metric_result, dict):
|
||||
train_metric_result = {'metric': train_metric_result}
|
||||
train_metric_result.update(train_loss)
|
||||
else:
|
||||
train_metric_result = train_loss
|
||||
if callable(optimizer.lr):
|
||||
train_metric_result.update(
|
||||
{'learning_rate': optimizer.lr(current_step).numpy()})
|
||||
else:
|
||||
train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
|
||||
logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
|
||||
current_step, total_steps, train_loss,
|
||||
train_metric_result)
|
||||
|
||||
train_summary_writer(
|
||||
metrics=train_metric_result, step=optimizer.iterations)
|
||||
|
||||
# Saves model checkpoints and run validation steps at every
|
||||
# iterations_per_loop steps.
|
||||
# To avoid repeated model saving, we do not save after the last
|
||||
# step of training.
|
||||
if save_freq > 0 and current_step < total_steps and (
|
||||
current_step - last_save_checkpoint_step) >= save_freq:
|
||||
_save_checkpoint(checkpoint, model_dir,
|
||||
checkpoint_name.format(step=current_step))
|
||||
last_save_checkpoint_step = current_step
|
||||
|
||||
if test_step:
|
||||
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
|
||||
eval_metric_result = self._run_evaluation(test_step, current_step,
|
||||
eval_metric, eval_iterator)
|
||||
logging.info('Step: %s evalation metric = %s.', current_step,
|
||||
eval_metric_result)
|
||||
test_summary_writer(
|
||||
metrics=eval_metric_result, step=optimizer.iterations)
|
||||
|
||||
# Re-initialize evaluation metric, except the last step.
|
||||
if eval_metric and current_step < total_steps:
|
||||
eval_metric.reset_states()
|
||||
if train_metric and current_step < total_steps:
|
||||
train_metric.reset_states()
|
||||
|
||||
# Reaches the end of training and saves the last checkpoint.
|
||||
if last_save_checkpoint_step < total_steps:
|
||||
_save_checkpoint(checkpoint, model_dir,
|
||||
checkpoint_name.format(step=current_step))
|
||||
|
||||
if test_step:
|
||||
logging.info('Running final evaluation after training is complete.')
|
||||
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
|
||||
eval_metric_result = self._run_evaluation(test_step, current_step,
|
||||
eval_metric, eval_iterator)
|
||||
logging.info('Final evaluation metric = %s.', eval_metric_result)
|
||||
test_summary_writer(
|
||||
metrics=eval_metric_result, step=optimizer.iterations)
|
||||
|
||||
self.train_summary_writer.close()
|
||||
self.eval_summary_writer.close()
|
||||
|
||||
return train_loss, eval_metric_result
|
||||
|
||||
def _run_evaluation(self, test_step, current_training_step, metric,
|
||||
test_iterator):
|
||||
"""Runs validation steps and aggregate metrics."""
|
||||
if not test_iterator or not metric:
|
||||
logging.warning(
|
||||
'Both test_iterator (%s) and metrics (%s) must not be None.',
|
||||
test_iterator, metric)
|
||||
return None
|
||||
logging.info('Running evaluation after step: %s.', current_training_step)
|
||||
while True:
|
||||
try:
|
||||
test_step(test_iterator)
|
||||
except (StopIteration, tf.errors.OutOfRangeError):
|
||||
break
|
||||
|
||||
metric_result = metric.result()
|
||||
if isinstance(metric, tf.keras.metrics.Metric):
|
||||
metric_result = metric_result.numpy().astype(float)
|
||||
logging.info('Step: [%d] Validation metric = %f', current_training_step,
|
||||
metric_result)
|
||||
return metric_result
|
||||
|
||||
def evaluate_from_model_dir(
|
||||
self,
|
||||
model_dir: Text,
|
||||
eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
|
||||
eval_metric_fn: Callable[[], Any],
|
||||
total_steps: int = -1,
|
||||
eval_timeout: int = None,
|
||||
min_eval_interval: int = 180,
|
||||
summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
|
||||
"""Runs distributed evaluation on model folder.
|
||||
|
||||
Args:
|
||||
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
|
||||
trigger evaluting metric on eval data. If None, will not run eval step.
|
||||
eval_metric_fn: metric_fn for evaluation in test_step.
|
||||
model_dir: the folder for storing model checkpoints.
|
||||
total_steps: total training steps. If the current step reaches the
|
||||
total_steps, the evaluation loop will stop.
|
||||
eval_timeout: The maximum number of seconds to wait between checkpoints.
|
||||
If left as None, then the process will wait indefinitely. Used by
|
||||
tf.train.checkpoints_iterator.
|
||||
min_eval_interval: The minimum number of seconds between yielding
|
||||
checkpoints. Used by tf.train.checkpoints_iterator.
|
||||
summary_writer_fn: function to create summary writer.
|
||||
|
||||
Returns:
|
||||
Eval metrics dictionary of the last checkpoint.
|
||||
"""
|
||||
|
||||
if not model_dir:
|
||||
raise ValueError('model_dir must be set.')
|
||||
|
||||
def terminate_eval():
|
||||
tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
|
||||
eval_timeout)
|
||||
return True
|
||||
|
||||
summary_writer = summary_writer_fn(model_dir, 'eval')
|
||||
self.eval_summary_writer = summary_writer.writer
|
||||
|
||||
# Read checkpoints from the given model directory
|
||||
# until `eval_timeout` seconds elapses.
|
||||
for checkpoint_path in tf.train.checkpoints_iterator(
|
||||
model_dir,
|
||||
min_interval_secs=min_eval_interval,
|
||||
timeout=eval_timeout,
|
||||
timeout_fn=terminate_eval):
|
||||
eval_metric_result, current_step = self.evaluate_checkpoint(
|
||||
checkpoint_path=checkpoint_path,
|
||||
eval_input_fn=eval_input_fn,
|
||||
eval_metric_fn=eval_metric_fn,
|
||||
summary_writer=summary_writer)
|
||||
if total_steps > 0 and current_step >= total_steps:
|
||||
logging.info('Evaluation finished after training step %d', current_step)
|
||||
break
|
||||
return eval_metric_result
|
||||
|
||||
def evaluate_checkpoint(self,
|
||||
checkpoint_path: Text,
|
||||
eval_input_fn: Callable[[params_dict.ParamsDict],
|
||||
tf.data.Dataset],
|
||||
eval_metric_fn: Callable[[], Any],
|
||||
summary_writer: SummaryWriter = None):
|
||||
"""Runs distributed evaluation on the one checkpoint.
|
||||
|
||||
Args:
|
||||
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
|
||||
trigger evaluting metric on eval data. If None, will not run eval step.
|
||||
eval_metric_fn: metric_fn for evaluation in test_step.
|
||||
checkpoint_path: the checkpoint to evaluate.
|
||||
summary_writer_fn: function to create summary writer.
|
||||
|
||||
Returns:
|
||||
Eval metrics dictionary of the last checkpoint.
|
||||
"""
|
||||
if not callable(eval_metric_fn):
|
||||
raise ValueError('if `eval_metric_fn` is specified, '
|
||||
'eval_metric_fn must be a callable.')
|
||||
|
||||
params = self._params
|
||||
strategy = self._strategy
|
||||
# To reduce unnecessary send/receive input pipeline operation, we place
|
||||
# input pipeline ops in worker task.
|
||||
with strategy.scope():
|
||||
|
||||
# To correctly place the model weights on accelerators,
|
||||
# model and optimizer should be created in scope.
|
||||
model = self.model_fn(params.as_dict())
|
||||
checkpoint = tf.train.Checkpoint(model=model)
|
||||
|
||||
eval_metric = eval_metric_fn()
|
||||
assert eval_metric, 'eval_metric does not exist'
|
||||
test_step = self._create_test_step(strategy, model, metric=eval_metric)
|
||||
|
||||
logging.info('Starting to evaluate.')
|
||||
if not checkpoint_path:
|
||||
raise ValueError('checkpoint path is empty')
|
||||
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
|
||||
current_step = reader.get_tensor(
|
||||
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
|
||||
logging.info(
|
||||
'Checkpoint file %s found and restoring from '
|
||||
'checkpoint', checkpoint_path)
|
||||
checkpoint.restore(checkpoint_path)
|
||||
|
||||
self.global_train_step = model.optimizer.iterations
|
||||
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
|
||||
eval_metric_result = self._run_evaluation(test_step, current_step,
|
||||
eval_metric, eval_iterator)
|
||||
logging.info('Step: %s evalation metric = %s.', current_step,
|
||||
eval_metric_result)
|
||||
summary_writer(metrics=eval_metric_result, step=current_step)
|
||||
eval_metric.reset_states()
|
||||
|
||||
return eval_metric_result, current_step
|
||||
|
||||
def predict(self):
|
||||
return NotImplementedError('Unimplmented function.')
|
||||
|
||||
|
||||
class ExecutorBuilder(object):
|
||||
"""Builder of DistributedExecutor.
|
||||
|
||||
Example 1: Builds an executor with supported Strategy.
|
||||
builder = ExecutorBuilder(
|
||||
strategy_type='tpu',
|
||||
strategy_config={'tpu': '/bns/xxx'})
|
||||
dist_executor = builder.build_executor(
|
||||
params=params,
|
||||
model_fn=my_model_fn,
|
||||
loss_fn=my_loss_fn,
|
||||
metric_fn=my_metric_fn)
|
||||
|
||||
Example 2: Builds an executor with customized Strategy.
|
||||
builder = ExecutorBuilder()
|
||||
builder.strategy = <some customized Strategy>
|
||||
dist_executor = builder.build_executor(
|
||||
params=params,
|
||||
model_fn=my_model_fn,
|
||||
loss_fn=my_loss_fn,
|
||||
metric_fn=my_metric_fn)
|
||||
|
||||
Example 3: Builds a customized executor with customized Strategy.
|
||||
class MyDistributedExecutor(DistributedExecutor):
|
||||
# implementation ...
|
||||
|
||||
builder = ExecutorBuilder()
|
||||
builder.strategy = <some customized Strategy>
|
||||
dist_executor = builder.build_executor(
|
||||
class_ctor=MyDistributedExecutor,
|
||||
params=params,
|
||||
model_fn=my_model_fn,
|
||||
loss_fn=my_loss_fn,
|
||||
metric_fn=my_metric_fn)
|
||||
|
||||
Args:
|
||||
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'. If
|
||||
None. User is responsible to set the strategy before calling
|
||||
build_executor(...).
|
||||
strategy_config: necessary config for constructing the proper Strategy.
|
||||
Check strategy_flags_dict() for examples of the structure.
|
||||
"""
|
||||
|
||||
def __init__(self, strategy_type=None, strategy_config=None):
|
||||
_ = distribution_utils.configure_cluster(
|
||||
strategy_config.worker_hosts, strategy_config.task_index)
|
||||
self._strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=strategy_type,
|
||||
num_gpus=strategy_config.num_gpus,
|
||||
all_reduce_alg=strategy_config.all_reduce_alg,
|
||||
num_packs=strategy_config.num_packs,
|
||||
tpu_address=strategy_config.tpu)
|
||||
|
||||
@property
|
||||
def strategy(self):
|
||||
"""Returns default checkpoint name."""
|
||||
return self._strategy
|
||||
|
||||
@strategy.setter
|
||||
def strategy(self, new_strategy):
|
||||
"""Sets default summary writer for the current thread."""
|
||||
self._strategy = new_strategy
|
||||
|
||||
|
||||
def build_executor(self,
|
||||
class_ctor=DistributedExecutor,
|
||||
params=None,
|
||||
model_fn=None,
|
||||
loss_fn=None,
|
||||
**kwargs):
|
||||
"""Creates an executor according to strategy type.
|
||||
|
||||
See doc string of the DistributedExecutor.__init__ for more information of
|
||||
the
|
||||
input arguments.
|
||||
|
||||
Args:
|
||||
class_ctor: A constructor of executor (default: DistributedExecutor).
|
||||
params: ParamsDict, all the model parameters and runtime parameters.
|
||||
model_fn: Keras model function.
|
||||
loss_fn: loss function.
|
||||
**kwargs: other arguments to the executor constructor.
|
||||
|
||||
Returns:
|
||||
An instance of DistributedExecutor or its subclass.
|
||||
"""
|
||||
if self._strategy is None:
|
||||
raise ValueError('`strategy` should not be None. You need to specify '
|
||||
'`strategy_type` in the builder contructor or directly '
|
||||
'set the `strategy` property of the builder.')
|
||||
return class_ctor(
|
||||
strategy=self._strategy,
|
||||
params=params,
|
||||
model_fn=model_fn,
|
||||
loss_fn=loss_fn,
|
||||
**kwargs)
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
# TensorFlow Natural Language Processing Modelling Toolkit
|
||||
|
||||
tensorflow/models/official/nlp provides a [modeling library](modeling) for constructing
|
||||
NLP model achitectures, as well as TF2 reference implementations for
|
||||
state-of-the-art models.
|
||||
|
||||
The repository contains the following models, with implementations, pre-trained
|
||||
model weights, usage scripts and conversion utilities:
|
||||
|
||||
* [Albert](albert)
|
||||
* [Bert](bert)
|
||||
* [NHNet](nhnet)
|
||||
* [XLNet](xlnet)
|
||||
* [Transformer for translation](transformer)
|
||||
|
||||
Addtional features:
|
||||
|
||||
* Distributed trainable on both multi-GPU and TPU
|
||||
* e2e training for custom models, including both pretraining and finetuning.
|
||||
+332
@@ -0,0 +1,332 @@
|
||||
# ALBERT (ALBERT: A Lite BERT for Self-supervised Learning of Language Representations)
|
||||
|
||||
The academic paper which describes ALBERT in detail and provides full results on
|
||||
a number of tasks can be found here: https://arxiv.org/abs/1909.11942.
|
||||
|
||||
This repository contains TensorFlow 2.x implementation for ALBERT.
|
||||
|
||||
## Contents
|
||||
* [Contents](#contents)
|
||||
* [Pre-trained Models](#pre-trained-models)
|
||||
* [Restoring from Checkpoints](#restoring-from-checkpoints)
|
||||
* [Set Up](#set-up)
|
||||
* [Process Datasets](#process-datasets)
|
||||
* [Fine-tuning with BERT](#fine-tuning-with-bert)
|
||||
* [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
|
||||
* [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
|
||||
* [SQuAD 1.1](#squad-1.1)
|
||||
|
||||
|
||||
## Pre-trained Models
|
||||
|
||||
We released both checkpoints and tf.hub modules as the pretrained models for
|
||||
fine-tuning. They are TF 2.x compatible and are converted from the ALBERT v2
|
||||
checkpoints released in TF 1.x official ALBERT repository
|
||||
[google-research/albert](https://github.com/google-research/albert)
|
||||
in order to keep consistent with ALBERT paper.
|
||||
|
||||
Our current released checkpoints are exactly the same as TF 1.x official ALBERT
|
||||
repository.
|
||||
|
||||
### Access to Pretrained Checkpoints
|
||||
|
||||
Pretrained checkpoints can be found in the following links:
|
||||
|
||||
**Note: We implemented ALBERT using Keras functional-style networks in [nlp/modeling](../modeling).
|
||||
ALBERT V2 models compatible with TF 2.x checkpoints are:**
|
||||
|
||||
* **[`ALBERT V2 Base`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base.tar.gz)**:
|
||||
12-layer, 768-hidden, 12-heads, 12M parameters
|
||||
* **[`ALBERT V2 Large`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_large.tar.gz)**:
|
||||
24-layer, 1024-hidden, 16-heads, 18M parameters
|
||||
* **[`ALBERT V2 XLarge`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_xlarge.tar.gz)**:
|
||||
24-layer, 2048-hidden, 32-heads, 60M parameters
|
||||
* **[`ALBERT V2 XXLarge`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_xxlarge.tar.gz)**:
|
||||
12-layer, 4096-hidden, 64-heads, 235M parameters
|
||||
|
||||
We recommend to host checkpoints on Google Cloud storage buckets when you use
|
||||
Cloud GPU/TPU.
|
||||
|
||||
### Restoring from Checkpoints
|
||||
|
||||
`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
|
||||
weights from provided pre-trained checkpoints, you can use the following code:
|
||||
|
||||
```python
|
||||
init_checkpoint='the pretrained model checkpoint path.'
|
||||
model=tf.keras.Model() # Bert pre-trained model as feature extractor.
|
||||
checkpoint = tf.train.Checkpoint(model=model)
|
||||
checkpoint.restore(init_checkpoint)
|
||||
```
|
||||
|
||||
Checkpoints featuring native serialized Keras models
|
||||
(i.e. model.load()/load_weights()) will be available soon.
|
||||
|
||||
### Access to Pretrained hub modules.
|
||||
|
||||
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
|
||||
following links:
|
||||
|
||||
* **[`ALBERT V2 Base`](https://tfhub.dev/tensorflow/albert_en_base/1)**:
|
||||
12-layer, 768-hidden, 12-heads, 12M parameters
|
||||
* **[`ALBERT V2 Large`](https://tfhub.dev/tensorflow/albert_en_large/1)**:
|
||||
24-layer, 1024-hidden, 16-heads, 18M parameters
|
||||
* **[`ALBERT V2 XLarge`](https://tfhub.dev/tensorflow/albert_en_xlarge/1)**:
|
||||
24-layer, 2048-hidden, 32-heads, 60M parameters
|
||||
* **[`ALBERT V2 XXLarge`](https://tfhub.dev/tensorflow/albert_en_xxlarge/1)**:
|
||||
12-layer, 4096-hidden, 64-heads, 235M parameters
|
||||
|
||||
## Set Up
|
||||
|
||||
```shell
|
||||
export PYTHONPATH="$PYTHONPATH:/path/to/models"
|
||||
```
|
||||
|
||||
Install `tf-nightly` to get latest updates:
|
||||
|
||||
```shell
|
||||
pip install tf-nightly-gpu
|
||||
```
|
||||
|
||||
With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
|
||||
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
|
||||
|
||||
```shell
|
||||
ctpu up -name <instance name> --tf-version=”nightly”
|
||||
```
|
||||
|
||||
Second, you need to install TF 2 `tf-nightly` on your VM:
|
||||
|
||||
```shell
|
||||
pip install tf-nightly
|
||||
```
|
||||
|
||||
Warning: More details TPU-specific set-up instructions and tutorial should come
|
||||
along with official TF 2.x release for TPU. Note that this repo is not
|
||||
officially supported by Google Cloud TPU team yet until TF 2.1 released.
|
||||
|
||||
## Process Datasets
|
||||
|
||||
### Pre-training
|
||||
|
||||
Pre-train ALBERT using TF2.x will come soon.
|
||||
For now, please use [ALBERT research repo](https://github.com/google-research/ALBERT)
|
||||
to pretrain the model and convert the checkpoint to TF2.x compatible ones using
|
||||
[tf2_albert_encoder_checkpoint_converter.py](tf2_albert_encoder_checkpoint_converter.py).
|
||||
|
||||
|
||||
|
||||
### Fine-tuning
|
||||
|
||||
To prepare the fine-tuning data for final model training, use the
|
||||
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
|
||||
Note that different from BERT models that use word piece tokenzer,
|
||||
ALBERT models employ sentence piece tokenizer. So the FLAG tokenizer_impl has
|
||||
to be set to 'sentence_piece'.
|
||||
Resulting datasets in `tf_record` format and training meta data should be later
|
||||
passed to training or evaluation scripts. The task-specific arguments are
|
||||
described in following sections:
|
||||
|
||||
* GLUE
|
||||
|
||||
Users can download the
|
||||
[GLUE data](https://gluebenchmark.com/tasks) by running
|
||||
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
||||
and unpack it to some directory `$GLUE_DIR`.
|
||||
|
||||
```shell
|
||||
export GLUE_DIR=~/glue
|
||||
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
|
||||
|
||||
export TASK_NAME=MNLI
|
||||
export OUTPUT_DIR=gs://some_bucket/datasets
|
||||
python ../data/create_finetuning_data.py \
|
||||
--input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
|
||||
--sp_model_file=${ALBERT_DIR}/30k-clean.model \
|
||||
--train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
|
||||
--eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
|
||||
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
|
||||
--fine_tuning_task_type=classification --max_seq_length=128 \
|
||||
--classification_task_name=${TASK_NAME} \
|
||||
--tokenizer_impl=sentence_piece
|
||||
```
|
||||
|
||||
* SQUAD
|
||||
|
||||
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
|
||||
detailed information about the SQuAD datasets and evaluation.
|
||||
|
||||
The necessary files can be found here:
|
||||
|
||||
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
|
||||
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
|
||||
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
|
||||
* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
|
||||
* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
|
||||
* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
|
||||
|
||||
```shell
|
||||
export SQUAD_DIR=~/squad
|
||||
export SQUAD_VERSION=v1.1
|
||||
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
|
||||
export OUTPUT_DIR=gs://some_bucket/datasets
|
||||
|
||||
python ../data/create_finetuning_data.py \
|
||||
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
|
||||
--sp_model_file=${ALBERT_DIR}/30k-clean.model \
|
||||
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
||||
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
||||
--fine_tuning_task_type=squad --max_seq_length=384 \
|
||||
--tokenizer_impl=sentence_piece
|
||||
```
|
||||
|
||||
## Fine-tuning with ALBERT
|
||||
|
||||
### Cloud GPUs and TPUs
|
||||
|
||||
* Cloud Storage
|
||||
|
||||
The unzipped pre-trained model files can also be found in the Google Cloud
|
||||
Storage folder `gs://cloud-tpu-checkpoints/albert/checkpoints`. For example:
|
||||
|
||||
```shell
|
||||
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
```
|
||||
|
||||
Currently, users are able to access to `tf-nightly` TPUs and the following TPU
|
||||
script should run with `tf-nightly`.
|
||||
|
||||
* GPU -> TPU
|
||||
|
||||
Just add the following flags to `run_classifier.py` or `run_squad.py`:
|
||||
|
||||
```shell
|
||||
--distribution_strategy=tpu
|
||||
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
||||
```
|
||||
|
||||
### Sentence and Sentence-pair Classification Tasks
|
||||
|
||||
This example code fine-tunes `albert_v2_base` on the Microsoft Research
|
||||
Paraphrase Corpus (MRPC) corpus, which only contains 3,600 examples and can
|
||||
fine-tune in a few minutes on most GPUs.
|
||||
|
||||
We use the `albert_v2_base` as an example throughout the
|
||||
workflow.
|
||||
|
||||
|
||||
```shell
|
||||
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
export GLUE_DIR=gs://some_bucket/datasets
|
||||
export TASK=MRPC
|
||||
|
||||
python run_classifier.py \
|
||||
--mode='train_and_eval' \
|
||||
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
||||
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
|
||||
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
||||
--bert_config_file=${ALBERT_DIR}/albert_config.json \
|
||||
--init_checkpoint=${ALBERT_DIR}/bert_model.ckpt \
|
||||
--train_batch_size=4 \
|
||||
--eval_batch_size=4 \
|
||||
--steps_per_loop=1 \
|
||||
--learning_rate=2e-5 \
|
||||
--num_train_epochs=3 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=mirrored
|
||||
```
|
||||
|
||||
Alternatively, instead of specifying `init_checkpoint`, you can specify
|
||||
`hub_module_url` to employ a pretraind BERT hub module, e.g.,
|
||||
` --hub_module_url=https://tfhub.dev/tensorflow/albert_en_base/1`.
|
||||
|
||||
To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
|
||||
information and use remote storage for model checkpoints.
|
||||
|
||||
```shell
|
||||
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
|
||||
export TPU_IP_ADDRESS='???'
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
export GLUE_DIR=gs://some_bucket/datasets
|
||||
|
||||
python run_classifier.py \
|
||||
--mode='train_and_eval' \
|
||||
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
||||
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
|
||||
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
||||
--bert_config_file=$ALBERT_DIR/albert_config.json \
|
||||
--init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
|
||||
--train_batch_size=32 \
|
||||
--eval_batch_size=32 \
|
||||
--learning_rate=2e-5 \
|
||||
--num_train_epochs=3 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=tpu \
|
||||
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
||||
```
|
||||
|
||||
### SQuAD 1.1
|
||||
|
||||
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
|
||||
benchmark dataset. See more in [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
|
||||
|
||||
We use the `albert_v2_base` as an example throughout the
|
||||
workflow.
|
||||
|
||||
```shell
|
||||
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
|
||||
export SQUAD_DIR=gs://some_bucket/datasets
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
export SQUAD_VERSION=v1.1
|
||||
|
||||
python run_squad.py \
|
||||
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
||||
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
||||
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
|
||||
--sp_model_file=${ALBERT_DIR}/30k-clean.model \
|
||||
--bert_config_file=$ALBERT_DIR/albert_config.json \
|
||||
--init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
|
||||
--train_batch_size=4 \
|
||||
--predict_batch_size=4 \
|
||||
--learning_rate=8e-5 \
|
||||
--num_train_epochs=2 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=mirrored
|
||||
```
|
||||
|
||||
Similarily, you can replace `init_checkpoint` FLAGS with `hub_module_url` to
|
||||
specify a hub module path.
|
||||
|
||||
To use TPU, you need switch distribution strategy type to `tpu` with TPU
|
||||
information.
|
||||
|
||||
```shell
|
||||
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
|
||||
export TPU_IP_ADDRESS='???'
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
export SQUAD_DIR=gs://some_bucket/datasets
|
||||
export SQUAD_VERSION=v1.1
|
||||
|
||||
python run_squad.py \
|
||||
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
||||
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
||||
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
|
||||
--sp_model_file=${ALBERT_DIR}/30k-clean.model \
|
||||
--bert_config_file=$ALBERT_DIR/albert_config.json \
|
||||
--init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
|
||||
--train_batch_size=32 \
|
||||
--learning_rate=8e-5 \
|
||||
--num_train_epochs=2 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=tpu \
|
||||
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
||||
```
|
||||
|
||||
The dev set predictions will be saved into a file called predictions.json in the
|
||||
model_dir:
|
||||
|
||||
```shell
|
||||
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
|
||||
```
|
||||
+61
@@ -0,0 +1,61 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The ALBERT configurations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from official.nlp.bert import configs
|
||||
|
||||
|
||||
class AlbertConfig(configs.BertConfig):
|
||||
"""Configuration for `ALBERT`."""
|
||||
|
||||
def __init__(self,
|
||||
embedding_size,
|
||||
num_hidden_groups=1,
|
||||
inner_group_num=1,
|
||||
**kwargs):
|
||||
"""Constructs AlbertConfig.
|
||||
|
||||
Args:
|
||||
embedding_size: Size of the factorized word embeddings.
|
||||
num_hidden_groups: Number of group for the hidden layers, parameters in
|
||||
the same group are shared. Note that this value and also the following
|
||||
'inner_group_num' has to be 1 for now, because all released ALBERT
|
||||
models set them to 1. We may support arbitary valid values in future.
|
||||
inner_group_num: Number of inner repetition of attention and ffn.
|
||||
**kwargs: The remaining arguments are the same as above 'BertConfig'.
|
||||
"""
|
||||
super(AlbertConfig, self).__init__(**kwargs)
|
||||
self.embedding_size = embedding_size
|
||||
|
||||
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
|
||||
# in the released ALBERT. Support other values in AlbertTransformerEncoder
|
||||
# if needed.
|
||||
if inner_group_num != 1 or num_hidden_groups != 1:
|
||||
raise ValueError("We only support 'inner_group_num' and "
|
||||
"'num_hidden_groups' as 1.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
|
||||
config = AlbertConfig(embedding_size=None, vocab_size=None)
|
||||
for (key, value) in six.iteritems(json_object):
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
+88
@@ -0,0 +1,88 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A script to export the ALBERT core model as a TF-Hub SavedModel."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
# from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
from typing import Text
|
||||
|
||||
from official.nlp.albert import configs
|
||||
from official.nlp.bert import bert_models
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("albert_config_file", None,
|
||||
"Albert configuration file to define core albert layers.")
|
||||
flags.DEFINE_string("model_checkpoint_path", None,
|
||||
"File path to TF model checkpoint.")
|
||||
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
|
||||
flags.DEFINE_string(
|
||||
"sp_model_file", None,
|
||||
"The sentence piece model file that the ALBERT model was trained on.")
|
||||
|
||||
|
||||
def create_albert_model(
|
||||
albert_config: configs.AlbertConfig) -> tf.keras.Model:
|
||||
"""Creates an ALBERT keras core model from ALBERT configuration.
|
||||
|
||||
Args:
|
||||
albert_config: An `AlbertConfig` to create the core model.
|
||||
|
||||
Returns:
|
||||
A keras model.
|
||||
"""
|
||||
# Adds input layers just as placeholders.
|
||||
input_word_ids = tf.keras.layers.Input(
|
||||
shape=(None,), dtype=tf.int32, name="input_word_ids")
|
||||
input_mask = tf.keras.layers.Input(
|
||||
shape=(None,), dtype=tf.int32, name="input_mask")
|
||||
input_type_ids = tf.keras.layers.Input(
|
||||
shape=(None,), dtype=tf.int32, name="input_type_ids")
|
||||
transformer_encoder = bert_models.get_transformer_encoder(
|
||||
albert_config, sequence_length=None)
|
||||
sequence_output, pooled_output = transformer_encoder(
|
||||
[input_word_ids, input_mask, input_type_ids])
|
||||
# To keep consistent with legacy hub modules, the outputs are
|
||||
# "pooled_output" and "sequence_output".
|
||||
return tf.keras.Model(
|
||||
inputs=[input_word_ids, input_mask, input_type_ids],
|
||||
outputs=[pooled_output, sequence_output]), transformer_encoder
|
||||
|
||||
|
||||
def export_albert_tfhub(albert_config: configs.AlbertConfig,
|
||||
model_checkpoint_path: Text, hub_destination: Text,
|
||||
sp_model_file: Text):
|
||||
"""Restores a tf.keras.Model and saves for TF-Hub."""
|
||||
core_model, encoder = create_albert_model(albert_config)
|
||||
checkpoint = tf.train.Checkpoint(model=encoder)
|
||||
checkpoint.restore(model_checkpoint_path).assert_consumed()
|
||||
core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
|
||||
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
|
||||
|
||||
|
||||
def main(_):
|
||||
albert_config = configs.AlbertConfig.from_json_file(
|
||||
FLAGS.albert_config_file)
|
||||
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
|
||||
FLAGS.export_path, FLAGS.sp_model_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
+89
@@ -0,0 +1,89 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests official.nlp.albert.export_albert_tfhub."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from official.nlp.albert import configs
|
||||
from official.nlp.albert import export_albert_tfhub
|
||||
|
||||
|
||||
class ExportAlbertTfhubTest(tf.test.TestCase):
|
||||
|
||||
def test_export_albert_tfhub(self):
|
||||
# Exports a savedmodel for TF-Hub
|
||||
albert_config = configs.AlbertConfig(
|
||||
vocab_size=100,
|
||||
embedding_size=8,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
max_position_embeddings=128,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=1)
|
||||
bert_model, encoder = export_albert_tfhub.create_albert_model(albert_config)
|
||||
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
|
||||
checkpoint = tf.train.Checkpoint(model=encoder)
|
||||
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
|
||||
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
|
||||
|
||||
sp_model_file = os.path.join(self.get_temp_dir(), "sp_tokenizer.model")
|
||||
with tf.io.gfile.GFile(sp_model_file, "w") as f:
|
||||
f.write("dummy content")
|
||||
|
||||
hub_destination = os.path.join(self.get_temp_dir(), "hub")
|
||||
export_albert_tfhub.export_albert_tfhub(
|
||||
albert_config,
|
||||
model_checkpoint_path,
|
||||
hub_destination,
|
||||
sp_model_file=sp_model_file)
|
||||
|
||||
# Restores a hub KerasLayer.
|
||||
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
|
||||
|
||||
if hasattr(hub_layer, "resolved_object"):
|
||||
with tf.io.gfile.GFile(
|
||||
hub_layer.resolved_object.sp_model_file.asset_path.numpy()) as f:
|
||||
self.assertEqual("dummy content", f.read())
|
||||
# Checks the hub KerasLayer.
|
||||
for source_weight, hub_weight in zip(bert_model.trainable_weights,
|
||||
hub_layer.trainable_weights):
|
||||
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
|
||||
|
||||
dummy_ids = np.zeros((2, 10), dtype=np.int32)
|
||||
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
|
||||
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
|
||||
|
||||
# The outputs of hub module are "pooled_output" and "sequence_output",
|
||||
# while the outputs of encoder is in reversed order, i.e.,
|
||||
# "sequence_output" and "pooled_output".
|
||||
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
|
||||
self.assertEqual(hub_outputs[0].shape, (2, 16))
|
||||
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
|
||||
for source_output, hub_output, encoder_output in zip(
|
||||
source_outputs, hub_outputs, encoder_outputs):
|
||||
self.assertAllClose(source_output.numpy(), hub_output.numpy())
|
||||
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
+69
@@ -0,0 +1,69 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ALBERT classification finetuning runner in tf2.x."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.albert import configs as albert_configs
|
||||
from official.nlp.bert import run_classifier as run_classifier_bert
|
||||
from official.utils.misc import distribution_utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(_):
|
||||
# Users should always run this script under TF 2.x
|
||||
|
||||
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
|
||||
input_meta_data = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
if not FLAGS.model_dir:
|
||||
FLAGS.model_dir = '/tmp/bert20/'
|
||||
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=FLAGS.distribution_strategy,
|
||||
num_gpus=FLAGS.num_gpus,
|
||||
tpu_address=FLAGS.tpu)
|
||||
max_seq_length = input_meta_data['max_seq_length']
|
||||
train_input_fn = run_classifier_bert.get_dataset_fn(
|
||||
FLAGS.train_data_path,
|
||||
max_seq_length,
|
||||
FLAGS.train_batch_size,
|
||||
is_training=True)
|
||||
eval_input_fn = run_classifier_bert.get_dataset_fn(
|
||||
FLAGS.eval_data_path,
|
||||
max_seq_length,
|
||||
FLAGS.eval_batch_size,
|
||||
is_training=False)
|
||||
|
||||
albert_config = albert_configs.AlbertConfig.from_json_file(
|
||||
FLAGS.bert_config_file)
|
||||
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
|
||||
train_input_fn, eval_input_fn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('bert_config_file')
|
||||
flags.mark_flag_as_required('input_meta_data_path')
|
||||
flags.mark_flag_as_required('model_dir')
|
||||
app.run(main)
|
||||
+139
@@ -0,0 +1,139 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.albert import configs as albert_configs
|
||||
from official.nlp.bert import run_squad_helper
|
||||
from official.nlp.bert import tokenization
|
||||
from official.nlp.data import squad_lib_sp
|
||||
from official.utils.misc import distribution_utils
|
||||
|
||||
flags.DEFINE_string(
|
||||
'sp_model_file', None,
|
||||
'The path to the sentence piece model. Used by sentence piece tokenizer '
|
||||
'employed by ALBERT.')
|
||||
|
||||
# More flags can be found in run_squad_helper.
|
||||
run_squad_helper.define_common_squad_flags()
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def train_squad(strategy,
|
||||
input_meta_data,
|
||||
custom_callbacks=None,
|
||||
run_eagerly=False):
|
||||
"""Runs bert squad training."""
|
||||
bert_config = albert_configs.AlbertConfig.from_json_file(
|
||||
FLAGS.bert_config_file)
|
||||
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
|
||||
custom_callbacks, run_eagerly)
|
||||
|
||||
|
||||
def predict_squad(strategy, input_meta_data):
|
||||
"""Makes predictions for the squad dataset."""
|
||||
bert_config = albert_configs.AlbertConfig.from_json_file(
|
||||
FLAGS.bert_config_file)
|
||||
tokenizer = tokenization.FullSentencePieceTokenizer(
|
||||
sp_model_file=FLAGS.sp_model_file)
|
||||
|
||||
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
|
||||
bert_config, squad_lib_sp)
|
||||
|
||||
|
||||
def eval_squad(strategy, input_meta_data):
|
||||
"""Evaluate on the squad dataset."""
|
||||
bert_config = albert_configs.AlbertConfig.from_json_file(
|
||||
FLAGS.bert_config_file)
|
||||
tokenizer = tokenization.FullSentencePieceTokenizer(
|
||||
sp_model_file=FLAGS.sp_model_file)
|
||||
|
||||
eval_metrics = run_squad_helper.eval_squad(
|
||||
strategy, input_meta_data, tokenizer, bert_config, squad_lib_sp)
|
||||
return eval_metrics
|
||||
|
||||
|
||||
def export_squad(model_export_path, input_meta_data):
|
||||
"""Exports a trained model as a `SavedModel` for inference.
|
||||
|
||||
Args:
|
||||
model_export_path: a string specifying the path to the SavedModel directory.
|
||||
input_meta_data: dictionary containing meta data about input and model.
|
||||
|
||||
Raises:
|
||||
Export path is not specified, got an empty string or None.
|
||||
"""
|
||||
bert_config = albert_configs.AlbertConfig.from_json_file(
|
||||
FLAGS.bert_config_file)
|
||||
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
|
||||
|
||||
|
||||
def main(_):
|
||||
# Users should always run this script under TF 2.x
|
||||
|
||||
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
|
||||
input_meta_data = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
if FLAGS.mode == 'export_only':
|
||||
export_squad(FLAGS.model_export_path, input_meta_data)
|
||||
return
|
||||
|
||||
# Configures cluster spec for multi-worker distribution strategy.
|
||||
if FLAGS.num_gpus > 0:
|
||||
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
|
||||
FLAGS.task_index)
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=FLAGS.distribution_strategy,
|
||||
num_gpus=FLAGS.num_gpus,
|
||||
all_reduce_alg=FLAGS.all_reduce_alg,
|
||||
tpu_address=FLAGS.tpu)
|
||||
|
||||
if 'train' in FLAGS.mode:
|
||||
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
|
||||
if 'predict' in FLAGS.mode:
|
||||
predict_squad(strategy, input_meta_data)
|
||||
if 'eval' in FLAGS.mode:
|
||||
eval_metrics = eval_squad(strategy, input_meta_data)
|
||||
f1_score = eval_metrics['final_f1']
|
||||
logging.info('SQuAD eval F1-score: %f', f1_score)
|
||||
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
|
||||
summary_writer = tf.summary.create_file_writer(summary_dir)
|
||||
with summary_writer.as_default():
|
||||
# TODO(lehou): write to the correct step number.
|
||||
tf.summary.scalar('F1-score', f1_score, step=0)
|
||||
summary_writer.flush()
|
||||
# Also write eval_metrics to json file.
|
||||
squad_lib_sp.write_to_json_files(
|
||||
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('bert_config_file')
|
||||
flags.mark_flag_as_required('model_dir')
|
||||
app.run(main)
|
||||
+132
@@ -0,0 +1,132 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A converter from a tf1 ALBERT encoder checkpoint to a tf2 encoder checkpoint.
|
||||
|
||||
The conversion will yield an object-oriented checkpoint that can be used
|
||||
to restore a AlbertTransformerEncoder object.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import tensorflow as tf
|
||||
from official.modeling import activations
|
||||
from official.nlp.albert import configs
|
||||
from official.nlp.bert import tf1_checkpoint_converter_lib
|
||||
from official.nlp.modeling import networks
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("albert_config_file", None,
|
||||
"Albert configuration file to define core bert layers.")
|
||||
flags.DEFINE_string(
|
||||
"checkpoint_to_convert", None,
|
||||
"Initial checkpoint from a pretrained BERT model core (that is, only the "
|
||||
"BertModel, with no task heads.)")
|
||||
flags.DEFINE_string("converted_checkpoint_path", None,
|
||||
"Name for the created object-based V2 checkpoint.")
|
||||
|
||||
|
||||
ALBERT_NAME_REPLACEMENTS = (
|
||||
("bert/encoder/", ""),
|
||||
("bert/", ""),
|
||||
("embeddings/word_embeddings", "word_embeddings/embeddings"),
|
||||
("embeddings/position_embeddings", "position_embedding/embeddings"),
|
||||
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
|
||||
("embeddings/LayerNorm", "embeddings/layer_norm"),
|
||||
("embedding_hidden_mapping_in", "embedding_projection"),
|
||||
("group_0/inner_group_0/", ""),
|
||||
("attention_1/self", "self_attention"),
|
||||
("attention_1/output/dense", "self_attention_output"),
|
||||
("LayerNorm/", "self_attention_layer_norm/"),
|
||||
("ffn_1/intermediate/dense", "intermediate"),
|
||||
("ffn_1/intermediate/output/dense", "output"),
|
||||
("LayerNorm_1/", "output_layer_norm/"),
|
||||
("pooler/dense", "pooler_transform"),
|
||||
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
|
||||
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
|
||||
("cls/seq_relationship/output_weights",
|
||||
"predictions/transform/logits/kernel"),
|
||||
)
|
||||
|
||||
|
||||
def _create_albert_model(cfg):
|
||||
"""Creates a BERT keras core model from BERT configuration.
|
||||
|
||||
Args:
|
||||
cfg: A `BertConfig` to create the core model.
|
||||
|
||||
Returns:
|
||||
A keras model.
|
||||
"""
|
||||
albert_encoder = networks.AlbertTransformerEncoder(
|
||||
vocab_size=cfg.vocab_size,
|
||||
hidden_size=cfg.hidden_size,
|
||||
embedding_width=cfg.embedding_size,
|
||||
num_layers=cfg.num_hidden_layers,
|
||||
num_attention_heads=cfg.num_attention_heads,
|
||||
intermediate_size=cfg.intermediate_size,
|
||||
activation=activations.gelu,
|
||||
dropout_rate=cfg.hidden_dropout_prob,
|
||||
attention_dropout_rate=cfg.attention_probs_dropout_prob,
|
||||
sequence_length=cfg.max_position_embeddings,
|
||||
type_vocab_size=cfg.type_vocab_size,
|
||||
initializer=tf.keras.initializers.TruncatedNormal(
|
||||
stddev=cfg.initializer_range))
|
||||
return albert_encoder
|
||||
|
||||
|
||||
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
|
||||
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
|
||||
output_dir, _ = os.path.split(output_path)
|
||||
|
||||
# Create a temporary V1 name-converted checkpoint in the output directory.
|
||||
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
|
||||
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
|
||||
tf1_checkpoint_converter_lib.convert(
|
||||
checkpoint_from_path=v1_checkpoint,
|
||||
checkpoint_to_path=temporary_checkpoint,
|
||||
num_heads=bert_config.num_attention_heads,
|
||||
name_replacements=ALBERT_NAME_REPLACEMENTS,
|
||||
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
|
||||
exclude_patterns=["adam", "Adam"])
|
||||
|
||||
# Create a V2 checkpoint from the temporary checkpoint.
|
||||
model = _create_albert_model(bert_config)
|
||||
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
|
||||
output_path)
|
||||
|
||||
# Clean up the temporary checkpoint, if it exists.
|
||||
try:
|
||||
tf.io.gfile.rmtree(temporary_checkpoint_dir)
|
||||
except tf.errors.OpError:
|
||||
# If it doesn't exist, we don't need to clean it up; continue.
|
||||
pass
|
||||
|
||||
|
||||
def main(_):
|
||||
output_path = FLAGS.converted_checkpoint_path
|
||||
v1_checkpoint = FLAGS.checkpoint_to_convert
|
||||
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
|
||||
convert_checkpoint(albert_config, output_path, v1_checkpoint)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
+350
@@ -0,0 +1,350 @@
|
||||
# BERT (Bidirectional Encoder Representations from Transformers)
|
||||
|
||||
The academic paper which describes BERT in detail and provides full results on a
|
||||
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
|
||||
|
||||
This repository contains TensorFlow 2.x implementation for BERT.
|
||||
|
||||
## Contents
|
||||
* [Contents](#contents)
|
||||
* [Pre-trained Models](#pre-trained-models)
|
||||
* [Restoring from Checkpoints](#restoring-from-checkpoints)
|
||||
* [Set Up](#set-up)
|
||||
* [Process Datasets](#process-datasets)
|
||||
* [Fine-tuning with BERT](#fine-tuning-with-bert)
|
||||
* [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
|
||||
* [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
|
||||
* [SQuAD 1.1](#squad-1.1)
|
||||
|
||||
|
||||
## Pre-trained Models
|
||||
|
||||
We released both checkpoints and tf.hub modules as the pretrained models for
|
||||
fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
|
||||
released in TF 1.x official BERT repository
|
||||
[google-research/bert](https://github.com/google-research/bert)
|
||||
in order to keep consistent with BERT paper.
|
||||
|
||||
|
||||
### Access to Pretrained Checkpoints
|
||||
|
||||
Pretrained checkpoints can be found in the following links:
|
||||
|
||||
**Note: We have switched BERT implementation
|
||||
to use Keras functional-style networks in [nlp/modeling](../modeling).
|
||||
The new checkpoints are:**
|
||||
|
||||
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
|
||||
24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
|
||||
24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
|
||||
12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
* **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
|
||||
24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||
* **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
|
||||
12-layer, 768-hidden, 12-heads , 110M parameters
|
||||
* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
|
||||
24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||
|
||||
We recommend to host checkpoints on Google Cloud storage buckets when you use
|
||||
Cloud GPU/TPU.
|
||||
|
||||
### Restoring from Checkpoints
|
||||
|
||||
`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
|
||||
weights from provided pre-trained checkpoints, you can use the following code:
|
||||
|
||||
```python
|
||||
init_checkpoint='the pretrained model checkpoint path.'
|
||||
model=tf.keras.Model() # Bert pre-trained model as feature extractor.
|
||||
checkpoint = tf.train.Checkpoint(model=model)
|
||||
checkpoint.restore(init_checkpoint)
|
||||
```
|
||||
|
||||
Checkpoints featuring native serialized Keras models
|
||||
(i.e. model.load()/load_weights()) will be available soon.
|
||||
|
||||
### Access to Pretrained hub modules.
|
||||
|
||||
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
|
||||
following links:
|
||||
|
||||
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/1)**:
|
||||
24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||
* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/1)**:
|
||||
24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||
* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1)**:
|
||||
12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1)**:
|
||||
24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||
* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/1)**:
|
||||
12-layer, 768-hidden, 12-heads , 110M parameters
|
||||
* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/1)**:
|
||||
24-layer, 1024-hidden, 16-heads, 340M parameters
|
||||
* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/1)**:
|
||||
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/1)**:
|
||||
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
|
||||
110M parameters
|
||||
|
||||
## Set Up
|
||||
|
||||
```shell
|
||||
export PYTHONPATH="$PYTHONPATH:/path/to/models"
|
||||
```
|
||||
|
||||
Install `tf-nightly` to get latest updates:
|
||||
|
||||
```shell
|
||||
pip install tf-nightly-gpu
|
||||
```
|
||||
|
||||
With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
|
||||
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
|
||||
|
||||
```shell
|
||||
ctpu up -name <instance name> --tf-version=”nightly”
|
||||
```
|
||||
|
||||
Second, you need to install TF 2 `tf-nightly` on your VM:
|
||||
|
||||
```shell
|
||||
pip install tf-nightly
|
||||
```
|
||||
|
||||
Warning: More details TPU-specific set-up instructions and tutorial should come
|
||||
along with official TF 2.x release for TPU. Note that this repo is not
|
||||
officially supported by Google Cloud TPU team yet until TF 2.1 released.
|
||||
|
||||
## Process Datasets
|
||||
|
||||
### Pre-training
|
||||
|
||||
There is no change to generate pre-training data. Please use the script
|
||||
[`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
|
||||
which is essentially branched from [BERT research repo](https://github.com/google-research/bert)
|
||||
to get processed pre-training data and it adapts to TF2 symbols and python3
|
||||
compatibility.
|
||||
|
||||
|
||||
### Fine-tuning
|
||||
|
||||
To prepare the fine-tuning data for final model training, use the
|
||||
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
|
||||
Resulting datasets in `tf_record` format and training meta data should be later
|
||||
passed to training or evaluation scripts. The task-specific arguments are
|
||||
described in following sections:
|
||||
|
||||
* GLUE
|
||||
|
||||
Users can download the
|
||||
[GLUE data](https://gluebenchmark.com/tasks) by running
|
||||
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
||||
and unpack it to some directory `$GLUE_DIR`.
|
||||
|
||||
```shell
|
||||
export GLUE_DIR=~/glue
|
||||
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
||||
|
||||
export TASK_NAME=MNLI
|
||||
export OUTPUT_DIR=gs://some_bucket/datasets
|
||||
python ../data/create_finetuning_data.py \
|
||||
--input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
|
||||
--vocab_file=${BERT_DIR}/vocab.txt \
|
||||
--train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
|
||||
--eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
|
||||
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
|
||||
--fine_tuning_task_type=classification --max_seq_length=128 \
|
||||
--classification_task_name=${TASK_NAME}
|
||||
```
|
||||
|
||||
* SQUAD
|
||||
|
||||
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
|
||||
detailed information about the SQuAD datasets and evaluation.
|
||||
|
||||
The necessary files can be found here:
|
||||
|
||||
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
|
||||
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
|
||||
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
|
||||
* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
|
||||
* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
|
||||
* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
|
||||
|
||||
```shell
|
||||
export SQUAD_DIR=~/squad
|
||||
export SQUAD_VERSION=v1.1
|
||||
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
||||
export OUTPUT_DIR=gs://some_bucket/datasets
|
||||
|
||||
python ../data/create_finetuning_data.py \
|
||||
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
|
||||
--vocab_file=${BERT_DIR}/vocab.txt \
|
||||
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
||||
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
||||
--fine_tuning_task_type=squad --max_seq_length=384
|
||||
```
|
||||
|
||||
## Fine-tuning with BERT
|
||||
|
||||
### Cloud GPUs and TPUs
|
||||
|
||||
* Cloud Storage
|
||||
|
||||
The unzipped pre-trained model files can also be found in the Google Cloud
|
||||
Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
|
||||
|
||||
```shell
|
||||
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
```
|
||||
|
||||
Currently, users are able to access to `tf-nightly` TPUs and the following TPU
|
||||
script should run with `tf-nightly`.
|
||||
|
||||
* GPU -> TPU
|
||||
|
||||
Just add the following flags to `run_classifier.py` or `run_squad.py`:
|
||||
|
||||
```shell
|
||||
--distribution_strategy=tpu
|
||||
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
||||
```
|
||||
|
||||
### Sentence and Sentence-pair Classification Tasks
|
||||
|
||||
This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
|
||||
Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
|
||||
few minutes on most GPUs.
|
||||
|
||||
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
|
||||
workflow.
|
||||
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
|
||||
(uncased_L-12_H-768_A-12).
|
||||
|
||||
```shell
|
||||
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
export GLUE_DIR=gs://some_bucket/datasets
|
||||
export TASK=MRPC
|
||||
|
||||
python run_classifier.py \
|
||||
--mode='train_and_eval' \
|
||||
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
||||
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
|
||||
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
||||
--bert_config_file=${BERT_DIR}/bert_config.json \
|
||||
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
|
||||
--train_batch_size=4 \
|
||||
--eval_batch_size=4 \
|
||||
--steps_per_loop=1 \
|
||||
--learning_rate=2e-5 \
|
||||
--num_train_epochs=3 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=mirrored
|
||||
```
|
||||
|
||||
Alternatively, instead of specifying `init_checkpoint`, you can specify
|
||||
`hub_module_url` to employ a pretraind BERT hub module, e.g.,
|
||||
` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
|
||||
|
||||
To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
|
||||
information and use remote storage for model checkpoints.
|
||||
|
||||
```shell
|
||||
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
||||
export TPU_IP_ADDRESS='???'
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
export GLUE_DIR=gs://some_bucket/datasets
|
||||
export TASK=MRPC
|
||||
|
||||
python run_classifier.py \
|
||||
--mode='train_and_eval' \
|
||||
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
||||
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
|
||||
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
||||
--bert_config_file=${BERT_DIR}/bert_config.json \
|
||||
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
|
||||
--train_batch_size=32 \
|
||||
--eval_batch_size=32 \
|
||||
--steps_per_loop=1000 \
|
||||
--learning_rate=2e-5 \
|
||||
--num_train_epochs=3 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=tpu \
|
||||
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
||||
```
|
||||
|
||||
Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
|
||||
training steps inside a `tf.function` can significantly increase TPU utilization
|
||||
and callbacks will not be called inside the loop.
|
||||
|
||||
### SQuAD 1.1
|
||||
|
||||
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
|
||||
benchmark dataset. See more in [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
|
||||
|
||||
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
|
||||
workflow.
|
||||
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
|
||||
(uncased_L-12_H-768_A-12).
|
||||
|
||||
```shell
|
||||
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
||||
export SQUAD_DIR=gs://some_bucket/datasets
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
export SQUAD_VERSION=v1.1
|
||||
|
||||
python run_squad.py \
|
||||
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
||||
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
||||
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
|
||||
--vocab_file=${BERT_DIR}/vocab.txt \
|
||||
--bert_config_file=${BERT_DIR}/bert_config.json \
|
||||
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
|
||||
--train_batch_size=4 \
|
||||
--predict_batch_size=4 \
|
||||
--learning_rate=8e-5 \
|
||||
--num_train_epochs=2 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=mirrored
|
||||
```
|
||||
|
||||
Similarily, you can replace `init_checkpoint` FLAG with `hub_module_url` to
|
||||
specify a hub module path.
|
||||
|
||||
To use TPU, you need switch distribution strategy type to `tpu` with TPU
|
||||
information.
|
||||
|
||||
```shell
|
||||
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
||||
export TPU_IP_ADDRESS='???'
|
||||
export MODEL_DIR=gs://some_bucket/my_output_dir
|
||||
export SQUAD_DIR=gs://some_bucket/datasets
|
||||
export SQUAD_VERSION=v1.1
|
||||
|
||||
python run_squad.py \
|
||||
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
||||
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
||||
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
|
||||
--vocab_file=${BERT_DIR}/vocab.txt \
|
||||
--bert_config_file=${BERT_DIR}/bert_config.json \
|
||||
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
|
||||
--train_batch_size=32 \
|
||||
--learning_rate=8e-5 \
|
||||
--num_train_epochs=2 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=tpu \
|
||||
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
||||
```
|
||||
|
||||
The dev set predictions will be saved into a file called predictions.json in the
|
||||
model_dir:
|
||||
|
||||
```shell
|
||||
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
|
||||
```
|
||||
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
|
||||
+110
@@ -0,0 +1,110 @@
|
||||
# BERT FineTuning with Cloud TPU: Sentence and Sentence-Pair Classification Tasks (TF 2.1)
|
||||
This tutorial shows you how to train the Bidirectional Encoder Representations from Transformers (BERT) model on Cloud TPU.
|
||||
|
||||
|
||||
## Set up Cloud Storage and Compute Engine VM
|
||||
1. [Open a cloud shell window](https://console.cloud.google.com/?cloudshell=true&_ga=2.11844148.-1612541229.1552429951)
|
||||
2. Create a variable for the project's name:
|
||||
```
|
||||
export PROJECT_NAME=your-project_name
|
||||
```
|
||||
3. Configure `gcloud` command-line tool to use the project where you want to create Cloud TPU.
|
||||
```
|
||||
gcloud config set project ${PROJECT_NAME}
|
||||
```
|
||||
4. Create a Cloud Storage bucket using the following command:
|
||||
```
|
||||
gsutil mb -p ${PROJECT_NAME} -c standard -l europe-west4 -b on gs://your-bucket-name
|
||||
```
|
||||
This Cloud Storage bucket stores the data you use to train your model and the training results.
|
||||
5. Launch a Compute Engine VM and Cloud TPU using the ctpu up command.
|
||||
```
|
||||
ctpu up --tpu-size=v3-8 \
|
||||
--machine-type=n1-standard-8 \
|
||||
--zone=europe-west4-a \
|
||||
--tf-version=2.1 [optional flags: --project, --name]
|
||||
```
|
||||
6. The configuration you specified appears. Enter y to approve or n to cancel.
|
||||
7. When the ctpu up command has finished executing, verify that your shell prompt has changed from username@project to username@tpuname. This change shows that you are now logged into your Compute Engine VM.
|
||||
```
|
||||
gcloud compute ssh vm-name --zone=europe-west4-a
|
||||
(vm)$ export TPU_NAME=vm-name
|
||||
```
|
||||
As you continue these instructions, run each command that begins with `(vm)$` in your VM session window.
|
||||
|
||||
## Prepare the Dataset
|
||||
1. From your Compute Engine virtual machine (VM), install requirements.txt.
|
||||
```
|
||||
(vm)$ cd /usr/share/models
|
||||
(vm)$ sudo pip3 install -r official/requirements.txt
|
||||
```
|
||||
2. Optional: download download_glue_data.py
|
||||
|
||||
This tutorial uses the General Language Understanding Evaluation (GLUE) benchmark to evaluate and analyze the performance of the model. The GLUE data is provided for this tutorial at gs://cloud-tpu-checkpoints/bert/classification.
|
||||
|
||||
## Define parameter values
|
||||
Next, define several parameter values that are required when you train and evaluate your model:
|
||||
|
||||
```
|
||||
(vm)$ export PYTHONPATH="$PYTHONPATH:/usr/share/tpu/models"
|
||||
(vm)$ export STORAGE_BUCKET=gs://your-bucket-name
|
||||
(vm)$ export BERT_BASE_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
||||
(vm)$ export MODEL_DIR=${STORAGE_BUCKET}/bert-output
|
||||
(vm)$ export GLUE_DIR=gs://cloud-tpu-checkpoints/bert/classification
|
||||
(vm)$ export TASK=mnli
|
||||
```
|
||||
|
||||
## Train the model
|
||||
From your Compute Engine VM, run the following command.
|
||||
|
||||
```
|
||||
(vm)$ python3 official/nlp/bert/run_classifier.py \
|
||||
--mode='train_and_eval' \
|
||||
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
||||
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
|
||||
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
||||
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
|
||||
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
|
||||
--train_batch_size=32 \
|
||||
--eval_batch_size=32 \
|
||||
--learning_rate=2e-5 \
|
||||
--num_train_epochs=3 \
|
||||
--model_dir=${MODEL_DIR} \
|
||||
--distribution_strategy=tpu \
|
||||
--tpu=${TPU_NAME}
|
||||
```
|
||||
|
||||
## Verify your results
|
||||
The training takes approximately 1 hour on a v3-8 TPU. When script completes, you should see results similar to the following:
|
||||
```
|
||||
Training Summary:
|
||||
{'train_loss': 0.28142181038856506,
|
||||
'last_train_metrics': 0.9467429518699646,
|
||||
'eval_metrics': 0.8599063158035278,
|
||||
'total_training_steps': 36813}
|
||||
```
|
||||
|
||||
## Clean up
|
||||
To avoid incurring charges to your GCP account for the resources used in this topic:
|
||||
1. Disconnect from the Compute Engine VM:
|
||||
```
|
||||
(vm)$ exit
|
||||
```
|
||||
2. In your Cloud Shell, run ctpu delete with the --zone flag you used when you set up the Cloud TPU to delete your Compute Engine VM and your Cloud TPU:
|
||||
```
|
||||
$ ctpu delete --zone=your-zone
|
||||
```
|
||||
3. Run ctpu status specifying your zone to make sure you have no instances allocated to avoid unnecessary charges for TPU usage. The deletion might take several minutes. A response like the one below indicates there are no more allocated instances:
|
||||
```
|
||||
$ ctpu status --zone=your-zone
|
||||
```
|
||||
4. Run gsutil as shown, replacing your-bucket with the name of the Cloud Storage bucket you created for this tutorial:
|
||||
```
|
||||
$ gsutil rm -r gs://your-bucket
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+349
@@ -0,0 +1,349 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""BERT models that are compatible with TF 2.0."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gin
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from official.modeling import tf_utils
|
||||
from official.nlp.albert import configs as albert_configs
|
||||
from official.nlp.bert import configs
|
||||
from official.nlp.modeling import losses
|
||||
from official.nlp.modeling import models
|
||||
from official.nlp.modeling import networks
|
||||
|
||||
|
||||
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
|
||||
"""Returns layer that computes custom loss and metrics for pretraining."""
|
||||
|
||||
def __init__(self, vocab_size, **kwargs):
|
||||
super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
|
||||
self._vocab_size = vocab_size
|
||||
self.config = {
|
||||
'vocab_size': vocab_size,
|
||||
}
|
||||
|
||||
def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
|
||||
lm_example_loss, sentence_output, sentence_labels,
|
||||
next_sentence_loss):
|
||||
"""Adds metrics."""
|
||||
masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
|
||||
lm_labels, lm_output)
|
||||
numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
|
||||
denominator = tf.reduce_sum(lm_label_weights) + 1e-5
|
||||
masked_lm_accuracy = numerator / denominator
|
||||
self.add_metric(
|
||||
masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
|
||||
|
||||
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
|
||||
|
||||
if sentence_labels is not None:
|
||||
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
|
||||
sentence_labels, sentence_output)
|
||||
self.add_metric(
|
||||
next_sentence_accuracy,
|
||||
name='next_sentence_accuracy',
|
||||
aggregation='mean')
|
||||
|
||||
if next_sentence_loss is not None:
|
||||
self.add_metric(
|
||||
next_sentence_loss, name='next_sentence_loss', aggregation='mean')
|
||||
|
||||
def call(self,
|
||||
lm_output,
|
||||
sentence_output,
|
||||
lm_label_ids,
|
||||
lm_label_weights,
|
||||
sentence_labels=None):
|
||||
"""Implements call() for the layer."""
|
||||
lm_label_weights = tf.cast(lm_label_weights, tf.float32)
|
||||
lm_output = tf.cast(lm_output, tf.float32)
|
||||
|
||||
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
|
||||
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
|
||||
|
||||
if sentence_labels is not None:
|
||||
sentence_output = tf.cast(sentence_output, tf.float32)
|
||||
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
|
||||
labels=sentence_labels, predictions=sentence_output)
|
||||
loss = mask_label_loss + sentence_loss
|
||||
else:
|
||||
sentence_loss = None
|
||||
loss = mask_label_loss
|
||||
|
||||
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
|
||||
# TODO(hongkuny): Avoids the hack and switches add_loss.
|
||||
final_loss = tf.fill(batch_shape, loss)
|
||||
|
||||
self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
|
||||
mask_label_loss, sentence_output, sentence_labels,
|
||||
sentence_loss)
|
||||
return final_loss
|
||||
|
||||
|
||||
@gin.configurable
|
||||
def get_transformer_encoder(bert_config,
|
||||
sequence_length,
|
||||
transformer_encoder_cls=None):
|
||||
"""Gets a 'TransformerEncoder' object.
|
||||
|
||||
Args:
|
||||
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
|
||||
sequence_length: Maximum sequence length of the training data.
|
||||
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
|
||||
default BERT encoder implementation.
|
||||
|
||||
Returns:
|
||||
A networks.TransformerEncoder object.
|
||||
"""
|
||||
if transformer_encoder_cls is not None:
|
||||
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
|
||||
embedding_cfg = dict(
|
||||
vocab_size=bert_config.vocab_size,
|
||||
type_vocab_size=bert_config.type_vocab_size,
|
||||
hidden_size=bert_config.hidden_size,
|
||||
seq_length=sequence_length,
|
||||
max_seq_length=bert_config.max_position_embeddings,
|
||||
initializer=tf.keras.initializers.TruncatedNormal(
|
||||
stddev=bert_config.initializer_range),
|
||||
dropout_rate=bert_config.hidden_dropout_prob,
|
||||
)
|
||||
hidden_cfg = dict(
|
||||
num_attention_heads=bert_config.num_attention_heads,
|
||||
intermediate_size=bert_config.intermediate_size,
|
||||
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
|
||||
dropout_rate=bert_config.hidden_dropout_prob,
|
||||
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
|
||||
)
|
||||
kwargs = dict(
|
||||
embedding_cfg=embedding_cfg,
|
||||
hidden_cfg=hidden_cfg,
|
||||
num_hidden_instances=bert_config.num_hidden_layers,
|
||||
pooled_output_dim=bert_config.hidden_size,
|
||||
)
|
||||
|
||||
# Relies on gin configuration to define the Transformer encoder arguments.
|
||||
return transformer_encoder_cls(**kwargs)
|
||||
|
||||
kwargs = dict(
|
||||
vocab_size=bert_config.vocab_size,
|
||||
hidden_size=bert_config.hidden_size,
|
||||
num_layers=bert_config.num_hidden_layers,
|
||||
num_attention_heads=bert_config.num_attention_heads,
|
||||
intermediate_size=bert_config.intermediate_size,
|
||||
activation=tf_utils.get_activation(bert_config.hidden_act),
|
||||
dropout_rate=bert_config.hidden_dropout_prob,
|
||||
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
|
||||
sequence_length=sequence_length,
|
||||
max_sequence_length=bert_config.max_position_embeddings,
|
||||
type_vocab_size=bert_config.type_vocab_size,
|
||||
initializer=tf.keras.initializers.TruncatedNormal(
|
||||
stddev=bert_config.initializer_range))
|
||||
if isinstance(bert_config, albert_configs.AlbertConfig):
|
||||
kwargs['embedding_width'] = bert_config.embedding_size
|
||||
return networks.AlbertTransformerEncoder(**kwargs)
|
||||
else:
|
||||
assert isinstance(bert_config, configs.BertConfig)
|
||||
return networks.TransformerEncoder(**kwargs)
|
||||
|
||||
|
||||
def pretrain_model(bert_config,
|
||||
seq_length,
|
||||
max_predictions_per_seq,
|
||||
initializer=None,
|
||||
use_next_sentence_label=True):
|
||||
"""Returns model to be used for pre-training.
|
||||
|
||||
Args:
|
||||
bert_config: Configuration that defines the core BERT model.
|
||||
seq_length: Maximum sequence length of the training data.
|
||||
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
|
||||
and use for pretraining.
|
||||
initializer: Initializer for weights in BertPretrainer.
|
||||
use_next_sentence_label: Whether to use the next sentence label.
|
||||
|
||||
Returns:
|
||||
Pretraining model as well as core BERT submodel from which to save
|
||||
weights after pretraining.
|
||||
"""
|
||||
input_word_ids = tf.keras.layers.Input(
|
||||
shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
|
||||
input_mask = tf.keras.layers.Input(
|
||||
shape=(seq_length,), name='input_mask', dtype=tf.int32)
|
||||
input_type_ids = tf.keras.layers.Input(
|
||||
shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
|
||||
masked_lm_positions = tf.keras.layers.Input(
|
||||
shape=(max_predictions_per_seq,),
|
||||
name='masked_lm_positions',
|
||||
dtype=tf.int32)
|
||||
masked_lm_ids = tf.keras.layers.Input(
|
||||
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
|
||||
masked_lm_weights = tf.keras.layers.Input(
|
||||
shape=(max_predictions_per_seq,),
|
||||
name='masked_lm_weights',
|
||||
dtype=tf.int32)
|
||||
|
||||
if use_next_sentence_label:
|
||||
next_sentence_labels = tf.keras.layers.Input(
|
||||
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
|
||||
else:
|
||||
next_sentence_labels = None
|
||||
|
||||
transformer_encoder = get_transformer_encoder(bert_config, seq_length)
|
||||
if initializer is None:
|
||||
initializer = tf.keras.initializers.TruncatedNormal(
|
||||
stddev=bert_config.initializer_range)
|
||||
pretrainer_model = models.BertPretrainer(
|
||||
network=transformer_encoder,
|
||||
num_classes=2, # The next sentence prediction label has two classes.
|
||||
num_token_predictions=max_predictions_per_seq,
|
||||
initializer=initializer,
|
||||
output='predictions')
|
||||
|
||||
lm_output, sentence_output = pretrainer_model(
|
||||
[input_word_ids, input_mask, input_type_ids, masked_lm_positions])
|
||||
|
||||
pretrain_loss_layer = BertPretrainLossAndMetricLayer(
|
||||
vocab_size=bert_config.vocab_size)
|
||||
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_labels)
|
||||
inputs = {
|
||||
'input_word_ids': input_word_ids,
|
||||
'input_mask': input_mask,
|
||||
'input_type_ids': input_type_ids,
|
||||
'masked_lm_positions': masked_lm_positions,
|
||||
'masked_lm_ids': masked_lm_ids,
|
||||
'masked_lm_weights': masked_lm_weights,
|
||||
}
|
||||
if use_next_sentence_label:
|
||||
inputs['next_sentence_labels'] = next_sentence_labels
|
||||
|
||||
keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
|
||||
return keras_model, transformer_encoder
|
||||
|
||||
|
||||
def squad_model(bert_config,
|
||||
max_seq_length,
|
||||
initializer=None,
|
||||
hub_module_url=None,
|
||||
hub_module_trainable=True):
|
||||
"""Returns BERT Squad model along with core BERT model to import weights.
|
||||
|
||||
Args:
|
||||
bert_config: BertConfig, the config defines the core Bert model.
|
||||
max_seq_length: integer, the maximum input sequence length.
|
||||
initializer: Initializer for the final dense layer in the span labeler.
|
||||
Defaulted to TruncatedNormal initializer.
|
||||
hub_module_url: TF-Hub path/url to Bert module.
|
||||
hub_module_trainable: True to finetune layers in the hub module.
|
||||
|
||||
Returns:
|
||||
A tuple of (1) keras model that outputs start logits and end logits and
|
||||
(2) the core BERT transformer encoder.
|
||||
"""
|
||||
if initializer is None:
|
||||
initializer = tf.keras.initializers.TruncatedNormal(
|
||||
stddev=bert_config.initializer_range)
|
||||
if not hub_module_url:
|
||||
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
|
||||
return models.BertSpanLabeler(
|
||||
network=bert_encoder, initializer=initializer), bert_encoder
|
||||
|
||||
input_word_ids = tf.keras.layers.Input(
|
||||
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
|
||||
input_mask = tf.keras.layers.Input(
|
||||
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
|
||||
input_type_ids = tf.keras.layers.Input(
|
||||
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
|
||||
core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
|
||||
pooled_output, sequence_output = core_model(
|
||||
[input_word_ids, input_mask, input_type_ids])
|
||||
bert_encoder = tf.keras.Model(
|
||||
inputs={
|
||||
'input_word_ids': input_word_ids,
|
||||
'input_mask': input_mask,
|
||||
'input_type_ids': input_type_ids,
|
||||
},
|
||||
outputs=[sequence_output, pooled_output],
|
||||
name='core_model')
|
||||
return models.BertSpanLabeler(
|
||||
network=bert_encoder, initializer=initializer), bert_encoder
|
||||
|
||||
|
||||
def classifier_model(bert_config,
|
||||
num_labels,
|
||||
max_seq_length,
|
||||
final_layer_initializer=None,
|
||||
hub_module_url=None,
|
||||
hub_module_trainable=True):
|
||||
"""BERT classifier model in functional API style.
|
||||
|
||||
Construct a Keras model for predicting `num_labels` outputs from an input with
|
||||
maximum sequence length `max_seq_length`.
|
||||
|
||||
Args:
|
||||
bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
|
||||
ALBERT model.
|
||||
num_labels: integer, the number of classes.
|
||||
max_seq_length: integer, the maximum input sequence length.
|
||||
final_layer_initializer: Initializer for final dense layer. Defaulted
|
||||
TruncatedNormal initializer.
|
||||
hub_module_url: TF-Hub path/url to Bert module.
|
||||
hub_module_trainable: True to finetune layers in the hub module.
|
||||
|
||||
Returns:
|
||||
Combined prediction model (words, mask, type) -> (one-hot labels)
|
||||
BERT sub-model (words, mask, type) -> (bert_outputs)
|
||||
"""
|
||||
if final_layer_initializer is not None:
|
||||
initializer = final_layer_initializer
|
||||
else:
|
||||
initializer = tf.keras.initializers.TruncatedNormal(
|
||||
stddev=bert_config.initializer_range)
|
||||
|
||||
if not hub_module_url:
|
||||
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
|
||||
return models.BertClassifier(
|
||||
bert_encoder,
|
||||
num_classes=num_labels,
|
||||
dropout_rate=bert_config.hidden_dropout_prob,
|
||||
initializer=initializer), bert_encoder
|
||||
|
||||
input_word_ids = tf.keras.layers.Input(
|
||||
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
|
||||
input_mask = tf.keras.layers.Input(
|
||||
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
|
||||
input_type_ids = tf.keras.layers.Input(
|
||||
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
|
||||
bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
|
||||
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
|
||||
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
|
||||
pooled_output)
|
||||
|
||||
output = tf.keras.layers.Dense(
|
||||
num_labels, kernel_initializer=initializer, name='output')(
|
||||
output)
|
||||
return tf.keras.Model(
|
||||
inputs={
|
||||
'input_word_ids': input_word_ids,
|
||||
'input_mask': input_mask,
|
||||
'input_type_ids': input_type_ids
|
||||
},
|
||||
outputs=output), bert_model
|
||||
+118
@@ -0,0 +1,118 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Defining common flags used across all BERT models/applications."""
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from official.utils.flags import core as flags_core
|
||||
|
||||
|
||||
def define_gin_flags():
|
||||
"""Define common gin configurable flags."""
|
||||
flags.DEFINE_multi_string('gin_file', None,
|
||||
'List of paths to the config files.')
|
||||
flags.DEFINE_multi_string(
|
||||
'gin_param', None, 'Newline separated list of Gin parameter bindings.')
|
||||
|
||||
|
||||
def define_common_bert_flags():
|
||||
"""Define common flags for BERT tasks."""
|
||||
flags_core.define_base(
|
||||
data_dir=False,
|
||||
model_dir=True,
|
||||
clean=False,
|
||||
train_epochs=False,
|
||||
epochs_between_evals=False,
|
||||
stop_threshold=False,
|
||||
batch_size=False,
|
||||
num_gpu=True,
|
||||
hooks=False,
|
||||
export_dir=False,
|
||||
distribution_strategy=True,
|
||||
run_eagerly=True)
|
||||
flags_core.define_distribution()
|
||||
flags.DEFINE_string('bert_config_file', None,
|
||||
'Bert configuration file to define core bert layers.')
|
||||
flags.DEFINE_string(
|
||||
'model_export_path', None,
|
||||
'Path to the directory, where trainined model will be '
|
||||
'exported.')
|
||||
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
|
||||
flags.DEFINE_string(
|
||||
'init_checkpoint', None,
|
||||
'Initial checkpoint (usually from a pre-trained BERT model).')
|
||||
flags.DEFINE_integer('num_train_epochs', 3,
|
||||
'Total number of training epochs to perform.')
|
||||
flags.DEFINE_integer(
|
||||
'steps_per_loop', 1,
|
||||
'Number of steps per graph-mode loop. Only training step '
|
||||
'happens inside the loop. Callbacks will not be called '
|
||||
'inside.')
|
||||
flags.DEFINE_float('learning_rate', 5e-5,
|
||||
'The initial learning rate for Adam.')
|
||||
flags.DEFINE_float('end_lr', 0.0,
|
||||
'The end learning rate for learning rate decay.')
|
||||
flags.DEFINE_string('optimizer_type', 'adamw',
|
||||
'The type of optimizer to use for training (adamw|lamb)')
|
||||
flags.DEFINE_boolean(
|
||||
'scale_loss', False,
|
||||
'Whether to divide the loss by number of replica inside the per-replica '
|
||||
'loss function.')
|
||||
flags.DEFINE_boolean(
|
||||
'use_keras_compile_fit', False,
|
||||
'If True, uses Keras compile/fit() API for training logic. Otherwise '
|
||||
'use custom training loop.')
|
||||
flags.DEFINE_string(
|
||||
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
|
||||
'If specified, init_checkpoint flag should not be used.')
|
||||
flags.DEFINE_bool('hub_module_trainable', True,
|
||||
'True to make keras layers in the hub module trainable.')
|
||||
|
||||
flags_core.define_log_steps()
|
||||
|
||||
# Adds flags for mixed precision and multi-worker training.
|
||||
flags_core.define_performance(
|
||||
num_parallel_calls=False,
|
||||
inter_op=False,
|
||||
intra_op=False,
|
||||
synthetic_data=False,
|
||||
max_train_steps=False,
|
||||
dtype=True,
|
||||
dynamic_loss_scale=True,
|
||||
loss_scale=True,
|
||||
all_reduce_alg=True,
|
||||
num_packs=False,
|
||||
tf_gpu_thread_mode=True,
|
||||
datasets_num_private_threads=True,
|
||||
enable_xla=True,
|
||||
fp16_implementation=True,
|
||||
)
|
||||
|
||||
|
||||
def dtype():
|
||||
return flags_core.get_tf_dtype(flags.FLAGS)
|
||||
|
||||
|
||||
def use_float16():
|
||||
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
|
||||
|
||||
|
||||
def use_graph_rewrite():
|
||||
return flags.FLAGS.fp16_implementation == 'graph_rewrite'
|
||||
|
||||
|
||||
def get_loss_scale():
|
||||
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
|
||||
+105
@@ -0,0 +1,105 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The main BERT model and related functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import json
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
"""Configuration for `BertModel`."""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
backward_compatible=True):
|
||||
"""Constructs BertConfig.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler.
|
||||
hidden_dropout_prob: The dropout probability for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`BertModel`.
|
||||
initializer_range: The stdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
backward_compatible: Boolean, whether the variables shape are compatible
|
||||
with checkpoints converted from TF 1.x BERT.
|
||||
"""
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.backward_compatible = backward_compatible
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = BertConfig(vocab_size=None)
|
||||
for (key, value) in six.iteritems(json_object):
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with tf.io.gfile.GFile(json_file, "r") as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
+86
@@ -0,0 +1,86 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A script to export the BERT core model as a TF-Hub SavedModel."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
# from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
from typing import Text
|
||||
from official.nlp.bert import bert_models
|
||||
from official.nlp.bert import configs
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("bert_config_file", None,
|
||||
"Bert configuration file to define core bert layers.")
|
||||
flags.DEFINE_string("model_checkpoint_path", None,
|
||||
"File path to TF model checkpoint.")
|
||||
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
|
||||
flags.DEFINE_string("vocab_file", None,
|
||||
"The vocabulary file that the BERT model was trained on.")
|
||||
|
||||
|
||||
def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
|
||||
"""Creates a BERT keras core model from BERT configuration.
|
||||
|
||||
Args:
|
||||
bert_config: A `BertConfig` to create the core model.
|
||||
|
||||
Returns:
|
||||
A keras model.
|
||||
"""
|
||||
# Adds input layers just as placeholders.
|
||||
input_word_ids = tf.keras.layers.Input(
|
||||
shape=(None,), dtype=tf.int32, name="input_word_ids")
|
||||
input_mask = tf.keras.layers.Input(
|
||||
shape=(None,), dtype=tf.int32, name="input_mask")
|
||||
input_type_ids = tf.keras.layers.Input(
|
||||
shape=(None,), dtype=tf.int32, name="input_type_ids")
|
||||
transformer_encoder = bert_models.get_transformer_encoder(
|
||||
bert_config, sequence_length=None)
|
||||
sequence_output, pooled_output = transformer_encoder(
|
||||
[input_word_ids, input_mask, input_type_ids])
|
||||
# To keep consistent with legacy hub modules, the outputs are
|
||||
# "pooled_output" and "sequence_output".
|
||||
return tf.keras.Model(
|
||||
inputs=[input_word_ids, input_mask, input_type_ids],
|
||||
outputs=[pooled_output, sequence_output]), transformer_encoder
|
||||
|
||||
|
||||
def export_bert_tfhub(bert_config: configs.BertConfig,
|
||||
model_checkpoint_path: Text, hub_destination: Text,
|
||||
vocab_file: Text):
|
||||
"""Restores a tf.keras.Model and saves for TF-Hub."""
|
||||
core_model, encoder = create_bert_model(bert_config)
|
||||
checkpoint = tf.train.Checkpoint(model=encoder)
|
||||
checkpoint.restore(model_checkpoint_path).assert_consumed()
|
||||
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
|
||||
core_model.do_lower_case = tf.Variable(
|
||||
"uncased" in vocab_file, trainable=False)
|
||||
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
|
||||
|
||||
|
||||
def main(_):
|
||||
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
|
||||
FLAGS.vocab_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
+87
@@ -0,0 +1,87 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests official.nlp.bert.export_tfhub."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
from official.nlp.bert import configs
|
||||
from official.nlp.bert import export_tfhub
|
||||
|
||||
|
||||
class ExportTfhubTest(tf.test.TestCase):
|
||||
|
||||
def test_export_tfhub(self):
|
||||
# Exports a savedmodel for TF-Hub
|
||||
bert_config = configs.BertConfig(
|
||||
vocab_size=100,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
max_position_embeddings=128,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=1)
|
||||
bert_model, encoder = export_tfhub.create_bert_model(bert_config)
|
||||
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
|
||||
checkpoint = tf.train.Checkpoint(model=encoder)
|
||||
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
|
||||
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
|
||||
|
||||
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
|
||||
with tf.io.gfile.GFile(vocab_file, "w") as f:
|
||||
f.write("dummy content")
|
||||
|
||||
hub_destination = os.path.join(self.get_temp_dir(), "hub")
|
||||
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
|
||||
hub_destination, vocab_file)
|
||||
|
||||
# Restores a hub KerasLayer.
|
||||
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
|
||||
|
||||
if hasattr(hub_layer, "resolved_object"):
|
||||
# Checks meta attributes.
|
||||
self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
|
||||
with tf.io.gfile.GFile(
|
||||
hub_layer.resolved_object.vocab_file.asset_path.numpy()) as f:
|
||||
self.assertEqual("dummy content", f.read())
|
||||
# Checks the hub KerasLayer.
|
||||
for source_weight, hub_weight in zip(bert_model.trainable_weights,
|
||||
hub_layer.trainable_weights):
|
||||
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
|
||||
|
||||
dummy_ids = np.zeros((2, 10), dtype=np.int32)
|
||||
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
|
||||
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
|
||||
|
||||
# The outputs of hub module are "pooled_output" and "sequence_output",
|
||||
# while the outputs of encoder is in reversed order, i.e.,
|
||||
# "sequence_output" and "pooled_output".
|
||||
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
|
||||
self.assertEqual(hub_outputs[0].shape, (2, 16))
|
||||
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
|
||||
for source_output, hub_output, encoder_output in zip(
|
||||
source_outputs, hub_outputs, encoder_outputs):
|
||||
self.assertAllClose(source_output.numpy(), hub_output.numpy())
|
||||
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
+231
@@ -0,0 +1,231 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""BERT model input pipelines."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def decode_record(record, name_to_features):
|
||||
"""Decodes a record to a TensorFlow example."""
|
||||
example = tf.io.parse_single_example(record, name_to_features)
|
||||
|
||||
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||
# So cast all int64 to int32.
|
||||
for name in list(example.keys()):
|
||||
t = example[name]
|
||||
if t.dtype == tf.int64:
|
||||
t = tf.cast(t, tf.int32)
|
||||
example[name] = t
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def single_file_dataset(input_file, name_to_features):
|
||||
"""Creates a single-file dataset to be passed for BERT custom training."""
|
||||
# For training, we want a lot of parallel reading and shuffling.
|
||||
# For eval, we want no shuffling and parallel reading doesn't matter.
|
||||
d = tf.data.TFRecordDataset(input_file)
|
||||
d = d.map(lambda record: decode_record(record, name_to_features))
|
||||
|
||||
# When `input_file` is a path to a single file or a list
|
||||
# containing a single path, disable auto sharding so that
|
||||
# same input file is sent to all workers.
|
||||
if isinstance(input_file, str) or len(input_file) == 1:
|
||||
options = tf.data.Options()
|
||||
options.experimental_distribute.auto_shard_policy = (
|
||||
tf.data.experimental.AutoShardPolicy.OFF)
|
||||
d = d.with_options(options)
|
||||
return d
|
||||
|
||||
|
||||
def create_pretrain_dataset(input_patterns,
|
||||
seq_length,
|
||||
max_predictions_per_seq,
|
||||
batch_size,
|
||||
is_training=True,
|
||||
input_pipeline_context=None,
|
||||
use_next_sentence_label=True):
|
||||
"""Creates input dataset from (tf)records files for pretraining."""
|
||||
name_to_features = {
|
||||
'input_ids':
|
||||
tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
'input_mask':
|
||||
tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
'segment_ids':
|
||||
tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
'masked_lm_positions':
|
||||
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
'masked_lm_ids':
|
||||
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||
'masked_lm_weights':
|
||||
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
|
||||
}
|
||||
if use_next_sentence_label:
|
||||
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
|
||||
tf.int64)
|
||||
|
||||
for input_pattern in input_patterns:
|
||||
if not tf.io.gfile.glob(input_pattern):
|
||||
raise ValueError('%s does not match any files.' % input_pattern)
|
||||
|
||||
dataset = tf.data.Dataset.list_files(input_patterns, shuffle=is_training)
|
||||
|
||||
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
|
||||
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
|
||||
input_pipeline_context.input_pipeline_id)
|
||||
if is_training:
|
||||
dataset = dataset.repeat()
|
||||
|
||||
# We set shuffle buffer to exactly match total number of
|
||||
# training files to ensure that training data is well shuffled.
|
||||
input_files = []
|
||||
for input_pattern in input_patterns:
|
||||
input_files.extend(tf.io.gfile.glob(input_pattern))
|
||||
dataset = dataset.shuffle(len(input_files))
|
||||
|
||||
# In parallel, create tf record dataset for each train files.
|
||||
# cycle_length = 8 means that up to 8 files will be read and deserialized in
|
||||
# parallel. You may want to increase this number if you have a large number of
|
||||
# CPU cores.
|
||||
dataset = dataset.interleave(
|
||||
tf.data.TFRecordDataset, cycle_length=8,
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
decode_fn = lambda record: decode_record(record, name_to_features)
|
||||
dataset = dataset.map(
|
||||
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
def _select_data_from_record(record):
|
||||
"""Filter out features to use for pretraining."""
|
||||
x = {
|
||||
'input_word_ids': record['input_ids'],
|
||||
'input_mask': record['input_mask'],
|
||||
'input_type_ids': record['segment_ids'],
|
||||
'masked_lm_positions': record['masked_lm_positions'],
|
||||
'masked_lm_ids': record['masked_lm_ids'],
|
||||
'masked_lm_weights': record['masked_lm_weights'],
|
||||
}
|
||||
if use_next_sentence_label:
|
||||
x['next_sentence_labels'] = record['next_sentence_labels']
|
||||
|
||||
y = record['masked_lm_weights']
|
||||
|
||||
return (x, y)
|
||||
|
||||
dataset = dataset.map(
|
||||
_select_data_from_record,
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
if is_training:
|
||||
dataset = dataset.shuffle(100)
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder=is_training)
|
||||
dataset = dataset.prefetch(1024)
|
||||
return dataset
|
||||
|
||||
|
||||
def create_classifier_dataset(file_path,
|
||||
seq_length,
|
||||
batch_size,
|
||||
is_training=True,
|
||||
input_pipeline_context=None):
|
||||
"""Creates input dataset from (tf)records files for train/eval."""
|
||||
name_to_features = {
|
||||
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
'label_ids': tf.io.FixedLenFeature([], tf.int64),
|
||||
'is_real_example': tf.io.FixedLenFeature([], tf.int64),
|
||||
}
|
||||
dataset = single_file_dataset(file_path, name_to_features)
|
||||
|
||||
# The dataset is always sharded by number of hosts.
|
||||
# num_input_pipelines is the number of hosts rather than number of cores.
|
||||
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
|
||||
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
|
||||
input_pipeline_context.input_pipeline_id)
|
||||
|
||||
def _select_data_from_record(record):
|
||||
x = {
|
||||
'input_word_ids': record['input_ids'],
|
||||
'input_mask': record['input_mask'],
|
||||
'input_type_ids': record['segment_ids']
|
||||
}
|
||||
y = record['label_ids']
|
||||
return (x, y)
|
||||
|
||||
dataset = dataset.map(_select_data_from_record)
|
||||
|
||||
if is_training:
|
||||
dataset = dataset.shuffle(100)
|
||||
dataset = dataset.repeat()
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder=is_training)
|
||||
dataset = dataset.prefetch(1024)
|
||||
return dataset
|
||||
|
||||
|
||||
def create_squad_dataset(file_path,
|
||||
seq_length,
|
||||
batch_size,
|
||||
is_training=True,
|
||||
input_pipeline_context=None):
|
||||
"""Creates input dataset from (tf)records files for train/eval."""
|
||||
name_to_features = {
|
||||
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
|
||||
}
|
||||
if is_training:
|
||||
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
|
||||
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
|
||||
else:
|
||||
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
|
||||
|
||||
dataset = single_file_dataset(file_path, name_to_features)
|
||||
|
||||
# The dataset is always sharded by number of hosts.
|
||||
# num_input_pipelines is the number of hosts rather than number of cores.
|
||||
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
|
||||
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
|
||||
input_pipeline_context.input_pipeline_id)
|
||||
|
||||
def _select_data_from_record(record):
|
||||
"""Dispatches record to features and labels."""
|
||||
x, y = {}, {}
|
||||
for name, tensor in record.items():
|
||||
if name in ('start_positions', 'end_positions'):
|
||||
y[name] = tensor
|
||||
elif name == 'input_ids':
|
||||
x['input_word_ids'] = tensor
|
||||
elif name == 'segment_ids':
|
||||
x['input_type_ids'] = tensor
|
||||
else:
|
||||
x[name] = tensor
|
||||
return (x, y)
|
||||
|
||||
dataset = dataset.map(_select_data_from_record)
|
||||
|
||||
if is_training:
|
||||
dataset = dataset.shuffle(100)
|
||||
dataset = dataset.repeat()
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
dataset = dataset.prefetch(1024)
|
||||
return dataset
|
||||
+101
@@ -0,0 +1,101 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities to save models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
# from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
import typing
|
||||
|
||||
|
||||
def export_bert_model(model_export_path: typing.Text,
|
||||
model: tf.keras.Model,
|
||||
checkpoint_dir: typing.Optional[typing.Text] = None,
|
||||
restore_model_using_load_weights: bool = False) -> None:
|
||||
"""Export BERT model for serving which does not include the optimizer.
|
||||
|
||||
Arguments:
|
||||
model_export_path: Path to which exported model will be saved.
|
||||
model: Keras model object to export.
|
||||
checkpoint_dir: Path from which model weights will be loaded, if
|
||||
specified.
|
||||
restore_model_using_load_weights: Whether to use checkpoint.restore() API
|
||||
for custom checkpoint or to use model.load_weights() API.
|
||||
There are 2 different ways to save checkpoints. One is using
|
||||
tf.train.Checkpoint and another is using Keras model.save_weights().
|
||||
Custom training loop implementation uses tf.train.Checkpoint API
|
||||
and Keras ModelCheckpoint callback internally uses model.save_weights()
|
||||
API. Since these two API's cannot be used toghether, model loading logic
|
||||
must be take into account how model checkpoint was saved.
|
||||
|
||||
Raises:
|
||||
ValueError when either model_export_path or model is not specified.
|
||||
"""
|
||||
if not model_export_path:
|
||||
raise ValueError('model_export_path must be specified.')
|
||||
if not isinstance(model, tf.keras.Model):
|
||||
raise ValueError('model must be a tf.keras.Model object.')
|
||||
|
||||
if checkpoint_dir:
|
||||
# Keras compile/fit() was used to save checkpoint using
|
||||
# model.save_weights().
|
||||
if restore_model_using_load_weights:
|
||||
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
|
||||
assert tf.io.gfile.exists(model_weight_path)
|
||||
model.load_weights(model_weight_path)
|
||||
|
||||
# tf.train.Checkpoint API was used via custom training loop logic.
|
||||
else:
|
||||
checkpoint = tf.train.Checkpoint(model=model)
|
||||
|
||||
# Restores the model from latest checkpoint.
|
||||
latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
|
||||
assert latest_checkpoint_file
|
||||
logging.info('Checkpoint file %s found and restoring from '
|
||||
'checkpoint', latest_checkpoint_file)
|
||||
checkpoint.restore(
|
||||
latest_checkpoint_file).assert_existing_objects_matched()
|
||||
|
||||
model.save(model_export_path, include_optimizer=False, save_format='tf')
|
||||
|
||||
|
||||
class BertModelCheckpoint(tf.keras.callbacks.Callback):
|
||||
"""Keras callback that saves model at the end of every epoch."""
|
||||
|
||||
def __init__(self, checkpoint_dir, checkpoint):
|
||||
"""Initializes BertModelCheckpoint.
|
||||
|
||||
Arguments:
|
||||
checkpoint_dir: Directory of the to be saved checkpoint file.
|
||||
checkpoint: tf.train.Checkpoint object.
|
||||
"""
|
||||
super(BertModelCheckpoint, self).__init__()
|
||||
self.checkpoint_file_name = os.path.join(
|
||||
checkpoint_dir, 'bert_training_checkpoint_step_{global_step}.ckpt')
|
||||
assert isinstance(checkpoint, tf.train.Checkpoint)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
global_step = tf.keras.backend.get_value(self.model.optimizer.iterations)
|
||||
formatted_file_name = self.checkpoint_file_name.format(
|
||||
global_step=global_step)
|
||||
saved_path = self.checkpoint.save(formatted_file_name)
|
||||
logging.info('Saving model TF checkpoint to : %s', saved_path)
|
||||
+491
@@ -0,0 +1,491 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A light weight utilities to train NLP models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
from official.staging.training import grad_utils
|
||||
from official.utils.misc import distribution_utils
|
||||
|
||||
_SUMMARY_TXT = 'training_summary.txt'
|
||||
_MIN_SUMMARY_STEPS = 10
|
||||
|
||||
|
||||
def _should_export_checkpoint(strategy):
|
||||
return (not strategy) or strategy.extended.should_checkpoint
|
||||
|
||||
|
||||
def _should_export_summary(strategy):
|
||||
return (not strategy) or strategy.extended.should_save_summary
|
||||
|
||||
|
||||
def _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_prefix):
|
||||
"""Saves model to with provided checkpoint prefix."""
|
||||
|
||||
if _should_export_checkpoint(strategy):
|
||||
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
|
||||
saved_path = checkpoint.save(checkpoint_path)
|
||||
logging.info('Saving model as TF checkpoint: %s', saved_path)
|
||||
else:
|
||||
# In multi worker training we need every worker to save checkpoint, because
|
||||
# variables can trigger synchronization on read and synchronization needs
|
||||
# all workers to participate. To avoid workers overriding each other we save
|
||||
# to a temporary directory on non-chief workers.
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
checkpoint.save(os.path.join(tmp_dir, 'ckpt'))
|
||||
tf.io.gfile.rmtree(tmp_dir)
|
||||
return
|
||||
|
||||
|
||||
def _get_input_iterator(input_fn, strategy):
|
||||
"""Returns distributed dataset iterator."""
|
||||
# When training with TPU pods, datasets needs to be cloned across
|
||||
# workers. Since Dataset instance cannot be cloned in eager mode, we instead
|
||||
# pass callable that returns a dataset.
|
||||
if not callable(input_fn):
|
||||
raise ValueError('`input_fn` should be a closure that returns a dataset.')
|
||||
iterator = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(input_fn))
|
||||
return iterator
|
||||
|
||||
|
||||
def _float_metric_value(metric):
|
||||
"""Gets the value of a float-value keras metric."""
|
||||
return metric.result().numpy().astype(float)
|
||||
|
||||
|
||||
def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
|
||||
"""Calculates steps to run on device."""
|
||||
if steps_per_loop <= 0:
|
||||
raise ValueError('steps_per_loop should be positive integer.')
|
||||
if steps_per_loop == 1:
|
||||
return steps_per_loop
|
||||
remainder_in_epoch = current_step % steps_per_epoch
|
||||
if remainder_in_epoch != 0:
|
||||
return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
|
||||
else:
|
||||
return steps_per_loop
|
||||
|
||||
|
||||
def write_txt_summary(training_summary, summary_dir):
|
||||
"""Writes a summary text file to record stats."""
|
||||
summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
|
||||
with tf.io.gfile.GFile(summary_path, 'wb') as f:
|
||||
logging.info('Training Summary: \n%s', str(training_summary))
|
||||
f.write(json.dumps(training_summary, indent=4))
|
||||
|
||||
|
||||
def run_customized_training_loop(
|
||||
# pylint: disable=invalid-name
|
||||
_sentinel=None,
|
||||
# pylint: enable=invalid-name
|
||||
strategy=None,
|
||||
model_fn=None,
|
||||
loss_fn=None,
|
||||
scale_loss=True,
|
||||
model_dir=None,
|
||||
train_input_fn=None,
|
||||
steps_per_epoch=None,
|
||||
steps_per_loop=1,
|
||||
epochs=1,
|
||||
eval_input_fn=None,
|
||||
eval_steps=None,
|
||||
metric_fn=None,
|
||||
init_checkpoint=None,
|
||||
custom_callbacks=None,
|
||||
run_eagerly=False,
|
||||
sub_model_export_name=None,
|
||||
explicit_allreduce=False,
|
||||
pre_allreduce_callbacks=None,
|
||||
post_allreduce_callbacks=None):
|
||||
"""Run BERT pretrain model training using low-level API.
|
||||
|
||||
Arguments:
|
||||
_sentinel: Used to prevent positional parameters. Internal, do not use.
|
||||
strategy: Distribution strategy on which to run low level training loop.
|
||||
model_fn: Function that returns a tuple (model, sub_model). Caller of this
|
||||
function should add optimizer to the `model` via calling
|
||||
`model.compile()` API or manually setting `model.optimizer` attribute.
|
||||
Second element of the returned tuple(sub_model) is an optional sub model
|
||||
to be used for initial checkpoint -- if provided.
|
||||
loss_fn: Function with signature func(labels, logits) and returns a loss
|
||||
tensor.
|
||||
scale_loss: Whether to divide the raw loss by number of replicas before
|
||||
gradients calculation.
|
||||
model_dir: Model directory used during training for restoring/saving model
|
||||
weights.
|
||||
train_input_fn: Function that returns a tf.data.Dataset used for training.
|
||||
steps_per_epoch: Number of steps to run per epoch. At the end of each
|
||||
epoch, model checkpoint will be saved and evaluation will be conducted
|
||||
if evaluation dataset is provided.
|
||||
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
|
||||
communication in eager context, training logs are printed every
|
||||
steps_per_loop.
|
||||
epochs: Number of epochs to train.
|
||||
eval_input_fn: Function that returns evaluation dataset. If none,
|
||||
evaluation is skipped.
|
||||
eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
|
||||
is not none.
|
||||
metric_fn: A metrics function that returns a Keras Metric object to record
|
||||
evaluation result using evaluation dataset or with training dataset
|
||||
after every epoch.
|
||||
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
|
||||
`model_fn`.
|
||||
custom_callbacks: A list of Keras Callbacks objects to run during
|
||||
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
|
||||
methods are invoked during training.
|
||||
run_eagerly: Whether to run model training in pure eager execution. This
|
||||
should be disable for TPUStrategy.
|
||||
sub_model_export_name: If not None, will export `sub_model` returned by
|
||||
`model_fn` into checkpoint files. The name of intermediate checkpoint
|
||||
file is {sub_model_export_name}_step_{step}.ckpt and the last
|
||||
checkpint's name is {sub_model_export_name}.ckpt;
|
||||
if None, `sub_model` will not be exported as checkpoint.
|
||||
explicit_allreduce: Whether to explicitly perform gradient allreduce,
|
||||
instead of relying on implicit allreduce in optimizer.apply_gradients().
|
||||
default is False. For now, if training using FP16 mixed precision,
|
||||
explicit allreduce will aggregate gradients in FP16 format. For TPU and
|
||||
GPU training using FP32, explicit allreduce will aggregate gradients in
|
||||
FP32 format.
|
||||
pre_allreduce_callbacks: A list of callback functions that takes gradients
|
||||
and model variables pairs as input, manipulate them, and returns a new
|
||||
gradients and model variables paris. The callback functions will be
|
||||
invoked in the list order and before gradients are allreduced.
|
||||
With mixed precision training, the pre_allreduce_allbacks will be
|
||||
applied on scaled_gradients. Default is no callbacks.
|
||||
Only used when explicit_allreduce=True.
|
||||
post_allreduce_callbacks: A list of callback functions that takes
|
||||
gradients and model variables pairs as input, manipulate them, and
|
||||
returns a new gradients and model variables paris. The callback
|
||||
functions will be invoked in the list order and right before gradients
|
||||
are applied to variables for updates. Default is no callbacks. Only used
|
||||
when explicit_allreduce=True.
|
||||
|
||||
Returns:
|
||||
Trained model.
|
||||
|
||||
Raises:
|
||||
ValueError: (1) When model returned by `model_fn` does not have optimizer
|
||||
attribute or when required parameters are set to none. (2) eval args are
|
||||
not specified correctly. (3) metric_fn must be a callable if specified.
|
||||
(4) sub_model_checkpoint_name is specified, but `sub_model` returned
|
||||
by `model_fn` is None.
|
||||
"""
|
||||
|
||||
if _sentinel is not None:
|
||||
raise ValueError('only call `run_customized_training_loop()` '
|
||||
'with named arguments.')
|
||||
|
||||
required_arguments = [
|
||||
strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
|
||||
]
|
||||
if [arg for arg in required_arguments if arg is None]:
|
||||
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
|
||||
'`steps_per_loop` and `steps_per_epoch` are required '
|
||||
'parameters.')
|
||||
if steps_per_loop > steps_per_epoch:
|
||||
logging.error(
|
||||
'steps_per_loop: %d is specified to be greater than '
|
||||
' steps_per_epoch: %d, we will use steps_per_epoch as'
|
||||
' steps_per_loop.', steps_per_loop, steps_per_epoch)
|
||||
steps_per_loop = steps_per_epoch
|
||||
assert tf.executing_eagerly()
|
||||
|
||||
if run_eagerly:
|
||||
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
|
||||
raise ValueError(
|
||||
'TPUStrategy should not run eagerly as it heavily relies on graph'
|
||||
' optimization for the distributed system.')
|
||||
|
||||
if eval_input_fn and (eval_steps is None or metric_fn is None):
|
||||
raise ValueError(
|
||||
'`eval_step` and `metric_fn` are required when `eval_input_fn ` '
|
||||
'is not none.')
|
||||
if metric_fn and not callable(metric_fn):
|
||||
raise ValueError(
|
||||
'if `metric_fn` is specified, metric_fn must be a callable.')
|
||||
|
||||
total_training_steps = steps_per_epoch * epochs
|
||||
train_iterator = _get_input_iterator(train_input_fn, strategy)
|
||||
|
||||
with distribution_utils.get_strategy_scope(strategy):
|
||||
# To correctly place the model weights on accelerators,
|
||||
# model and optimizer should be created in scope.
|
||||
model, sub_model = model_fn()
|
||||
if not hasattr(model, 'optimizer'):
|
||||
raise ValueError('User should set optimizer attribute to model '
|
||||
'inside `model_fn`.')
|
||||
if sub_model_export_name and sub_model is None:
|
||||
raise ValueError('sub_model_export_name is specified as %s, but '
|
||||
'sub_model is None.' % sub_model_export_name)
|
||||
|
||||
optimizer = model.optimizer
|
||||
|
||||
if init_checkpoint:
|
||||
logging.info(
|
||||
'Checkpoint file %s found and restoring from '
|
||||
'initial checkpoint for core model.', init_checkpoint)
|
||||
checkpoint = tf.train.Checkpoint(model=sub_model)
|
||||
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
|
||||
logging.info('Loading from checkpoint file completed')
|
||||
|
||||
train_loss_metric = tf.keras.metrics.Mean(
|
||||
'training_loss', dtype=tf.float32)
|
||||
eval_metrics = [metric_fn()] if metric_fn else []
|
||||
# If evaluation is required, make a copy of metric as it will be used by
|
||||
# both train and evaluation.
|
||||
train_metrics = [
|
||||
metric.__class__.from_config(metric.get_config())
|
||||
for metric in eval_metrics
|
||||
]
|
||||
|
||||
# Create summary writers
|
||||
if _should_export_summary(strategy):
|
||||
summary_dir = os.path.join(model_dir, 'summaries')
|
||||
else:
|
||||
# In multi worker training we need every worker to write summary, because
|
||||
# variables can trigger synchronization on read and synchronization needs
|
||||
# all workers to participate.
|
||||
summary_dir = tempfile.mkdtemp()
|
||||
eval_summary_writer = tf.summary.create_file_writer(
|
||||
os.path.join(summary_dir, 'eval'))
|
||||
if steps_per_loop >= _MIN_SUMMARY_STEPS:
|
||||
# Only writes summary when the stats are collected sufficiently over
|
||||
# enough steps.
|
||||
train_summary_writer = tf.summary.create_file_writer(
|
||||
os.path.join(summary_dir, 'train'))
|
||||
else:
|
||||
train_summary_writer = None
|
||||
|
||||
# Collects training variables.
|
||||
training_vars = model.trainable_variables
|
||||
|
||||
def _replicated_step(inputs):
|
||||
"""Replicated training step."""
|
||||
|
||||
inputs, labels = inputs
|
||||
with tf.GradientTape() as tape:
|
||||
model_outputs = model(inputs, training=True)
|
||||
loss = loss_fn(labels, model_outputs)
|
||||
# Raw loss is used for reporting in metrics/logs.
|
||||
raw_loss = loss
|
||||
if scale_loss:
|
||||
# Scales down the loss for gradients to be invariant from replicas.
|
||||
loss = loss / strategy.num_replicas_in_sync
|
||||
|
||||
if explicit_allreduce:
|
||||
grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
|
||||
training_vars,
|
||||
pre_allreduce_callbacks,
|
||||
post_allreduce_callbacks)
|
||||
else:
|
||||
if isinstance(optimizer,
|
||||
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
|
||||
with tape:
|
||||
scaled_loss = optimizer.get_scaled_loss(loss)
|
||||
scaled_grads = tape.gradient(scaled_loss, training_vars)
|
||||
grads = optimizer.get_unscaled_gradients(scaled_grads)
|
||||
else:
|
||||
grads = tape.gradient(loss, training_vars)
|
||||
optimizer.apply_gradients(zip(grads, training_vars))
|
||||
# For reporting, the metric takes the mean of losses.
|
||||
train_loss_metric.update_state(raw_loss)
|
||||
for metric in train_metrics:
|
||||
metric.update_state(labels, model_outputs)
|
||||
|
||||
@tf.function
|
||||
def train_steps(iterator, steps):
|
||||
"""Performs distributed training steps in a loop.
|
||||
|
||||
Args:
|
||||
iterator: the distributed iterator of training datasets.
|
||||
steps: an tf.int32 integer tensor to specify number of steps to run
|
||||
inside host training loop.
|
||||
|
||||
Raises:
|
||||
ValueError: Any of the arguments or tensor shapes are invalid.
|
||||
"""
|
||||
if not isinstance(steps, tf.Tensor):
|
||||
raise ValueError('steps should be an Tensor. Python object may cause '
|
||||
'retracing.')
|
||||
|
||||
for _ in tf.range(steps):
|
||||
strategy.run(_replicated_step, args=(next(iterator),))
|
||||
|
||||
def train_single_step(iterator):
|
||||
"""Performs a distributed training step.
|
||||
|
||||
Args:
|
||||
iterator: the distributed iterator of training datasets.
|
||||
|
||||
Raises:
|
||||
ValueError: Any of the arguments or tensor shapes are invalid.
|
||||
"""
|
||||
strategy.run(_replicated_step, args=(next(iterator),))
|
||||
|
||||
def test_step(iterator):
|
||||
"""Calculates evaluation metrics on distributed devices."""
|
||||
|
||||
def _test_step_fn(inputs):
|
||||
"""Replicated accuracy calculation."""
|
||||
|
||||
inputs, labels = inputs
|
||||
model_outputs = model(inputs, training=False)
|
||||
for metric in eval_metrics:
|
||||
metric.update_state(labels, model_outputs)
|
||||
|
||||
strategy.run(_test_step_fn, args=(next(iterator),))
|
||||
|
||||
if not run_eagerly:
|
||||
train_single_step = tf.function(train_single_step)
|
||||
test_step = tf.function(test_step)
|
||||
|
||||
def _run_evaluation(current_training_step, test_iterator):
|
||||
"""Runs validation steps and aggregate metrics."""
|
||||
for _ in range(eval_steps):
|
||||
test_step(test_iterator)
|
||||
|
||||
with eval_summary_writer.as_default():
|
||||
for metric in eval_metrics + model.metrics:
|
||||
metric_value = _float_metric_value(metric)
|
||||
logging.info('Step: [%d] Validation %s = %f', current_training_step,
|
||||
metric.name, metric_value)
|
||||
tf.summary.scalar(
|
||||
metric.name, metric_value, step=current_training_step)
|
||||
eval_summary_writer.flush()
|
||||
|
||||
def _run_callbacks_on_batch_begin(batch):
|
||||
"""Runs custom callbacks at the start of every step."""
|
||||
if not custom_callbacks:
|
||||
return
|
||||
for callback in custom_callbacks:
|
||||
callback.on_batch_begin(batch)
|
||||
|
||||
def _run_callbacks_on_batch_end(batch, logs):
|
||||
"""Runs custom callbacks at the end of every step."""
|
||||
if not custom_callbacks:
|
||||
return
|
||||
for callback in custom_callbacks:
|
||||
callback.on_batch_end(batch, logs)
|
||||
|
||||
# Training loop starts here.
|
||||
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
|
||||
sub_model_checkpoint = tf.train.Checkpoint(
|
||||
model=sub_model) if sub_model_export_name else None
|
||||
|
||||
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
|
||||
if latest_checkpoint_file:
|
||||
logging.info(
|
||||
'Checkpoint file %s found and restoring from '
|
||||
'checkpoint', latest_checkpoint_file)
|
||||
checkpoint.restore(latest_checkpoint_file)
|
||||
logging.info('Loading from checkpoint file completed')
|
||||
|
||||
current_step = optimizer.iterations.numpy()
|
||||
checkpoint_name = 'ctl_step_{step}.ckpt'
|
||||
|
||||
while current_step < total_training_steps:
|
||||
# Training loss/metric are taking average over steps inside micro
|
||||
# training loop. We reset the their values before each round.
|
||||
train_loss_metric.reset_states()
|
||||
for metric in train_metrics + model.metrics:
|
||||
metric.reset_states()
|
||||
|
||||
_run_callbacks_on_batch_begin(current_step)
|
||||
# Runs several steps in the host while loop.
|
||||
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
|
||||
|
||||
if tf.config.list_physical_devices('GPU'):
|
||||
# TODO(zongweiz): merge with train_steps once tf.while_loop
|
||||
# GPU performance bugs are fixed.
|
||||
for _ in range(steps):
|
||||
train_single_step(train_iterator)
|
||||
else:
|
||||
# Converts steps to a Tensor to avoid tf.function retracing.
|
||||
train_steps(train_iterator,
|
||||
tf.convert_to_tensor(steps, dtype=tf.int32))
|
||||
train_loss = _float_metric_value(train_loss_metric)
|
||||
current_step += steps
|
||||
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
|
||||
|
||||
# Updates training logging.
|
||||
training_status = 'Train Step: %d/%d / loss = %s' % (
|
||||
current_step, total_training_steps, train_loss)
|
||||
|
||||
if train_summary_writer:
|
||||
with train_summary_writer.as_default():
|
||||
tf.summary.scalar(
|
||||
train_loss_metric.name, train_loss, step=current_step)
|
||||
for metric in train_metrics + model.metrics:
|
||||
metric_value = _float_metric_value(metric)
|
||||
training_status += ' %s = %f' % (metric.name, metric_value)
|
||||
tf.summary.scalar(metric.name, metric_value, step=current_step)
|
||||
train_summary_writer.flush()
|
||||
logging.info(training_status)
|
||||
|
||||
# Saves model checkpoints and run validation steps at every epoch end.
|
||||
if current_step % steps_per_epoch == 0:
|
||||
# To avoid repeated model saving, we do not save after the last
|
||||
# step of training.
|
||||
if current_step < total_training_steps:
|
||||
_save_checkpoint(strategy, checkpoint, model_dir,
|
||||
checkpoint_name.format(step=current_step))
|
||||
if sub_model_export_name:
|
||||
_save_checkpoint(
|
||||
strategy, sub_model_checkpoint, model_dir,
|
||||
'%s_step_%d.ckpt' % (sub_model_export_name, current_step))
|
||||
if eval_input_fn:
|
||||
logging.info('Running evaluation after step: %s.', current_step)
|
||||
_run_evaluation(current_step,
|
||||
_get_input_iterator(eval_input_fn, strategy))
|
||||
# Re-initialize evaluation metric.
|
||||
for metric in eval_metrics + model.metrics:
|
||||
metric.reset_states()
|
||||
|
||||
_save_checkpoint(strategy, checkpoint, model_dir,
|
||||
checkpoint_name.format(step=current_step))
|
||||
if sub_model_export_name:
|
||||
_save_checkpoint(strategy, sub_model_checkpoint, model_dir,
|
||||
'%s.ckpt' % sub_model_export_name)
|
||||
|
||||
if eval_input_fn:
|
||||
logging.info('Running final evaluation after training is complete.')
|
||||
_run_evaluation(current_step,
|
||||
_get_input_iterator(eval_input_fn, strategy))
|
||||
|
||||
training_summary = {
|
||||
'total_training_steps': total_training_steps,
|
||||
'train_loss': _float_metric_value(train_loss_metric),
|
||||
}
|
||||
if eval_metrics:
|
||||
# TODO(hongkuny): Cleans up summary reporting in text.
|
||||
training_summary['last_train_metrics'] = _float_metric_value(
|
||||
train_metrics[0])
|
||||
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
|
||||
|
||||
write_txt_summary(training_summary, summary_dir)
|
||||
|
||||
if not _should_export_summary(strategy):
|
||||
tf.io.gfile.rmtree(summary_dir)
|
||||
|
||||
return model
|
||||
+236
@@ -0,0 +1,236 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for official.modeling.training.model_training_utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import parameterized
|
||||
from absl.testing.absltest import mock
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from official.nlp.bert import model_training_utils
|
||||
|
||||
|
||||
def eager_strategy_combinations():
|
||||
return combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.default_strategy,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
],
|
||||
mode='eager',
|
||||
)
|
||||
|
||||
|
||||
def eager_gpu_strategy_combinations():
|
||||
return combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.default_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
],
|
||||
mode='eager',
|
||||
)
|
||||
|
||||
|
||||
def create_fake_data_input_fn(batch_size, features_shape, num_classes):
|
||||
"""Creates a dummy input function with the given feature and label shapes.
|
||||
|
||||
Args:
|
||||
batch_size: integer.
|
||||
features_shape: list[int]. Feature shape for an individual example.
|
||||
num_classes: integer. Number of labels.
|
||||
|
||||
Returns:
|
||||
An input function that is usable in the executor.
|
||||
"""
|
||||
|
||||
def _dataset_fn(input_context=None):
|
||||
"""An input function for generating fake data."""
|
||||
local_batch_size = input_context.get_per_replica_batch_size(batch_size)
|
||||
features = np.random.rand(64, *features_shape)
|
||||
labels = np.random.randint(2, size=[64, num_classes])
|
||||
# Convert the inputs to a Dataset.
|
||||
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
|
||||
dataset = dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
|
||||
def _assign_dtype(features, labels):
|
||||
features = tf.cast(features, tf.float32)
|
||||
labels = tf.cast(labels, tf.float32)
|
||||
return features, labels
|
||||
|
||||
# Shuffle, repeat, and batch the examples.
|
||||
dataset = dataset.map(_assign_dtype)
|
||||
dataset = dataset.shuffle(64).repeat()
|
||||
dataset = dataset.batch(local_batch_size, drop_remainder=True)
|
||||
dataset = dataset.prefetch(buffer_size=64)
|
||||
return dataset
|
||||
|
||||
return _dataset_fn
|
||||
|
||||
|
||||
def create_model_fn(input_shape, num_classes, use_float16=False):
|
||||
|
||||
def _model_fn():
|
||||
"""A one-layer softmax model suitable for testing."""
|
||||
input_layer = tf.keras.layers.Input(shape=input_shape)
|
||||
x = tf.keras.layers.Dense(num_classes, activation='relu')(input_layer)
|
||||
output_layer = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
|
||||
sub_model = tf.keras.models.Model(input_layer, x, name='sub_model')
|
||||
model = tf.keras.models.Model(input_layer, output_layer, name='model')
|
||||
model.add_metric(
|
||||
tf.reduce_mean(input_layer), name='mean_input', aggregation='mean')
|
||||
model.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
|
||||
if use_float16:
|
||||
model.optimizer = (
|
||||
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
|
||||
model.optimizer, loss_scale='dynamic'))
|
||||
return model, sub_model
|
||||
|
||||
return _model_fn
|
||||
|
||||
|
||||
def metric_fn():
|
||||
"""Gets a tf.keras metric object."""
|
||||
return tf.keras.metrics.CategoricalAccuracy(name='accuracy', dtype=tf.float32)
|
||||
|
||||
|
||||
def summaries_with_matching_keyword(keyword, summary_dir):
|
||||
"""Yields summary protos matching given keyword from event file."""
|
||||
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, 'events*'))
|
||||
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
|
||||
if event.summary is not None:
|
||||
for value in event.summary.value:
|
||||
if keyword in value.tag:
|
||||
logging.error(event)
|
||||
yield event.summary
|
||||
|
||||
|
||||
def check_eventfile_for_keyword(keyword, summary_dir):
|
||||
"""Checks event files for the keyword."""
|
||||
return any(summaries_with_matching_keyword(keyword, summary_dir))
|
||||
|
||||
|
||||
class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ModelTrainingUtilsTest, self).setUp()
|
||||
self._model_fn = create_model_fn(input_shape=[128], num_classes=3)
|
||||
|
||||
def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
|
||||
input_fn = create_fake_data_input_fn(
|
||||
batch_size=8, features_shape=[128], num_classes=3)
|
||||
model_training_utils.run_customized_training_loop(
|
||||
strategy=strategy,
|
||||
model_fn=self._model_fn,
|
||||
loss_fn=tf.keras.losses.categorical_crossentropy,
|
||||
model_dir=model_dir,
|
||||
steps_per_epoch=20,
|
||||
steps_per_loop=steps_per_loop,
|
||||
epochs=2,
|
||||
train_input_fn=input_fn,
|
||||
eval_input_fn=input_fn,
|
||||
eval_steps=10,
|
||||
init_checkpoint=None,
|
||||
metric_fn=metric_fn,
|
||||
custom_callbacks=None,
|
||||
run_eagerly=run_eagerly)
|
||||
|
||||
@combinations.generate(eager_strategy_combinations())
|
||||
def test_train_eager_single_step(self, distribution):
|
||||
model_dir = self.get_temp_dir()
|
||||
if isinstance(distribution, tf.distribute.experimental.TPUStrategy):
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_training(
|
||||
distribution, model_dir, steps_per_loop=1, run_eagerly=True)
|
||||
else:
|
||||
self.run_training(
|
||||
distribution, model_dir, steps_per_loop=1, run_eagerly=True)
|
||||
|
||||
@combinations.generate(eager_gpu_strategy_combinations())
|
||||
def test_train_eager_mixed_precision(self, distribution):
|
||||
model_dir = self.get_temp_dir()
|
||||
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
|
||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
||||
self._model_fn = create_model_fn(
|
||||
input_shape=[128], num_classes=3, use_float16=True)
|
||||
self.run_training(
|
||||
distribution, model_dir, steps_per_loop=1, run_eagerly=True)
|
||||
|
||||
@combinations.generate(eager_strategy_combinations())
|
||||
def test_train_check_artifacts(self, distribution):
|
||||
model_dir = self.get_temp_dir()
|
||||
self.run_training(
|
||||
distribution, model_dir, steps_per_loop=10, run_eagerly=False)
|
||||
|
||||
# Two checkpoints should be saved after two epochs.
|
||||
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*')))
|
||||
self.assertNotEmpty(
|
||||
tf.io.gfile.glob(
|
||||
os.path.join(model_dir, 'summaries/training_summary*')))
|
||||
|
||||
# Loss and accuracy values should be written into summaries.
|
||||
self.assertTrue(
|
||||
check_eventfile_for_keyword('loss',
|
||||
os.path.join(model_dir, 'summaries/train')))
|
||||
self.assertTrue(
|
||||
check_eventfile_for_keyword('accuracy',
|
||||
os.path.join(model_dir, 'summaries/train')))
|
||||
self.assertTrue(
|
||||
check_eventfile_for_keyword('mean_input',
|
||||
os.path.join(model_dir, 'summaries/train')))
|
||||
self.assertTrue(
|
||||
check_eventfile_for_keyword('accuracy',
|
||||
os.path.join(model_dir, 'summaries/eval')))
|
||||
self.assertTrue(
|
||||
check_eventfile_for_keyword('mean_input',
|
||||
os.path.join(model_dir, 'summaries/eval')))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
],
|
||||
mode='eager',
|
||||
))
|
||||
def test_train_check_artifacts_non_chief(self, distribution):
|
||||
# We shouldn't export artifacts on non-chief workers. Since there's no easy
|
||||
# way to test with real MultiWorkerMirroredStrategy, we patch the strategy
|
||||
# to make it as if it's MultiWorkerMirroredStrategy on non-chief workers.
|
||||
extended = distribution.extended
|
||||
with mock.patch.object(extended.__class__, 'should_checkpoint',
|
||||
new_callable=mock.PropertyMock, return_value=False), \
|
||||
mock.patch.object(extended.__class__, 'should_save_summary',
|
||||
new_callable=mock.PropertyMock, return_value=False):
|
||||
model_dir = self.get_temp_dir()
|
||||
self.run_training(
|
||||
distribution, model_dir, steps_per_loop=10, run_eagerly=False)
|
||||
self.assertEmpty(tf.io.gfile.listdir(model_dir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+432
@@ -0,0 +1,432 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""BERT classification finetuning runner in TF 2.x."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
from official.modeling import performance
|
||||
from official.nlp import optimization
|
||||
from official.nlp.bert import bert_models
|
||||
from official.nlp.bert import common_flags
|
||||
from official.nlp.bert import configs as bert_configs
|
||||
from official.nlp.bert import input_pipeline
|
||||
from official.nlp.bert import model_saving_utils
|
||||
from official.nlp.bert import model_training_utils
|
||||
from official.utils.misc import distribution_utils
|
||||
from official.utils.misc import keras_utils
|
||||
|
||||
|
||||
flags.DEFINE_enum(
|
||||
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
|
||||
'One of {"train_and_eval", "export_only"}. `train_and_eval`: '
|
||||
'trains the model and evaluates in the meantime. '
|
||||
'`export_only`: will take the latest checkpoint inside '
|
||||
'model_dir and export a `SavedModel`.')
|
||||
flags.DEFINE_string('train_data_path', None,
|
||||
'Path to training data for BERT classifier.')
|
||||
flags.DEFINE_string('eval_data_path', None,
|
||||
'Path to evaluation data for BERT classifier.')
|
||||
# Model training specific flags.
|
||||
flags.DEFINE_string(
|
||||
'input_meta_data_path', None,
|
||||
'Path to file that contains meta data about input '
|
||||
'to be used for training and evaluation.')
|
||||
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
|
||||
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
|
||||
|
||||
common_flags.define_common_bert_flags()
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def get_loss_fn(num_classes):
|
||||
"""Gets the classification loss function."""
|
||||
|
||||
def classification_loss_fn(labels, logits):
|
||||
"""Classification loss."""
|
||||
labels = tf.squeeze(labels)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
one_hot_labels = tf.one_hot(
|
||||
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
|
||||
per_example_loss = -tf.reduce_sum(
|
||||
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
|
||||
return tf.reduce_mean(per_example_loss)
|
||||
|
||||
return classification_loss_fn
|
||||
|
||||
|
||||
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
|
||||
is_training):
|
||||
"""Gets a closure to create a dataset."""
|
||||
|
||||
def _dataset_fn(ctx=None):
|
||||
"""Returns tf.data.Dataset for distributed BERT pretraining."""
|
||||
batch_size = ctx.get_per_replica_batch_size(
|
||||
global_batch_size) if ctx else global_batch_size
|
||||
dataset = input_pipeline.create_classifier_dataset(
|
||||
input_file_pattern,
|
||||
max_seq_length,
|
||||
batch_size,
|
||||
is_training=is_training,
|
||||
input_pipeline_context=ctx)
|
||||
return dataset
|
||||
|
||||
return _dataset_fn
|
||||
|
||||
|
||||
def run_bert_classifier(strategy,
|
||||
bert_config,
|
||||
input_meta_data,
|
||||
model_dir,
|
||||
epochs,
|
||||
steps_per_epoch,
|
||||
steps_per_loop,
|
||||
eval_steps,
|
||||
warmup_steps,
|
||||
initial_lr,
|
||||
init_checkpoint,
|
||||
train_input_fn,
|
||||
eval_input_fn,
|
||||
custom_callbacks=None,
|
||||
run_eagerly=False,
|
||||
use_keras_compile_fit=False):
|
||||
"""Run BERT classifier training using low-level API."""
|
||||
max_seq_length = input_meta_data['max_seq_length']
|
||||
num_classes = input_meta_data['num_labels']
|
||||
|
||||
def _get_classifier_model():
|
||||
"""Gets a classifier model."""
|
||||
classifier_model, core_model = (
|
||||
bert_models.classifier_model(
|
||||
bert_config,
|
||||
num_classes,
|
||||
max_seq_length,
|
||||
hub_module_url=FLAGS.hub_module_url,
|
||||
hub_module_trainable=FLAGS.hub_module_trainable))
|
||||
optimizer = optimization.create_optimizer(
|
||||
initial_lr, steps_per_epoch * epochs, warmup_steps,
|
||||
FLAGS.end_lr, FLAGS.optimizer_type)
|
||||
classifier_model.optimizer = performance.configure_optimizer(
|
||||
optimizer,
|
||||
use_float16=common_flags.use_float16(),
|
||||
use_graph_rewrite=common_flags.use_graph_rewrite())
|
||||
return classifier_model, core_model
|
||||
|
||||
loss_fn = get_loss_fn(num_classes)
|
||||
|
||||
# Defines evaluation metrics function, which will create metrics in the
|
||||
# correct device and strategy scope.
|
||||
def metric_fn():
|
||||
return tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
'test_accuracy', dtype=tf.float32)
|
||||
|
||||
if use_keras_compile_fit:
|
||||
# Start training using Keras compile/fit API.
|
||||
logging.info('Training using TF 2.0 Keras compile/fit API with '
|
||||
'distribution strategy.')
|
||||
return run_keras_compile_fit(
|
||||
model_dir,
|
||||
strategy,
|
||||
_get_classifier_model,
|
||||
train_input_fn,
|
||||
eval_input_fn,
|
||||
loss_fn,
|
||||
metric_fn,
|
||||
init_checkpoint,
|
||||
epochs,
|
||||
steps_per_epoch,
|
||||
steps_per_loop,
|
||||
eval_steps,
|
||||
custom_callbacks=custom_callbacks)
|
||||
|
||||
# Use user-defined loop to start training.
|
||||
logging.info('Training using customized training loop TF 2.0 with '
|
||||
'distribution strategy.')
|
||||
return model_training_utils.run_customized_training_loop(
|
||||
strategy=strategy,
|
||||
model_fn=_get_classifier_model,
|
||||
loss_fn=loss_fn,
|
||||
model_dir=model_dir,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
steps_per_loop=steps_per_loop,
|
||||
epochs=epochs,
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
eval_steps=eval_steps,
|
||||
init_checkpoint=init_checkpoint,
|
||||
metric_fn=metric_fn,
|
||||
custom_callbacks=custom_callbacks,
|
||||
run_eagerly=run_eagerly)
|
||||
|
||||
|
||||
def run_keras_compile_fit(model_dir,
|
||||
strategy,
|
||||
model_fn,
|
||||
train_input_fn,
|
||||
eval_input_fn,
|
||||
loss_fn,
|
||||
metric_fn,
|
||||
init_checkpoint,
|
||||
epochs,
|
||||
steps_per_epoch,
|
||||
steps_per_loop,
|
||||
eval_steps,
|
||||
custom_callbacks=None):
|
||||
"""Runs BERT classifier model using Keras compile/fit API."""
|
||||
|
||||
with strategy.scope():
|
||||
training_dataset = train_input_fn()
|
||||
evaluation_dataset = eval_input_fn()
|
||||
bert_model, sub_model = model_fn()
|
||||
optimizer = bert_model.optimizer
|
||||
|
||||
if init_checkpoint:
|
||||
checkpoint = tf.train.Checkpoint(model=sub_model)
|
||||
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
|
||||
|
||||
bert_model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=loss_fn,
|
||||
metrics=[metric_fn()],
|
||||
experimental_steps_per_execution=steps_per_loop)
|
||||
|
||||
summary_dir = os.path.join(model_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
checkpoint_path = os.path.join(model_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
checkpoint_path, save_weights_only=True)
|
||||
|
||||
if custom_callbacks is not None:
|
||||
custom_callbacks += [summary_callback, checkpoint_callback]
|
||||
else:
|
||||
custom_callbacks = [summary_callback, checkpoint_callback]
|
||||
|
||||
bert_model.fit(
|
||||
x=training_dataset,
|
||||
validation_data=evaluation_dataset,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
epochs=epochs,
|
||||
validation_steps=eval_steps,
|
||||
callbacks=custom_callbacks)
|
||||
|
||||
return bert_model
|
||||
|
||||
|
||||
def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
|
||||
eval_steps):
|
||||
"""Obtains predictions of trained model on evaluation data.
|
||||
|
||||
Note that list of labels is returned along with the predictions because the
|
||||
order changes on distributing dataset over TPU pods.
|
||||
|
||||
Args:
|
||||
strategy: Distribution strategy.
|
||||
trained_model: Trained model with preloaded weights.
|
||||
eval_input_fn: Input function for evaluation data.
|
||||
eval_steps: Number of evaluation steps.
|
||||
|
||||
Returns:
|
||||
predictions: List of predictions.
|
||||
labels: List of gold labels corresponding to predictions.
|
||||
"""
|
||||
|
||||
@tf.function
|
||||
def test_step(iterator):
|
||||
"""Computes predictions on distributed devices."""
|
||||
|
||||
def _test_step_fn(inputs):
|
||||
"""Replicated predictions."""
|
||||
inputs, labels = inputs
|
||||
model_outputs = trained_model(inputs, training=False)
|
||||
return model_outputs, labels
|
||||
|
||||
outputs, labels = strategy.run(
|
||||
_test_step_fn, args=(next(iterator),))
|
||||
# outputs: current batch logits as a tuple of shard logits
|
||||
outputs = tf.nest.map_structure(strategy.experimental_local_results,
|
||||
outputs)
|
||||
labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
|
||||
return outputs, labels
|
||||
|
||||
def _run_evaluation(test_iterator):
|
||||
"""Runs evaluation steps."""
|
||||
preds, golds = list(), list()
|
||||
for _ in range(eval_steps):
|
||||
logits, labels = test_step(test_iterator)
|
||||
for cur_logits, cur_labels in zip(logits, labels):
|
||||
preds.extend(tf.math.argmax(cur_logits, axis=1).numpy())
|
||||
golds.extend(cur_labels.numpy().tolist())
|
||||
return preds, golds
|
||||
|
||||
test_iter = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(eval_input_fn))
|
||||
predictions, labels = _run_evaluation(test_iter)
|
||||
|
||||
return predictions, labels
|
||||
|
||||
|
||||
def export_classifier(model_export_path, input_meta_data,
|
||||
restore_model_using_load_weights, bert_config, model_dir):
|
||||
"""Exports a trained model as a `SavedModel` for inference.
|
||||
|
||||
Args:
|
||||
model_export_path: a string specifying the path to the SavedModel directory.
|
||||
input_meta_data: dictionary containing meta data about input and model.
|
||||
restore_model_using_load_weights: Whether to use checkpoint.restore() API
|
||||
for custom checkpoint or to use model.load_weights() API. There are 2
|
||||
different ways to save checkpoints. One is using tf.train.Checkpoint and
|
||||
another is using Keras model.save_weights(). Custom training loop
|
||||
implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
|
||||
callback internally uses model.save_weights() API. Since these two API's
|
||||
cannot be used together, model loading logic must be take into account how
|
||||
model checkpoint was saved.
|
||||
bert_config: Bert configuration file to define core bert layers.
|
||||
model_dir: The directory where the model weights and training/evaluation
|
||||
summaries are stored.
|
||||
|
||||
Raises:
|
||||
Export path is not specified, got an empty string or None.
|
||||
"""
|
||||
if not model_export_path:
|
||||
raise ValueError('Export path is not specified: %s' % model_export_path)
|
||||
if not model_dir:
|
||||
raise ValueError('Export path is not specified: %s' % model_dir)
|
||||
|
||||
# Export uses float32 for now, even if training uses mixed precision.
|
||||
tf.keras.mixed_precision.experimental.set_policy('float32')
|
||||
classifier_model = bert_models.classifier_model(
|
||||
bert_config, input_meta_data['num_labels'],
|
||||
input_meta_data['max_seq_length'])[0]
|
||||
|
||||
model_saving_utils.export_bert_model(
|
||||
model_export_path,
|
||||
model=classifier_model,
|
||||
checkpoint_dir=model_dir,
|
||||
restore_model_using_load_weights=restore_model_using_load_weights)
|
||||
|
||||
|
||||
def run_bert(strategy,
|
||||
input_meta_data,
|
||||
model_config,
|
||||
train_input_fn=None,
|
||||
eval_input_fn=None):
|
||||
"""Run BERT training."""
|
||||
if FLAGS.mode == 'export_only':
|
||||
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
|
||||
# internally uses model.save_weights() to save checkpoints, we must
|
||||
# use model.load_weights() when Keras compile/fit() is used.
|
||||
export_classifier(FLAGS.model_export_path, input_meta_data,
|
||||
FLAGS.use_keras_compile_fit,
|
||||
model_config, FLAGS.model_dir)
|
||||
return
|
||||
|
||||
if FLAGS.mode != 'train_and_eval':
|
||||
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
|
||||
# Enables XLA in Session Config. Should not be set for TPU.
|
||||
keras_utils.set_config_v2(FLAGS.enable_xla)
|
||||
performance.set_mixed_precision_policy(common_flags.dtype())
|
||||
|
||||
epochs = FLAGS.num_train_epochs
|
||||
train_data_size = input_meta_data['train_data_size']
|
||||
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
|
||||
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
|
||||
eval_steps = int(
|
||||
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
|
||||
|
||||
if not strategy:
|
||||
raise ValueError('Distribution strategy has not been specified.')
|
||||
|
||||
if FLAGS.log_steps:
|
||||
custom_callbacks = [keras_utils.TimeHistory(
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
logdir=FLAGS.model_dir,
|
||||
)]
|
||||
else:
|
||||
custom_callbacks = None
|
||||
|
||||
trained_model = run_bert_classifier(
|
||||
strategy,
|
||||
model_config,
|
||||
input_meta_data,
|
||||
FLAGS.model_dir,
|
||||
epochs,
|
||||
steps_per_epoch,
|
||||
FLAGS.steps_per_loop,
|
||||
eval_steps,
|
||||
warmup_steps,
|
||||
FLAGS.learning_rate,
|
||||
FLAGS.init_checkpoint,
|
||||
train_input_fn,
|
||||
eval_input_fn,
|
||||
run_eagerly=FLAGS.run_eagerly,
|
||||
use_keras_compile_fit=FLAGS.use_keras_compile_fit,
|
||||
custom_callbacks=custom_callbacks)
|
||||
|
||||
if FLAGS.model_export_path:
|
||||
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
|
||||
# internally uses model.save_weights() to save checkpoints, we must
|
||||
# use model.load_weights() when Keras compile/fit() is used.
|
||||
model_saving_utils.export_bert_model(
|
||||
FLAGS.model_export_path,
|
||||
model=trained_model,
|
||||
restore_model_using_load_weights=FLAGS.use_keras_compile_fit)
|
||||
return trained_model
|
||||
|
||||
|
||||
def main(_):
|
||||
# Users should always run this script under TF 2.x
|
||||
|
||||
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
|
||||
input_meta_data = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
if not FLAGS.model_dir:
|
||||
FLAGS.model_dir = '/tmp/bert20/'
|
||||
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=FLAGS.distribution_strategy,
|
||||
num_gpus=FLAGS.num_gpus,
|
||||
tpu_address=FLAGS.tpu)
|
||||
max_seq_length = input_meta_data['max_seq_length']
|
||||
train_input_fn = get_dataset_fn(
|
||||
FLAGS.train_data_path,
|
||||
max_seq_length,
|
||||
FLAGS.train_batch_size,
|
||||
is_training=True)
|
||||
eval_input_fn = get_dataset_fn(
|
||||
FLAGS.eval_data_path,
|
||||
max_seq_length,
|
||||
FLAGS.eval_batch_size,
|
||||
is_training=False)
|
||||
|
||||
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
run_bert(strategy, input_meta_data, bert_config, train_input_fn,
|
||||
eval_input_fn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('bert_config_file')
|
||||
flags.mark_flag_as_required('input_meta_data_path')
|
||||
flags.mark_flag_as_required('model_dir')
|
||||
app.run(main)
|
||||
+187
@@ -0,0 +1,187 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run masked LM/next sentence pre-training for BERT in TF 2.x."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import gin
|
||||
import tensorflow as tf
|
||||
from official.modeling import performance
|
||||
from official.nlp import optimization
|
||||
from official.nlp.bert import bert_models
|
||||
from official.nlp.bert import common_flags
|
||||
from official.nlp.bert import configs
|
||||
from official.nlp.bert import input_pipeline
|
||||
from official.nlp.bert import model_training_utils
|
||||
from official.utils.misc import distribution_utils
|
||||
|
||||
|
||||
flags.DEFINE_string('input_files', None,
|
||||
'File path to retrieve training data for pre-training.')
|
||||
# Model training specific flags.
|
||||
flags.DEFINE_integer(
|
||||
'max_seq_length', 128,
|
||||
'The maximum total input sequence length after WordPiece tokenization. '
|
||||
'Sequences longer than this will be truncated, and sequences shorter '
|
||||
'than this will be padded.')
|
||||
flags.DEFINE_integer('max_predictions_per_seq', 20,
|
||||
'Maximum predictions per sequence_output.')
|
||||
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
|
||||
flags.DEFINE_integer('num_steps_per_epoch', 1000,
|
||||
'Total number of training steps to run per epoch.')
|
||||
flags.DEFINE_float('warmup_steps', 10000,
|
||||
'Warmup steps for Adam weight decay optimizer.')
|
||||
flags.DEFINE_bool('use_next_sentence_label', True,
|
||||
'Whether to use next sentence label to compute final loss.')
|
||||
|
||||
common_flags.define_common_bert_flags()
|
||||
common_flags.define_gin_flags()
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
|
||||
max_predictions_per_seq, global_batch_size,
|
||||
use_next_sentence_label=True):
|
||||
"""Returns input dataset from input file string."""
|
||||
def _dataset_fn(ctx=None):
|
||||
"""Returns tf.data.Dataset for distributed BERT pretraining."""
|
||||
input_patterns = input_file_pattern.split(',')
|
||||
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
|
||||
train_dataset = input_pipeline.create_pretrain_dataset(
|
||||
input_patterns,
|
||||
seq_length,
|
||||
max_predictions_per_seq,
|
||||
batch_size,
|
||||
is_training=True,
|
||||
input_pipeline_context=ctx,
|
||||
use_next_sentence_label=use_next_sentence_label)
|
||||
return train_dataset
|
||||
|
||||
return _dataset_fn
|
||||
|
||||
|
||||
def get_loss_fn():
|
||||
"""Returns loss function for BERT pretraining."""
|
||||
|
||||
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
|
||||
return tf.reduce_mean(losses)
|
||||
|
||||
return _bert_pretrain_loss_fn
|
||||
|
||||
|
||||
def run_customized_training(strategy,
|
||||
bert_config,
|
||||
max_seq_length,
|
||||
max_predictions_per_seq,
|
||||
model_dir,
|
||||
steps_per_epoch,
|
||||
steps_per_loop,
|
||||
epochs,
|
||||
initial_lr,
|
||||
warmup_steps,
|
||||
end_lr,
|
||||
optimizer_type,
|
||||
input_files,
|
||||
train_batch_size,
|
||||
use_next_sentence_label=True):
|
||||
"""Run BERT pretrain model training using low-level API."""
|
||||
|
||||
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
|
||||
max_predictions_per_seq,
|
||||
train_batch_size,
|
||||
use_next_sentence_label)
|
||||
|
||||
def _get_pretrain_model():
|
||||
"""Gets a pretraining model."""
|
||||
pretrain_model, core_model = bert_models.pretrain_model(
|
||||
bert_config, max_seq_length, max_predictions_per_seq,
|
||||
use_next_sentence_label=use_next_sentence_label)
|
||||
optimizer = optimization.create_optimizer(
|
||||
initial_lr, steps_per_epoch * epochs, warmup_steps,
|
||||
end_lr, optimizer_type)
|
||||
pretrain_model.optimizer = performance.configure_optimizer(
|
||||
optimizer,
|
||||
use_float16=common_flags.use_float16(),
|
||||
use_graph_rewrite=common_flags.use_graph_rewrite())
|
||||
return pretrain_model, core_model
|
||||
|
||||
trained_model = model_training_utils.run_customized_training_loop(
|
||||
strategy=strategy,
|
||||
model_fn=_get_pretrain_model,
|
||||
loss_fn=get_loss_fn(),
|
||||
scale_loss=FLAGS.scale_loss,
|
||||
model_dir=model_dir,
|
||||
train_input_fn=train_input_fn,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
steps_per_loop=steps_per_loop,
|
||||
epochs=epochs,
|
||||
sub_model_export_name='pretrained/bert_model')
|
||||
|
||||
return trained_model
|
||||
|
||||
|
||||
def run_bert_pretrain(strategy):
|
||||
"""Runs BERT pre-training."""
|
||||
|
||||
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
if not strategy:
|
||||
raise ValueError('Distribution strategy is not specified.')
|
||||
|
||||
# Runs customized training loop.
|
||||
logging.info('Training using customized training loop TF 2.0 with distrubuted'
|
||||
'strategy.')
|
||||
|
||||
performance.set_mixed_precision_policy(common_flags.dtype())
|
||||
|
||||
return run_customized_training(
|
||||
strategy,
|
||||
bert_config,
|
||||
FLAGS.max_seq_length,
|
||||
FLAGS.max_predictions_per_seq,
|
||||
FLAGS.model_dir,
|
||||
FLAGS.num_steps_per_epoch,
|
||||
FLAGS.steps_per_loop,
|
||||
FLAGS.num_train_epochs,
|
||||
FLAGS.learning_rate,
|
||||
FLAGS.warmup_steps,
|
||||
FLAGS.end_lr,
|
||||
FLAGS.optimizer_type,
|
||||
FLAGS.input_files,
|
||||
FLAGS.train_batch_size,
|
||||
FLAGS.use_next_sentence_label)
|
||||
|
||||
|
||||
def main(_):
|
||||
# Users should always run this script under TF 2.x
|
||||
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
|
||||
if not FLAGS.model_dir:
|
||||
FLAGS.model_dir = '/tmp/bert20/'
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=FLAGS.distribution_strategy,
|
||||
num_gpus=FLAGS.num_gpus,
|
||||
tpu_address=FLAGS.tpu)
|
||||
if strategy:
|
||||
print('***** Number of cores used : ', strategy.num_replicas_in_sync)
|
||||
|
||||
run_bert_pretrain(strategy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
+149
@@ -0,0 +1,149 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.bert import configs as bert_configs
|
||||
from official.nlp.bert import run_squad_helper
|
||||
from official.nlp.bert import tokenization
|
||||
from official.nlp.data import squad_lib as squad_lib_wp
|
||||
from official.utils.misc import distribution_utils
|
||||
from official.utils.misc import keras_utils
|
||||
|
||||
|
||||
flags.DEFINE_string('vocab_file', None,
|
||||
'The vocabulary file that the BERT model was trained on.')
|
||||
|
||||
# More flags can be found in run_squad_helper.
|
||||
run_squad_helper.define_common_squad_flags()
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def train_squad(strategy,
|
||||
input_meta_data,
|
||||
custom_callbacks=None,
|
||||
run_eagerly=False,
|
||||
init_checkpoint=None):
|
||||
"""Run bert squad training."""
|
||||
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
|
||||
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
|
||||
custom_callbacks, run_eagerly, init_checkpoint)
|
||||
|
||||
|
||||
def predict_squad(strategy, input_meta_data):
|
||||
"""Makes predictions for the squad dataset."""
|
||||
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
run_squad_helper.predict_squad(
|
||||
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
|
||||
|
||||
|
||||
def eval_squad(strategy, input_meta_data):
|
||||
"""Evaluate on the squad dataset."""
|
||||
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
eval_metrics = run_squad_helper.eval_squad(
|
||||
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
|
||||
return eval_metrics
|
||||
|
||||
|
||||
def export_squad(model_export_path, input_meta_data):
|
||||
"""Exports a trained model as a `SavedModel` for inference.
|
||||
|
||||
Args:
|
||||
model_export_path: a string specifying the path to the SavedModel directory.
|
||||
input_meta_data: dictionary containing meta data about input and model.
|
||||
|
||||
Raises:
|
||||
Export path is not specified, got an empty string or None.
|
||||
"""
|
||||
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
|
||||
|
||||
|
||||
def main(_):
|
||||
# Users should always run this script under TF 2.x
|
||||
|
||||
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
|
||||
input_meta_data = json.loads(reader.read().decode('utf-8'))
|
||||
|
||||
if FLAGS.mode == 'export_only':
|
||||
export_squad(FLAGS.model_export_path, input_meta_data)
|
||||
return
|
||||
|
||||
# Configures cluster spec for multi-worker distribution strategy.
|
||||
if FLAGS.num_gpus > 0:
|
||||
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
|
||||
FLAGS.task_index)
|
||||
strategy = distribution_utils.get_distribution_strategy(
|
||||
distribution_strategy=FLAGS.distribution_strategy,
|
||||
num_gpus=FLAGS.num_gpus,
|
||||
all_reduce_alg=FLAGS.all_reduce_alg,
|
||||
tpu_address=FLAGS.tpu)
|
||||
|
||||
if 'train' in FLAGS.mode:
|
||||
if FLAGS.log_steps:
|
||||
custom_callbacks = [keras_utils.TimeHistory(
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
log_steps=FLAGS.log_steps,
|
||||
logdir=FLAGS.model_dir,
|
||||
)]
|
||||
else:
|
||||
custom_callbacks = None
|
||||
|
||||
train_squad(
|
||||
strategy,
|
||||
input_meta_data,
|
||||
custom_callbacks=custom_callbacks,
|
||||
run_eagerly=FLAGS.run_eagerly,
|
||||
)
|
||||
if 'predict' in FLAGS.mode:
|
||||
predict_squad(strategy, input_meta_data)
|
||||
if 'eval' in FLAGS.mode:
|
||||
eval_metrics = eval_squad(strategy, input_meta_data)
|
||||
f1_score = eval_metrics['final_f1']
|
||||
logging.info('SQuAD eval F1-score: %f', f1_score)
|
||||
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
|
||||
summary_writer = tf.summary.create_file_writer(summary_dir)
|
||||
with summary_writer.as_default():
|
||||
# TODO(lehou): write to the correct step number.
|
||||
tf.summary.scalar('F1-score', f1_score, step=0)
|
||||
summary_writer.flush()
|
||||
# Also write eval_metrics to json file.
|
||||
squad_lib_wp.write_to_json_files(
|
||||
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('bert_config_file')
|
||||
flags.mark_flag_as_required('model_dir')
|
||||
app.run(main)
|
||||
+432
@@ -0,0 +1,432 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
from official.modeling import performance
|
||||
from official.nlp import optimization
|
||||
from official.nlp.bert import bert_models
|
||||
from official.nlp.bert import common_flags
|
||||
from official.nlp.bert import input_pipeline
|
||||
from official.nlp.bert import model_saving_utils
|
||||
from official.nlp.bert import model_training_utils
|
||||
from official.nlp.bert import squad_evaluate_v1_1
|
||||
from official.nlp.bert import squad_evaluate_v2_0
|
||||
from official.nlp.data import squad_lib_sp
|
||||
from official.utils.misc import keras_utils
|
||||
|
||||
|
||||
def define_common_squad_flags():
|
||||
"""Defines common flags used by SQuAD tasks."""
|
||||
flags.DEFINE_enum(
|
||||
'mode', 'train_and_eval',
|
||||
['train_and_eval', 'train_and_predict',
|
||||
'train', 'eval', 'predict', 'export_only'],
|
||||
'One of {"train_and_eval", "train_and_predict", '
|
||||
'"train", "eval", "predict", "export_only"}. '
|
||||
'`train_and_eval`: train & predict to json files & compute eval metrics. '
|
||||
'`train_and_predict`: train & predict to json files. '
|
||||
'`train`: only trains the model. '
|
||||
'`eval`: predict answers from squad json file & compute eval metrics. '
|
||||
'`predict`: predict answers from the squad json file. '
|
||||
'`export_only`: will take the latest checkpoint inside '
|
||||
'model_dir and export a `SavedModel`.')
|
||||
flags.DEFINE_string('train_data_path', '',
|
||||
'Training data path with train tfrecords.')
|
||||
flags.DEFINE_string(
|
||||
'input_meta_data_path', None,
|
||||
'Path to file that contains meta data about input '
|
||||
'to be used for training and evaluation.')
|
||||
# Model training specific flags.
|
||||
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
|
||||
# Predict processing related.
|
||||
flags.DEFINE_string('predict_file', None,
|
||||
'Prediction data path with train tfrecords.')
|
||||
flags.DEFINE_bool(
|
||||
'do_lower_case', True,
|
||||
'Whether to lower case the input text. Should be True for uncased '
|
||||
'models and False for cased models.')
|
||||
flags.DEFINE_float(
|
||||
'null_score_diff_threshold', 0.0,
|
||||
'If null_score - best_non_null is greater than the threshold, '
|
||||
'predict null. This is only used for SQuAD v2.')
|
||||
flags.DEFINE_bool(
|
||||
'verbose_logging', False,
|
||||
'If true, all of the warnings related to data processing will be '
|
||||
'printed. A number of warnings are expected for a normal SQuAD '
|
||||
'evaluation.')
|
||||
flags.DEFINE_integer('predict_batch_size', 8,
|
||||
'Total batch size for prediction.')
|
||||
flags.DEFINE_integer(
|
||||
'n_best_size', 20,
|
||||
'The total number of n-best predictions to generate in the '
|
||||
'nbest_predictions.json output file.')
|
||||
flags.DEFINE_integer(
|
||||
'max_answer_length', 30,
|
||||
'The maximum length of an answer that can be generated. This is needed '
|
||||
'because the start and end predictions are not conditioned on one '
|
||||
'another.')
|
||||
|
||||
common_flags.define_common_bert_flags()
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def squad_loss_fn(start_positions,
|
||||
end_positions,
|
||||
start_logits,
|
||||
end_logits):
|
||||
"""Returns sparse categorical crossentropy for start/end logits."""
|
||||
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
|
||||
start_positions, start_logits, from_logits=True)
|
||||
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
|
||||
end_positions, end_logits, from_logits=True)
|
||||
|
||||
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
|
||||
return total_loss
|
||||
|
||||
|
||||
def get_loss_fn():
|
||||
"""Gets a loss function for squad task."""
|
||||
|
||||
def _loss_fn(labels, model_outputs):
|
||||
start_positions = labels['start_positions']
|
||||
end_positions = labels['end_positions']
|
||||
start_logits, end_logits = model_outputs
|
||||
return squad_loss_fn(
|
||||
start_positions,
|
||||
end_positions,
|
||||
start_logits,
|
||||
end_logits)
|
||||
|
||||
return _loss_fn
|
||||
|
||||
|
||||
RawResult = collections.namedtuple('RawResult',
|
||||
['unique_id', 'start_logits', 'end_logits'])
|
||||
|
||||
|
||||
def get_raw_results(predictions):
|
||||
"""Converts multi-replica predictions to RawResult."""
|
||||
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
|
||||
predictions['start_logits'],
|
||||
predictions['end_logits']):
|
||||
for values in zip(unique_ids.numpy(), start_logits.numpy(),
|
||||
end_logits.numpy()):
|
||||
yield RawResult(
|
||||
unique_id=values[0],
|
||||
start_logits=values[1].tolist(),
|
||||
end_logits=values[2].tolist())
|
||||
|
||||
|
||||
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
|
||||
is_training):
|
||||
"""Gets a closure to create a dataset.."""
|
||||
|
||||
def _dataset_fn(ctx=None):
|
||||
"""Returns tf.data.Dataset for distributed BERT pretraining."""
|
||||
batch_size = ctx.get_per_replica_batch_size(
|
||||
global_batch_size) if ctx else global_batch_size
|
||||
dataset = input_pipeline.create_squad_dataset(
|
||||
input_file_pattern,
|
||||
max_seq_length,
|
||||
batch_size,
|
||||
is_training=is_training,
|
||||
input_pipeline_context=ctx)
|
||||
return dataset
|
||||
|
||||
return _dataset_fn
|
||||
|
||||
|
||||
def predict_squad_customized(strategy,
|
||||
input_meta_data,
|
||||
bert_config,
|
||||
checkpoint_path,
|
||||
predict_tfrecord_path,
|
||||
num_steps):
|
||||
"""Make predictions using a Bert-based squad model."""
|
||||
predict_dataset_fn = get_dataset_fn(
|
||||
predict_tfrecord_path,
|
||||
input_meta_data['max_seq_length'],
|
||||
FLAGS.predict_batch_size,
|
||||
is_training=False)
|
||||
predict_iterator = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
predict_dataset_fn))
|
||||
|
||||
with strategy.scope():
|
||||
# Prediction always uses float32, even if training uses mixed precision.
|
||||
tf.keras.mixed_precision.experimental.set_policy('float32')
|
||||
squad_model, _ = bert_models.squad_model(
|
||||
bert_config,
|
||||
input_meta_data['max_seq_length'],
|
||||
hub_module_url=FLAGS.hub_module_url)
|
||||
|
||||
if checkpoint_path is None:
|
||||
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
|
||||
logging.info('Restoring checkpoints from %s', checkpoint_path)
|
||||
checkpoint = tf.train.Checkpoint(model=squad_model)
|
||||
checkpoint.restore(checkpoint_path).expect_partial()
|
||||
|
||||
@tf.function
|
||||
def predict_step(iterator):
|
||||
"""Predicts on distributed devices."""
|
||||
|
||||
def _replicated_step(inputs):
|
||||
"""Replicated prediction calculation."""
|
||||
x, _ = inputs
|
||||
unique_ids = x.pop('unique_ids')
|
||||
start_logits, end_logits = squad_model(x, training=False)
|
||||
return dict(
|
||||
unique_ids=unique_ids,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits)
|
||||
|
||||
outputs = strategy.run(_replicated_step, args=(next(iterator),))
|
||||
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
|
||||
|
||||
all_results = []
|
||||
for _ in range(num_steps):
|
||||
predictions = predict_step(predict_iterator)
|
||||
for result in get_raw_results(predictions):
|
||||
all_results.append(result)
|
||||
if len(all_results) % 100 == 0:
|
||||
logging.info('Made predictions for %d records.', len(all_results))
|
||||
return all_results
|
||||
|
||||
|
||||
def train_squad(strategy,
|
||||
input_meta_data,
|
||||
bert_config,
|
||||
custom_callbacks=None,
|
||||
run_eagerly=False,
|
||||
init_checkpoint=None):
|
||||
"""Run bert squad training."""
|
||||
if strategy:
|
||||
logging.info('Training using customized training loop with distribution'
|
||||
' strategy.')
|
||||
# Enables XLA in Session Config. Should not be set for TPU.
|
||||
keras_utils.set_config_v2(FLAGS.enable_xla)
|
||||
performance.set_mixed_precision_policy(common_flags.dtype())
|
||||
|
||||
epochs = FLAGS.num_train_epochs
|
||||
num_train_examples = input_meta_data['train_data_size']
|
||||
max_seq_length = input_meta_data['max_seq_length']
|
||||
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
|
||||
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
|
||||
train_input_fn = get_dataset_fn(
|
||||
FLAGS.train_data_path,
|
||||
max_seq_length,
|
||||
FLAGS.train_batch_size,
|
||||
is_training=True)
|
||||
|
||||
def _get_squad_model():
|
||||
"""Get Squad model and optimizer."""
|
||||
squad_model, core_model = bert_models.squad_model(
|
||||
bert_config,
|
||||
max_seq_length,
|
||||
hub_module_url=FLAGS.hub_module_url,
|
||||
hub_module_trainable=FLAGS.hub_module_trainable)
|
||||
optimizer = optimization.create_optimizer(FLAGS.learning_rate,
|
||||
steps_per_epoch * epochs,
|
||||
warmup_steps,
|
||||
FLAGS.end_lr,
|
||||
FLAGS.optimizer_type)
|
||||
|
||||
squad_model.optimizer = performance.configure_optimizer(
|
||||
optimizer,
|
||||
use_float16=common_flags.use_float16(),
|
||||
use_graph_rewrite=common_flags.use_graph_rewrite())
|
||||
return squad_model, core_model
|
||||
|
||||
# If explicit_allreduce = True, apply_gradients() no longer implicitly
|
||||
# allreduce gradients, users manually allreduce gradient and pass the
|
||||
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
|
||||
# applied to allreduced gradients.
|
||||
def clip_by_global_norm_callback(grads_and_vars):
|
||||
grads, variables = zip(*grads_and_vars)
|
||||
(clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
||||
return zip(clipped_grads, variables)
|
||||
|
||||
model_training_utils.run_customized_training_loop(
|
||||
strategy=strategy,
|
||||
model_fn=_get_squad_model,
|
||||
loss_fn=get_loss_fn(),
|
||||
model_dir=FLAGS.model_dir,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
steps_per_loop=FLAGS.steps_per_loop,
|
||||
epochs=epochs,
|
||||
train_input_fn=train_input_fn,
|
||||
init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
|
||||
run_eagerly=run_eagerly,
|
||||
custom_callbacks=custom_callbacks,
|
||||
explicit_allreduce=False,
|
||||
post_allreduce_callbacks=[clip_by_global_norm_callback])
|
||||
|
||||
|
||||
def prediction_output_squad(
|
||||
strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint):
|
||||
"""Makes predictions for a squad dataset."""
|
||||
doc_stride = input_meta_data['doc_stride']
|
||||
max_query_length = input_meta_data['max_query_length']
|
||||
# Whether data should be in Ver 2.0 format.
|
||||
version_2_with_negative = input_meta_data.get('version_2_with_negative',
|
||||
False)
|
||||
eval_examples = squad_lib.read_squad_examples(
|
||||
input_file=FLAGS.predict_file,
|
||||
is_training=False,
|
||||
version_2_with_negative=version_2_with_negative)
|
||||
|
||||
eval_writer = squad_lib.FeatureWriter(
|
||||
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
|
||||
is_training=False)
|
||||
eval_features = []
|
||||
|
||||
def _append_feature(feature, is_padding):
|
||||
if not is_padding:
|
||||
eval_features.append(feature)
|
||||
eval_writer.process_feature(feature)
|
||||
|
||||
# TPU requires a fixed batch size for all batches, therefore the number
|
||||
# of examples must be a multiple of the batch size, or else examples
|
||||
# will get dropped. So we pad with fake examples which are ignored
|
||||
# later on.
|
||||
kwargs = dict(
|
||||
examples=eval_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=input_meta_data['max_seq_length'],
|
||||
doc_stride=doc_stride,
|
||||
max_query_length=max_query_length,
|
||||
is_training=False,
|
||||
output_fn=_append_feature,
|
||||
batch_size=FLAGS.predict_batch_size)
|
||||
|
||||
# squad_lib_sp requires one more argument 'do_lower_case'.
|
||||
if squad_lib == squad_lib_sp:
|
||||
kwargs['do_lower_case'] = FLAGS.do_lower_case
|
||||
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
|
||||
eval_writer.close()
|
||||
|
||||
logging.info('***** Running predictions *****')
|
||||
logging.info(' Num orig examples = %d', len(eval_examples))
|
||||
logging.info(' Num split examples = %d', len(eval_features))
|
||||
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
|
||||
|
||||
num_steps = int(dataset_size / FLAGS.predict_batch_size)
|
||||
all_results = predict_squad_customized(
|
||||
strategy, input_meta_data, bert_config,
|
||||
checkpoint, eval_writer.filename, num_steps)
|
||||
|
||||
all_predictions, all_nbest_json, scores_diff_json = (
|
||||
squad_lib.postprocess_output(
|
||||
eval_examples,
|
||||
eval_features,
|
||||
all_results,
|
||||
FLAGS.n_best_size,
|
||||
FLAGS.max_answer_length,
|
||||
FLAGS.do_lower_case,
|
||||
version_2_with_negative=version_2_with_negative,
|
||||
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
|
||||
verbose=FLAGS.verbose_logging))
|
||||
|
||||
return all_predictions, all_nbest_json, scores_diff_json
|
||||
|
||||
|
||||
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
|
||||
squad_lib, version_2_with_negative):
|
||||
"""Save output to json files."""
|
||||
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
|
||||
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
|
||||
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
|
||||
logging.info('Writing predictions to: %s', (output_prediction_file))
|
||||
logging.info('Writing nbest to: %s', (output_nbest_file))
|
||||
|
||||
squad_lib.write_to_json_files(all_predictions, output_prediction_file)
|
||||
squad_lib.write_to_json_files(all_nbest_json, output_nbest_file)
|
||||
if version_2_with_negative:
|
||||
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
|
||||
|
||||
|
||||
def predict_squad(strategy,
|
||||
input_meta_data,
|
||||
tokenizer,
|
||||
bert_config,
|
||||
squad_lib,
|
||||
init_checkpoint=None):
|
||||
"""Get prediction results and evaluate them to hard drive."""
|
||||
if init_checkpoint is None:
|
||||
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
|
||||
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
|
||||
strategy, input_meta_data, tokenizer,
|
||||
bert_config, squad_lib, init_checkpoint)
|
||||
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
|
||||
input_meta_data.get('version_2_with_negative', False))
|
||||
|
||||
|
||||
def eval_squad(strategy,
|
||||
input_meta_data,
|
||||
tokenizer,
|
||||
bert_config,
|
||||
squad_lib,
|
||||
init_checkpoint=None):
|
||||
"""Get prediction results and evaluate them against ground truth."""
|
||||
if init_checkpoint is None:
|
||||
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
|
||||
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
|
||||
strategy, input_meta_data, tokenizer,
|
||||
bert_config, squad_lib, init_checkpoint)
|
||||
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
|
||||
input_meta_data.get('version_2_with_negative', False))
|
||||
|
||||
with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
|
||||
dataset_json = json.load(reader)
|
||||
pred_dataset = dataset_json['data']
|
||||
if input_meta_data.get('version_2_with_negative', False):
|
||||
eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset,
|
||||
all_predictions,
|
||||
scores_diff_json)
|
||||
else:
|
||||
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
|
||||
return eval_metrics
|
||||
|
||||
|
||||
def export_squad(model_export_path, input_meta_data, bert_config):
|
||||
"""Exports a trained model as a `SavedModel` for inference.
|
||||
|
||||
Args:
|
||||
model_export_path: a string specifying the path to the SavedModel directory.
|
||||
input_meta_data: dictionary containing meta data about input and model.
|
||||
bert_config: Bert configuration file to define core bert layers.
|
||||
|
||||
Raises:
|
||||
Export path is not specified, got an empty string or None.
|
||||
"""
|
||||
if not model_export_path:
|
||||
raise ValueError('Export path is not specified: %s' % model_export_path)
|
||||
# Export uses float32 for now, even if training uses mixed precision.
|
||||
tf.keras.mixed_precision.experimental.set_policy('float32')
|
||||
squad_model, _ = bert_models.squad_model(bert_config,
|
||||
input_meta_data['max_seq_length'])
|
||||
model_saving_utils.export_bert_model(
|
||||
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Evaluation of SQuAD predictions (version 1.1).
|
||||
|
||||
The functions are copied from
|
||||
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
|
||||
|
||||
The SQuAD dataset is described in this paper:
|
||||
SQuAD: 100,000+ Questions for Machine Comprehension of Text
|
||||
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
|
||||
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import string
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from absl import logging
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
|
||||
def _normalize_answer(s):
|
||||
"""Lowers text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def _f1_score(prediction, ground_truth):
|
||||
"""Computes F1 score by comparing prediction to ground truth."""
|
||||
prediction_tokens = _normalize_answer(prediction).split()
|
||||
ground_truth_tokens = _normalize_answer(ground_truth).split()
|
||||
prediction_counter = collections.Counter(prediction_tokens)
|
||||
ground_truth_counter = collections.Counter(ground_truth_tokens)
|
||||
common = prediction_counter & ground_truth_counter
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def _exact_match_score(prediction, ground_truth):
|
||||
"""Checks if predicted answer exactly matches ground truth answer."""
|
||||
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
"""Computes the max over all metric scores."""
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def evaluate(dataset, predictions):
|
||||
"""Evaluates predictions for a dataset."""
|
||||
f1 = exact_match = total = 0
|
||||
for article in dataset:
|
||||
for paragraph in article["paragraphs"]:
|
||||
for qa in paragraph["qas"]:
|
||||
total += 1
|
||||
if qa["id"] not in predictions:
|
||||
message = "Unanswered question " + qa["id"] + " will receive score 0."
|
||||
logging.error(message)
|
||||
continue
|
||||
ground_truths = [entry["text"] for entry in qa["answers"]]
|
||||
prediction = predictions[qa["id"]]
|
||||
exact_match += _metric_max_over_ground_truths(_exact_match_score,
|
||||
prediction, ground_truths)
|
||||
f1 += _metric_max_over_ground_truths(_f1_score, prediction,
|
||||
ground_truths)
|
||||
|
||||
exact_match = exact_match / total
|
||||
f1 = f1 / total
|
||||
|
||||
return {"exact_match": exact_match, "final_f1": f1}
|
||||
+252
@@ -0,0 +1,252 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Evaluation script for SQuAD version 2.0.
|
||||
|
||||
The functions are copied and modified from
|
||||
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
|
||||
|
||||
In addition to basic functionality, we also compute additional statistics and
|
||||
plot precision-recall curves if an additional na_prob.json file is provided.
|
||||
This file is expected to map question ID's to the model's predicted probability
|
||||
that a question is unanswerable.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import string
|
||||
|
||||
from absl import logging
|
||||
|
||||
|
||||
def _make_qid_to_has_ans(dataset):
|
||||
qid_to_has_ans = {}
|
||||
for article in dataset:
|
||||
for p in article['paragraphs']:
|
||||
for qa in p['qas']:
|
||||
qid_to_has_ans[qa['id']] = bool(qa['answers'])
|
||||
return qid_to_has_ans
|
||||
|
||||
|
||||
def _normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
||||
return re.sub(regex, ' ', text)
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def _get_tokens(s):
|
||||
if not s: return []
|
||||
return _normalize_answer(s).split()
|
||||
|
||||
|
||||
def _compute_exact(a_gold, a_pred):
|
||||
return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
|
||||
|
||||
|
||||
def _compute_f1(a_gold, a_pred):
|
||||
"""Compute F1-score."""
|
||||
gold_toks = _get_tokens(a_gold)
|
||||
pred_toks = _get_tokens(a_pred)
|
||||
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
||||
num_same = sum(common.values())
|
||||
if not gold_toks or not pred_toks:
|
||||
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
||||
return int(gold_toks == pred_toks)
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(pred_toks)
|
||||
recall = 1.0 * num_same / len(gold_toks)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def _get_raw_scores(dataset, predictions):
|
||||
"""Compute raw scores."""
|
||||
exact_scores = {}
|
||||
f1_scores = {}
|
||||
for article in dataset:
|
||||
for p in article['paragraphs']:
|
||||
for qa in p['qas']:
|
||||
qid = qa['id']
|
||||
gold_answers = [a['text'] for a in qa['answers']
|
||||
if _normalize_answer(a['text'])]
|
||||
if not gold_answers:
|
||||
# For unanswerable questions, only correct answer is empty string
|
||||
gold_answers = ['']
|
||||
if qid not in predictions:
|
||||
logging.error('Missing prediction for %s', qid)
|
||||
continue
|
||||
a_pred = predictions[qid]
|
||||
# Take max over all gold answers
|
||||
exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
|
||||
f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
|
||||
return exact_scores, f1_scores
|
||||
|
||||
|
||||
def _apply_no_ans_threshold(
|
||||
scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
|
||||
new_scores = {}
|
||||
for qid, s in scores.items():
|
||||
pred_na = na_probs[qid] > na_prob_thresh
|
||||
if pred_na:
|
||||
new_scores[qid] = float(not qid_to_has_ans[qid])
|
||||
else:
|
||||
new_scores[qid] = s
|
||||
return new_scores
|
||||
|
||||
|
||||
def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
||||
"""Make evaluation result dictionary."""
|
||||
if not qid_list:
|
||||
total = len(exact_scores)
|
||||
return collections.OrderedDict([
|
||||
('exact', 100.0 * sum(exact_scores.values()) / total),
|
||||
('f1', 100.0 * sum(f1_scores.values()) / total),
|
||||
('total', total),
|
||||
])
|
||||
else:
|
||||
total = len(qid_list)
|
||||
return collections.OrderedDict([
|
||||
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||
('total', total),
|
||||
])
|
||||
|
||||
|
||||
def _merge_eval(main_eval, new_eval, prefix):
|
||||
for k in new_eval:
|
||||
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
||||
|
||||
|
||||
def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
|
||||
"""Make evaluation dictionary containing average recision recall."""
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
true_pos = 0.0
|
||||
cur_p = 1.0
|
||||
cur_r = 0.0
|
||||
precisions = [1.0]
|
||||
recalls = [0.0]
|
||||
avg_prec = 0.0
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid_to_has_ans[qid]:
|
||||
true_pos += scores[qid]
|
||||
cur_p = true_pos / float(i+1)
|
||||
cur_r = true_pos / float(num_true_pos)
|
||||
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
|
||||
# i.e., if we can put a threshold after this point
|
||||
avg_prec += cur_p * (cur_r - recalls[-1])
|
||||
precisions.append(cur_p)
|
||||
recalls.append(cur_r)
|
||||
return {'ap': 100.0 * avg_prec}
|
||||
|
||||
|
||||
def _run_precision_recall_analysis(
|
||||
main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
"""Run precision recall analysis and return result dictionary."""
|
||||
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
|
||||
if num_true_pos == 0:
|
||||
return
|
||||
pr_exact = _make_precision_recall_eval(
|
||||
exact_raw, na_probs, num_true_pos, qid_to_has_ans)
|
||||
pr_f1 = _make_precision_recall_eval(
|
||||
f1_raw, na_probs, num_true_pos, qid_to_has_ans)
|
||||
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
|
||||
pr_oracle = _make_precision_recall_eval(
|
||||
oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
|
||||
_merge_eval(main_eval, pr_exact, 'pr_exact')
|
||||
_merge_eval(main_eval, pr_f1, 'pr_f1')
|
||||
_merge_eval(main_eval, pr_oracle, 'pr_oracle')
|
||||
|
||||
|
||||
def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
|
||||
"""Find the best threshold for no answer probability."""
|
||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||
cur_score = num_no_ans
|
||||
best_score = cur_score
|
||||
best_thresh = 0.0
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
for qid in qid_list:
|
||||
if qid not in scores: continue
|
||||
if qid_to_has_ans[qid]:
|
||||
diff = scores[qid]
|
||||
else:
|
||||
if predictions[qid]:
|
||||
diff = -1
|
||||
else:
|
||||
diff = 0
|
||||
cur_score += diff
|
||||
if cur_score > best_score:
|
||||
best_score = cur_score
|
||||
best_thresh = na_probs[qid]
|
||||
return 100.0 * best_score / len(scores), best_thresh
|
||||
|
||||
|
||||
def _find_all_best_thresh(
|
||||
main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh = _find_best_thresh(
|
||||
predictions, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh = _find_best_thresh(
|
||||
predictions, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval['final_exact'] = best_exact
|
||||
main_eval['final_exact_thresh'] = exact_thresh
|
||||
main_eval['final_f1'] = best_f1
|
||||
main_eval['final_f1_thresh'] = f1_thresh
|
||||
|
||||
|
||||
def evaluate(dataset, predictions, na_probs=None):
|
||||
"""Evaluate prediction results."""
|
||||
new_orig_data = []
|
||||
for article in dataset:
|
||||
for p in article['paragraphs']:
|
||||
for qa in p['qas']:
|
||||
if qa['id'] in predictions:
|
||||
new_para = {'qas': [qa]}
|
||||
new_article = {'paragraphs': [new_para]}
|
||||
new_orig_data.append(new_article)
|
||||
dataset = new_orig_data
|
||||
|
||||
if na_probs is None:
|
||||
na_probs = {k: 0.0 for k in predictions}
|
||||
qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
|
||||
exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
|
||||
f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
|
||||
out_eval = _make_eval_dict(exact_thresh, f1_thresh)
|
||||
if has_ans_qids:
|
||||
has_ans_eval = _make_eval_dict(
|
||||
exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
||||
_merge_eval(out_eval, has_ans_eval, 'HasAns')
|
||||
if no_ans_qids:
|
||||
no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
||||
_merge_eval(out_eval, no_ans_eval, 'NoAns')
|
||||
|
||||
_find_all_best_thresh(
|
||||
out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
||||
_run_precision_recall_analysis(
|
||||
out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
||||
return out_eval
|
||||
+195
@@ -0,0 +1,195 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf # TF 1.x
|
||||
|
||||
# Mapping between old <=> new names. The source pattern in original variable
|
||||
# name will be replaced by destination pattern.
|
||||
BERT_NAME_REPLACEMENTS = (
|
||||
("bert", "bert_model"),
|
||||
("embeddings/word_embeddings", "word_embeddings/embeddings"),
|
||||
("embeddings/token_type_embeddings",
|
||||
"embedding_postprocessor/type_embeddings"),
|
||||
("embeddings/position_embeddings",
|
||||
"embedding_postprocessor/position_embeddings"),
|
||||
("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
|
||||
("attention/self", "self_attention"),
|
||||
("attention/output/dense", "self_attention_output"),
|
||||
("attention/output/LayerNorm", "self_attention_layer_norm"),
|
||||
("intermediate/dense", "intermediate"),
|
||||
("output/dense", "output"),
|
||||
("output/LayerNorm", "output_layer_norm"),
|
||||
("pooler/dense", "pooler_transform"),
|
||||
)
|
||||
|
||||
BERT_V2_NAME_REPLACEMENTS = (
|
||||
("bert/", ""),
|
||||
("encoder", "transformer"),
|
||||
("embeddings/word_embeddings", "word_embeddings/embeddings"),
|
||||
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
|
||||
("embeddings/position_embeddings", "position_embedding/embeddings"),
|
||||
("embeddings/LayerNorm", "embeddings/layer_norm"),
|
||||
("attention/self", "self_attention"),
|
||||
("attention/output/dense", "self_attention_output"),
|
||||
("attention/output/LayerNorm", "self_attention_layer_norm"),
|
||||
("intermediate/dense", "intermediate"),
|
||||
("output/dense", "output"),
|
||||
("output/LayerNorm", "output_layer_norm"),
|
||||
("pooler/dense", "pooler_transform"),
|
||||
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
|
||||
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
|
||||
("cls/seq_relationship/output_weights",
|
||||
"predictions/transform/logits/kernel"),
|
||||
)
|
||||
|
||||
BERT_PERMUTATIONS = ()
|
||||
|
||||
BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
|
||||
|
||||
|
||||
def _bert_name_replacement(var_name, name_replacements):
|
||||
"""Gets the variable name replacement."""
|
||||
for src_pattern, tgt_pattern in name_replacements:
|
||||
if src_pattern in var_name:
|
||||
old_var_name = var_name
|
||||
var_name = var_name.replace(src_pattern, tgt_pattern)
|
||||
tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
|
||||
return var_name
|
||||
|
||||
|
||||
def _has_exclude_patterns(name, exclude_patterns):
|
||||
"""Checks if a string contains substrings that match patterns to exclude."""
|
||||
for p in exclude_patterns:
|
||||
if p in name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_permutation(name, permutations):
|
||||
"""Checks whether a variable requires transposition by pattern matching."""
|
||||
for src_pattern, permutation in permutations:
|
||||
if src_pattern in name:
|
||||
tf.logging.info("Permuted: %s --> %s", name, permutation)
|
||||
return permutation
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_new_shape(name, shape, num_heads):
|
||||
"""Checks whether a variable requires reshape by pattern matching."""
|
||||
if "self_attention_output/kernel" in name:
|
||||
return tuple([num_heads, shape[0] // num_heads, shape[1]])
|
||||
if "self_attention_output/bias" in name:
|
||||
return shape
|
||||
|
||||
patterns = [
|
||||
"self_attention/query", "self_attention/value", "self_attention/key"
|
||||
]
|
||||
for pattern in patterns:
|
||||
if pattern in name:
|
||||
if "kernel" in name:
|
||||
return tuple([shape[0], num_heads, shape[1] // num_heads])
|
||||
if "bias" in name:
|
||||
return tuple([num_heads, shape[0] // num_heads])
|
||||
return None
|
||||
|
||||
|
||||
def create_v2_checkpoint(model, src_checkpoint, output_path):
|
||||
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
|
||||
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
|
||||
model.load_weights(src_checkpoint).assert_existing_objects_matched()
|
||||
checkpoint = tf.train.Checkpoint(model=model)
|
||||
checkpoint.save(output_path)
|
||||
|
||||
|
||||
def convert(checkpoint_from_path,
|
||||
checkpoint_to_path,
|
||||
num_heads,
|
||||
name_replacements,
|
||||
permutations,
|
||||
exclude_patterns=None):
|
||||
"""Migrates the names of variables within a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_from_path: Path to source checkpoint to be read in.
|
||||
checkpoint_to_path: Path to checkpoint to be written out.
|
||||
num_heads: The number of heads of the model.
|
||||
name_replacements: A list of tuples of the form (match_str, replace_str)
|
||||
describing variable names to adjust.
|
||||
permutations: A list of tuples of the form (match_str, permutation)
|
||||
describing permutations to apply to given variables. Note that match_str
|
||||
should match the original variable name, not the replaced one.
|
||||
exclude_patterns: A list of string patterns to exclude variables from
|
||||
checkpoint conversion.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps the new variable names to the Variable objects.
|
||||
A dictionary that maps the old variable names to the new variable names.
|
||||
"""
|
||||
with tf.Graph().as_default():
|
||||
tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
|
||||
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
|
||||
name_shape_map = reader.get_variable_to_shape_map()
|
||||
new_variable_map = {}
|
||||
conversion_map = {}
|
||||
for var_name in name_shape_map:
|
||||
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
|
||||
continue
|
||||
# Get the original tensor data.
|
||||
tensor = reader.get_tensor(var_name)
|
||||
|
||||
# Look up the new variable name, if any.
|
||||
new_var_name = _bert_name_replacement(var_name, name_replacements)
|
||||
|
||||
# See if we need to reshape the underlying tensor.
|
||||
new_shape = None
|
||||
if num_heads > 0:
|
||||
new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
|
||||
if new_shape:
|
||||
tf.logging.info("Veriable %s has a shape change from %s to %s",
|
||||
|
||||
var_name, tensor.shape, new_shape)
|
||||
tensor = np.reshape(tensor, new_shape)
|
||||
|
||||
# See if we need to permute the underlying tensor.
|
||||
permutation = _get_permutation(var_name, permutations)
|
||||
if permutation:
|
||||
tensor = np.transpose(tensor, permutation)
|
||||
|
||||
# Create a new variable with the possibly-reshaped or transposed tensor.
|
||||
var = tf.Variable(tensor, name=var_name)
|
||||
|
||||
# Save the variable into the new variable map.
|
||||
new_variable_map[new_var_name] = var
|
||||
|
||||
# Keep a list of converter variables for sanity checking.
|
||||
if new_var_name != var_name:
|
||||
conversion_map[var_name] = new_var_name
|
||||
|
||||
saver = tf.train.Saver(new_variable_map)
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
|
||||
saver.save(sess, checkpoint_to_path)
|
||||
|
||||
tf.logging.info("Summary:")
|
||||
tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
|
||||
tf.logging.info(" Converted: %s", str(conversion_map))
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
|
||||
|
||||
The conversion will yield an object-oriented checkpoint that can be used
|
||||
to restore a TransformerEncoder object.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import tensorflow as tf
|
||||
from official.modeling import activations
|
||||
from official.nlp.bert import configs
|
||||
from official.nlp.bert import tf1_checkpoint_converter_lib
|
||||
from official.nlp.modeling import networks
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("bert_config_file", None,
|
||||
"Bert configuration file to define core bert layers.")
|
||||
flags.DEFINE_string(
|
||||
"checkpoint_to_convert", None,
|
||||
"Initial checkpoint from a pretrained BERT model core (that is, only the "
|
||||
"BertModel, with no task heads.)")
|
||||
flags.DEFINE_string("converted_checkpoint_path", None,
|
||||
"Name for the created object-based V2 checkpoint.")
|
||||
|
||||
|
||||
def _create_bert_model(cfg):
|
||||
"""Creates a BERT keras core model from BERT configuration.
|
||||
|
||||
Args:
|
||||
cfg: A `BertConfig` to create the core model.
|
||||
Returns:
|
||||
A TransformerEncoder netowork.
|
||||
"""
|
||||
bert_encoder = networks.TransformerEncoder(
|
||||
vocab_size=cfg.vocab_size,
|
||||
hidden_size=cfg.hidden_size,
|
||||
num_layers=cfg.num_hidden_layers,
|
||||
num_attention_heads=cfg.num_attention_heads,
|
||||
intermediate_size=cfg.intermediate_size,
|
||||
activation=activations.gelu,
|
||||
dropout_rate=cfg.hidden_dropout_prob,
|
||||
attention_dropout_rate=cfg.attention_probs_dropout_prob,
|
||||
sequence_length=cfg.max_position_embeddings,
|
||||
type_vocab_size=cfg.type_vocab_size,
|
||||
initializer=tf.keras.initializers.TruncatedNormal(
|
||||
stddev=cfg.initializer_range))
|
||||
|
||||
return bert_encoder
|
||||
|
||||
|
||||
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
|
||||
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
|
||||
output_dir, _ = os.path.split(output_path)
|
||||
|
||||
# Create a temporary V1 name-converted checkpoint in the output directory.
|
||||
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
|
||||
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
|
||||
tf1_checkpoint_converter_lib.convert(
|
||||
checkpoint_from_path=v1_checkpoint,
|
||||
checkpoint_to_path=temporary_checkpoint,
|
||||
num_heads=bert_config.num_attention_heads,
|
||||
name_replacements=tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS,
|
||||
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
|
||||
exclude_patterns=["adam", "Adam"])
|
||||
|
||||
# Create a V2 checkpoint from the temporary checkpoint.
|
||||
model = _create_bert_model(bert_config)
|
||||
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
|
||||
output_path)
|
||||
|
||||
# Clean up the temporary checkpoint, if it exists.
|
||||
try:
|
||||
tf.io.gfile.rmtree(temporary_checkpoint_dir)
|
||||
except tf.errors.OpError:
|
||||
# If it doesn't exist, we don't need to clean it up; continue.
|
||||
pass
|
||||
|
||||
|
||||
def main(_):
|
||||
output_path = FLAGS.converted_checkpoint_path
|
||||
v1_checkpoint = FLAGS.checkpoint_to_convert
|
||||
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
convert_checkpoint(bert_config, output_path, v1_checkpoint)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
+545
@@ -0,0 +1,545 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tokenization classes implementation.
|
||||
|
||||
The file is forked from:
|
||||
https://github.com/google-research/bert/blob/master/tokenization.py.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
|
||||
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
"""Checks whether the casing config is consistent with the checkpoint name."""
|
||||
|
||||
# The casing has to be passed in by the user and there is no explicit check
|
||||
# as to whether it matches the checkpoint. The casing information probably
|
||||
# should have been stored in the bert_config.json file, but it's not, so
|
||||
# we have to heuristically detect it to validate.
|
||||
|
||||
if not init_checkpoint:
|
||||
return
|
||||
|
||||
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
||||
if m is None:
|
||||
return
|
||||
|
||||
model_name = m.group(1)
|
||||
|
||||
lower_models = [
|
||||
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
cased_models = [
|
||||
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||
"multi_cased_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
is_bad_config = False
|
||||
if model_name in lower_models and not do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "False"
|
||||
case_name = "lowercased"
|
||||
opposite_flag = "True"
|
||||
|
||||
if model_name in cased_models and do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "True"
|
||||
case_name = "cased"
|
||||
opposite_flag = "False"
|
||||
|
||||
if is_bad_config:
|
||||
raise ValueError(
|
||||
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
||||
"However, `%s` seems to be a %s model, so you "
|
||||
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||
"how the model was pre-training. If this error is wrong, please "
|
||||
"just comment out this check." %
|
||||
(actual_flag, init_checkpoint, model_name, case_name, opposite_flag))
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with tf.io.gfile.GFile(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_by_vocab(vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
output.append(vocab[item])
|
||||
return output
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
return convert_by_vocab(vocab, tokens)
|
||||
|
||||
|
||||
def convert_ids_to_tokens(inv_vocab, ids):
|
||||
return convert_by_vocab(inv_vocab, ids)
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, split_on_punc=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(
|
||||
do_lower_case=do_lower_case, split_on_punc=split_on_punc)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab(self.inv_vocab, ids)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True, split_on_punc=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
split_on_punc: Whether to apply split on punctuations. By default BERT
|
||||
starts a new token for punctuations. This makes detokenization difficult
|
||||
for tasks like seq2seq decoding.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
self.split_on_punc = split_on_punc
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
if self.split_on_punc:
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
else:
|
||||
split_tokens.append(token)
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenziation."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically control characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def preprocess_text(inputs, remove_space=True, lower=False):
|
||||
"""Preprocesses data by removing extra space and normalize data.
|
||||
|
||||
This method is used together with sentence piece tokenizer and is forked from:
|
||||
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
|
||||
|
||||
Args:
|
||||
inputs: The input text.
|
||||
remove_space: Whether to remove the extra space.
|
||||
lower: Whether to lowercase the text.
|
||||
|
||||
Returns:
|
||||
The preprocessed text.
|
||||
|
||||
"""
|
||||
outputs = inputs
|
||||
if remove_space:
|
||||
outputs = " ".join(inputs.strip().split())
|
||||
|
||||
if six.PY2 and isinstance(outputs, str):
|
||||
try:
|
||||
outputs = six.ensure_text(outputs, "utf-8")
|
||||
except UnicodeDecodeError:
|
||||
outputs = six.ensure_text(outputs, "latin-1")
|
||||
|
||||
outputs = unicodedata.normalize("NFKD", outputs)
|
||||
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
|
||||
if lower:
|
||||
outputs = outputs.lower()
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def encode_pieces(sp_model, text, sample=False):
|
||||
"""Segements text into pieces.
|
||||
|
||||
This method is used together with sentence piece tokenizer and is forked from:
|
||||
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
|
||||
|
||||
|
||||
Args:
|
||||
sp_model: A spm.SentencePieceProcessor object.
|
||||
text: The input text to be segemented.
|
||||
sample: Whether to randomly sample a segmentation output or return a
|
||||
deterministic one.
|
||||
|
||||
Returns:
|
||||
A list of token pieces.
|
||||
"""
|
||||
if six.PY2 and isinstance(text, six.text_type):
|
||||
text = six.ensure_binary(text, "utf-8")
|
||||
|
||||
if not sample:
|
||||
pieces = sp_model.EncodeAsPieces(text)
|
||||
else:
|
||||
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
|
||||
new_pieces = []
|
||||
for piece in pieces:
|
||||
piece = printable_text(piece)
|
||||
if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
|
||||
cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
|
||||
SPIECE_UNDERLINE, ""))
|
||||
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
||||
if len(cur_pieces[0]) == 1:
|
||||
cur_pieces = cur_pieces[1:]
|
||||
else:
|
||||
cur_pieces[0] = cur_pieces[0][1:]
|
||||
cur_pieces.append(piece[-1])
|
||||
new_pieces.extend(cur_pieces)
|
||||
else:
|
||||
new_pieces.append(piece)
|
||||
|
||||
return new_pieces
|
||||
|
||||
|
||||
def encode_ids(sp_model, text, sample=False):
|
||||
"""Segments text and return token ids.
|
||||
|
||||
This method is used together with sentence piece tokenizer and is forked from:
|
||||
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
|
||||
|
||||
Args:
|
||||
sp_model: A spm.SentencePieceProcessor object.
|
||||
text: The input text to be segemented.
|
||||
sample: Whether to randomly sample a segmentation output or return a
|
||||
deterministic one.
|
||||
|
||||
Returns:
|
||||
A list of token ids.
|
||||
"""
|
||||
pieces = encode_pieces(sp_model, text, sample=sample)
|
||||
ids = [sp_model.PieceToId(piece) for piece in pieces]
|
||||
return ids
|
||||
|
||||
|
||||
class FullSentencePieceTokenizer(object):
|
||||
"""Runs end-to-end sentence piece tokenization.
|
||||
|
||||
The interface of this class is intended to keep the same as above
|
||||
`FullTokenizer` class for easier usage.
|
||||
"""
|
||||
|
||||
def __init__(self, sp_model_file):
|
||||
"""Inits FullSentencePieceTokenizer.
|
||||
|
||||
Args:
|
||||
sp_model_file: The path to the sentence piece model file.
|
||||
"""
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(sp_model_file)
|
||||
self.vocab = {
|
||||
self.sp_model.IdToPiece(i): i
|
||||
for i in six.moves.range(self.sp_model.GetPieceSize())
|
||||
}
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes text into pieces."""
|
||||
return encode_pieces(self.sp_model, text)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Converts a list of tokens to a list of ids."""
|
||||
return [self.sp_model.PieceToId(printable_text(token)) for token in tokens]
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
"""Converts a list of ids ot a list of tokens."""
|
||||
return [self.sp_model.IdToPiece(id_) for id_ in ids]
|
||||
+160
@@ -0,0 +1,160 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.bert import tokenization
|
||||
|
||||
|
||||
class TokenizationTest(tf.test.TestCase):
|
||||
"""Tokenization test.
|
||||
|
||||
The implementation is forked from
|
||||
https://github.com/google-research/bert/blob/master/tokenization_test.py."
|
||||
"""
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
]
|
||||
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
||||
if six.PY2:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
else:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
|
||||
]).encode("utf-8"))
|
||||
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file)
|
||||
os.unlink(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = tokenization.BasicTokenizer()
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
|
||||
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
||||
|
||||
def test_basic_tokenizer_lower(self):
|
||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["hello", "!", "how", "are", "you", "?"])
|
||||
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||
|
||||
def test_basic_tokenizer_no_lower(self):
|
||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||
|
||||
def test_basic_tokenizer_no_split_on_punc(self):
|
||||
tokenizer = tokenization.BasicTokenizer(
|
||||
do_lower_case=True, split_on_punc=False)
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["hello!how", "are", "you?"])
|
||||
|
||||
def test_wordpiece_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", "##!", "!"
|
||||
]
|
||||
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
||||
|
||||
self.assertAllEqual(tokenizer.tokenize(""), [])
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize("unwanted running"),
|
||||
["un", "##want", "##ed", "runn", "##ing"])
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize("unwanted running !"),
|
||||
["un", "##want", "##ed", "runn", "##ing", "!"])
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize("unwanted running!"),
|
||||
["un", "##want", "##ed", "runn", "##ing", "##!"])
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||
|
||||
def test_convert_tokens_to_ids(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing"
|
||||
]
|
||||
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
|
||||
self.assertAllEqual(
|
||||
tokenization.convert_tokens_to_ids(
|
||||
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
||||
|
||||
def test_is_whitespace(self):
|
||||
self.assertTrue(tokenization._is_whitespace(u" "))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
||||
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
||||
|
||||
self.assertFalse(tokenization._is_whitespace(u"A"))
|
||||
self.assertFalse(tokenization._is_whitespace(u"-"))
|
||||
|
||||
def test_is_control(self):
|
||||
self.assertTrue(tokenization._is_control(u"\u0005"))
|
||||
|
||||
self.assertFalse(tokenization._is_control(u"A"))
|
||||
self.assertFalse(tokenization._is_control(u" "))
|
||||
self.assertFalse(tokenization._is_control(u"\t"))
|
||||
self.assertFalse(tokenization._is_control(u"\r"))
|
||||
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
|
||||
|
||||
def test_is_punctuation(self):
|
||||
self.assertTrue(tokenization._is_punctuation(u"-"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"$"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"`"))
|
||||
self.assertTrue(tokenization._is_punctuation(u"."))
|
||||
|
||||
self.assertFalse(tokenization._is_punctuation(u"A"))
|
||||
self.assertFalse(tokenization._is_punctuation(u" "))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
+676
@@ -0,0 +1,676 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""BERT library to process data for classification task."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import csv
|
||||
import os
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from official.nlp.bert import tokenization
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for simple sequence classification."""
|
||||
|
||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||
"""Constructs a InputExample.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.guid = guid
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
self.label = label
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
label_id,
|
||||
is_real_example=True):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
self.is_real_example = is_real_example
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
|
||||
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
|
||||
self.process_text_fn = process_text_fn
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for prediction."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def get_processor_name():
|
||||
"""Gets the string identifier of the processor."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
with tf.io.gfile.GFile(input_file, "r") as f:
|
||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||
lines = []
|
||||
for line in reader:
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
class XnliProcessor(DataProcessor):
|
||||
"""Processor for the XNLI data set."""
|
||||
|
||||
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
|
||||
super(XnliProcessor, self).__init__(process_text_fn)
|
||||
self.language = "zh"
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
lines = self._read_tsv(
|
||||
os.path.join(data_dir, "multinli",
|
||||
"multinli.train.%s.tsv" % self.language))
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "train-%d" % (i)
|
||||
text_a = self.process_text_fn(line[0])
|
||||
text_b = self.process_text_fn(line[1])
|
||||
label = self.process_text_fn(line[2])
|
||||
if label == self.process_text_fn("contradictory"):
|
||||
label = self.process_text_fn("contradiction")
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "dev-%d" % (i)
|
||||
language = self.process_text_fn(line[0])
|
||||
if language != self.process_text_fn(self.language):
|
||||
continue
|
||||
text_a = self.process_text_fn(line[6])
|
||||
text_b = self.process_text_fn(line[7])
|
||||
label = self.process_text_fn(line[1])
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["contradiction", "entailment", "neutral"]
|
||||
|
||||
@staticmethod
|
||||
def get_processor_name():
|
||||
"""See base class."""
|
||||
return "XNLI"
|
||||
|
||||
|
||||
class MnliProcessor(DataProcessor):
|
||||
"""Processor for the MultiNLI data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
|
||||
"dev_matched")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["contradiction", "entailment", "neutral"]
|
||||
|
||||
@staticmethod
|
||||
def get_processor_name():
|
||||
"""See base class."""
|
||||
return "MNLI"
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
|
||||
text_a = self.process_text_fn(line[8])
|
||||
text_b = self.process_text_fn(line[9])
|
||||
if set_type == "test":
|
||||
label = "contradiction"
|
||||
else:
|
||||
label = self.process_text_fn(line[-1])
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class MrpcProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
@staticmethod
|
||||
def get_processor_name():
|
||||
"""See base class."""
|
||||
return "MRPC"
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = self.process_text_fn(line[3])
|
||||
text_b = self.process_text_fn(line[4])
|
||||
if set_type == "test":
|
||||
label = "0"
|
||||
else:
|
||||
label = self.process_text_fn(line[0])
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class ColaProcessor(DataProcessor):
|
||||
"""Processor for the CoLA data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
@staticmethod
|
||||
def get_processor_name():
|
||||
"""See base class."""
|
||||
return "COLA"
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
# Only the test set has a header
|
||||
if set_type == "test" and i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
if set_type == "test":
|
||||
text_a = self.process_text_fn(line[1])
|
||||
label = "0"
|
||||
else:
|
||||
text_a = self.process_text_fn(line[3])
|
||||
label = self.process_text_fn(line[1])
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class SstProcessor(DataProcessor):
|
||||
"""Processor for the SST-2 data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
@staticmethod
|
||||
def get_processor_name():
|
||||
"""See base class."""
|
||||
return "SST-2"
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
if set_type == "test":
|
||||
text_a = tokenization.convert_to_unicode(line[1])
|
||||
label = "0"
|
||||
else:
|
||||
text_a = tokenization.convert_to_unicode(line[0])
|
||||
label = tokenization.convert_to_unicode(line[1])
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class QnliProcessor(DataProcessor):
|
||||
"""Processor for the QNLI data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["entailment", "not_entailment"]
|
||||
|
||||
@staticmethod
|
||||
def get_processor_name():
|
||||
"""See base class."""
|
||||
return "QNLI"
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, 1)
|
||||
if set_type == "test":
|
||||
text_a = tokenization.convert_to_unicode(line[1])
|
||||
text_b = tokenization.convert_to_unicode(line[2])
|
||||
label = "entailment"
|
||||
else:
|
||||
text_a = tokenization.convert_to_unicode(line[1])
|
||||
text_b = tokenization.convert_to_unicode(line[2])
|
||||
label = tokenization.convert_to_unicode(line[-1])
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class TfdsProcessor(DataProcessor):
|
||||
"""Processor for generic text classification TFDS data set.
|
||||
|
||||
The TFDS parameters are expected to be provided in the tfds_params string, in
|
||||
a comma-separated list of parameter assignments.
|
||||
Examples:
|
||||
tfds_params="dataset=scicite,text_key=string"
|
||||
tfds_params="dataset=imdb_reviews,test_split=,dev_split=test"
|
||||
tfds_params="dataset=glue/cola,text_key=sentence"
|
||||
tfds_params="dataset=glue/sst2,text_key=sentence"
|
||||
tfds_params="dataset=glue/qnli,text_key=question,text_b_key=sentence"
|
||||
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
|
||||
Possible parameters (please refer to the documentation of Tensorflow Datasets
|
||||
(TFDS) for the meaning of individual parameters):
|
||||
dataset: Required dataset name (potentially with subset and version number).
|
||||
data_dir: Optional TFDS source root directory.
|
||||
train_split: Name of the train split (defaults to `train`).
|
||||
dev_split: Name of the dev split (defaults to `validation`).
|
||||
test_split: Name of the test split (defaults to `test`).
|
||||
text_key: Key of the text_a feature (defaults to `text`).
|
||||
text_b_key: Key of the second text feature if available.
|
||||
label_key: Key of the label feature (defaults to `label`).
|
||||
test_text_key: Key of the text feature to use in test set.
|
||||
test_text_b_key: Key of the second text feature to use in test set.
|
||||
test_label: String to be used as the label for all test examples.
|
||||
"""
|
||||
|
||||
def __init__(self, tfds_params,
|
||||
process_text_fn=tokenization.convert_to_unicode):
|
||||
super(TfdsProcessor, self).__init__(process_text_fn)
|
||||
self._process_tfds_params_str(tfds_params)
|
||||
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir,
|
||||
with_info=True)
|
||||
self._labels = list(range(info.features[self.label_key].num_classes))
|
||||
|
||||
def _process_tfds_params_str(self, params_str):
|
||||
"""Extracts TFDS parameters from a comma-separated assignements string."""
|
||||
tuples = [x.split("=") for x in params_str.split(",")]
|
||||
d = {k.strip(): v.strip() for k, v in tuples}
|
||||
self.dataset_name = d["dataset"] # Required.
|
||||
self.data_dir = d.get("data_dir", None)
|
||||
self.train_split = d.get("train_split", "train")
|
||||
self.dev_split = d.get("dev_split", "validation")
|
||||
self.test_split = d.get("test_split", "test")
|
||||
self.text_key = d.get("text_key", "text")
|
||||
self.text_b_key = d.get("text_b_key", None)
|
||||
self.label_key = d.get("label_key", "label")
|
||||
self.test_text_key = d.get("test_text_key", self.text_key)
|
||||
self.test_text_b_key = d.get("test_text_b_key", self.text_b_key)
|
||||
self.test_label = d.get("test_label", "test_example")
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
assert data_dir is None
|
||||
return self._create_examples(self.train_split, "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
assert data_dir is None
|
||||
return self._create_examples(self.dev_split, "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
assert data_dir is None
|
||||
return self._create_examples(self.test_split, "test")
|
||||
|
||||
def get_labels(self):
|
||||
return self._labels
|
||||
|
||||
def get_processor_name(self):
|
||||
return "TFDS_" + self.dataset_name
|
||||
|
||||
def _create_examples(self, split_name, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
if split_name not in self.dataset:
|
||||
raise ValueError("Split {} not available.".format(split_name))
|
||||
dataset = self.dataset[split_name].as_numpy_iterator()
|
||||
examples = []
|
||||
text_b = None
|
||||
for i, example in enumerate(dataset):
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
if set_type == "test":
|
||||
text_a = self.process_text_fn(example[self.test_text_key])
|
||||
if self.test_text_b_key:
|
||||
text_b = self.process_text_fn(example[self.test_text_b_key])
|
||||
label = self.test_label
|
||||
else:
|
||||
text_a = self.process_text_fn(example[self.text_key])
|
||||
if self.text_b_key:
|
||||
text_b = self.process_text_fn(example[self.text_b_key])
|
||||
label = int(example[self.label_key])
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
def convert_single_example(ex_index, example, label_list, max_seq_length,
|
||||
tokenizer):
|
||||
"""Converts a single `InputExample` into a single `InputFeatures`."""
|
||||
label_map = {}
|
||||
for (i, label) in enumerate(label_list):
|
||||
label_map[label] = i
|
||||
|
||||
tokens_a = tokenizer.tokenize(example.text_a)
|
||||
tokens_b = None
|
||||
if example.text_b:
|
||||
tokens_b = tokenizer.tokenize(example.text_b)
|
||||
|
||||
if tokens_b:
|
||||
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
||||
# length is less than the specified length.
|
||||
# Account for [CLS], [SEP], [SEP] with "- 3"
|
||||
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
||||
else:
|
||||
# Account for [CLS] and [SEP] with "- 2"
|
||||
if len(tokens_a) > max_seq_length - 2:
|
||||
tokens_a = tokens_a[0:(max_seq_length - 2)]
|
||||
|
||||
# The convention in BERT is:
|
||||
# (a) For sequence pairs:
|
||||
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
||||
# (b) For single sequences:
|
||||
# tokens: [CLS] the dog is hairy . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0
|
||||
#
|
||||
# Where "type_ids" are used to indicate whether this is the first
|
||||
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||
# embedding vector (and position vector). This is not *strictly* necessary
|
||||
# since the [SEP] token unambiguously separates the sequences, but it makes
|
||||
# it easier for the model to learn the concept of sequences.
|
||||
#
|
||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||
# used as the "sentence vector". Note that this only makes sense because
|
||||
# the entire model is fine-tuned.
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
if tokens_b:
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
label_id = label_map[example.label]
|
||||
if ex_index < 5:
|
||||
logging.info("*** Example ***")
|
||||
logging.info("guid: %s", (example.guid))
|
||||
logging.info("tokens: %s",
|
||||
" ".join([tokenization.printable_text(x) for x in tokens]))
|
||||
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
|
||||
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
|
||||
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
|
||||
logging.info("label: %s (id = %d)", example.label, label_id)
|
||||
|
||||
feature = InputFeatures(
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
label_id=label_id,
|
||||
is_real_example=True)
|
||||
return feature
|
||||
|
||||
|
||||
def file_based_convert_examples_to_features(examples, label_list,
|
||||
max_seq_length, tokenizer,
|
||||
output_file):
|
||||
"""Convert a set of `InputExample`s to a TFRecord file."""
|
||||
|
||||
tf.io.gfile.makedirs(os.path.dirname(output_file))
|
||||
writer = tf.io.TFRecordWriter(output_file)
|
||||
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
if ex_index % 10000 == 0:
|
||||
logging.info("Writing example %d of %d", ex_index, len(examples))
|
||||
|
||||
feature = convert_single_example(ex_index, example, label_list,
|
||||
max_seq_length, tokenizer)
|
||||
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return f
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
features["label_ids"] = create_int_feature([feature.label_id])
|
||||
features["is_real_example"] = create_int_feature(
|
||||
[int(feature.is_real_example)])
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
writer.write(tf_example.SerializeToString())
|
||||
writer.close()
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
def generate_tf_record_from_data_file(processor,
|
||||
data_dir,
|
||||
tokenizer,
|
||||
train_data_output_path=None,
|
||||
eval_data_output_path=None,
|
||||
max_seq_length=128):
|
||||
"""Generates and saves training data into a tf record file.
|
||||
|
||||
Arguments:
|
||||
processor: Input processor object to be used for generating data. Subclass
|
||||
of `DataProcessor`.
|
||||
data_dir: Directory that contains train/eval data to process. Data files
|
||||
should be in from "dev.tsv", "test.tsv", or "train.tsv".
|
||||
tokenizer: The tokenizer to be applied on the data.
|
||||
train_data_output_path: Output to which processed tf record for training
|
||||
will be saved.
|
||||
eval_data_output_path: Output to which processed tf record for evaluation
|
||||
will be saved.
|
||||
max_seq_length: Maximum sequence length of the to be generated
|
||||
training/eval data.
|
||||
|
||||
Returns:
|
||||
A dictionary containing input meta data.
|
||||
"""
|
||||
assert train_data_output_path or eval_data_output_path
|
||||
|
||||
label_list = processor.get_labels()
|
||||
assert train_data_output_path
|
||||
train_input_data_examples = processor.get_train_examples(data_dir)
|
||||
file_based_convert_examples_to_features(train_input_data_examples, label_list,
|
||||
max_seq_length, tokenizer,
|
||||
train_data_output_path)
|
||||
num_training_data = len(train_input_data_examples)
|
||||
|
||||
if eval_data_output_path:
|
||||
eval_input_data_examples = processor.get_dev_examples(data_dir)
|
||||
file_based_convert_examples_to_features(eval_input_data_examples,
|
||||
label_list, max_seq_length,
|
||||
tokenizer, eval_data_output_path)
|
||||
|
||||
meta_data = {
|
||||
"task_type": "bert_classification",
|
||||
"processor_type": processor.get_processor_name(),
|
||||
"num_labels": len(processor.get_labels()),
|
||||
"train_data_size": num_training_data,
|
||||
"max_seq_length": max_seq_length,
|
||||
}
|
||||
|
||||
if eval_data_output_path:
|
||||
meta_data["eval_data_size"] = len(eval_input_data_examples)
|
||||
|
||||
return meta_data
|
||||
+203
@@ -0,0 +1,203 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""BERT finetuning task dataset generator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
from official.nlp.bert import tokenization
|
||||
from official.nlp.data import classifier_data_lib
|
||||
# word-piece tokenizer based squad_lib
|
||||
from official.nlp.data import squad_lib as squad_lib_wp
|
||||
# sentence-piece tokenizer based squad_lib
|
||||
from official.nlp.data import squad_lib_sp
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_enum(
|
||||
"fine_tuning_task_type", "classification", ["classification", "squad"],
|
||||
"The name of the BERT fine tuning task for which data "
|
||||
"will be generated..")
|
||||
|
||||
# BERT classification specific flags.
|
||||
flags.DEFINE_string(
|
||||
"input_data_dir", None,
|
||||
"The input data dir. Should contain the .tsv files (or other data files) "
|
||||
"for the task.")
|
||||
|
||||
flags.DEFINE_enum("classification_task_name", "MNLI",
|
||||
["COLA", "MNLI", "MRPC", "QNLI", "SST-2", "XNLI"],
|
||||
"The name of the task to train BERT classifier.")
|
||||
|
||||
# BERT Squad task specific flags.
|
||||
flags.DEFINE_string(
|
||||
"squad_data_file", None,
|
||||
"The input data file in for generating training data for BERT squad task.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"doc_stride", 128,
|
||||
"When splitting up a long document into chunks, how much stride to "
|
||||
"take between chunks.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_query_length", 64,
|
||||
"The maximum number of tokens for the question. Questions longer than "
|
||||
"this will be truncated to this length.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"version_2_with_negative", False,
|
||||
"If true, the SQuAD examples contain some that do not have an answer.")
|
||||
|
||||
# Shared flags across BERT fine-tuning tasks.
|
||||
flags.DEFINE_string("vocab_file", None,
|
||||
"The vocabulary file that the BERT model was trained on.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"train_data_output_path", None,
|
||||
"The path in which generated training input data will be written as tf"
|
||||
" records.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"eval_data_output_path", None,
|
||||
"The path in which generated training input data will be written as tf"
|
||||
" records.")
|
||||
|
||||
flags.DEFINE_string("meta_data_file_path", None,
|
||||
"The path in which input meta data will be written.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"do_lower_case", True,
|
||||
"Whether to lower case the input text. Should be True for uncased "
|
||||
"models and False for cased models.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_seq_length", 128,
|
||||
"The maximum total input sequence length after WordPiece tokenization. "
|
||||
"Sequences longer than this will be truncated, and sequences shorter "
|
||||
"than this will be padded.")
|
||||
|
||||
flags.DEFINE_string("sp_model_file", "",
|
||||
"The path to the model used by sentence piece tokenizer.")
|
||||
|
||||
flags.DEFINE_enum(
|
||||
"tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"],
|
||||
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
|
||||
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
|
||||
"while ALBERT uses sentence_piece tokenizer.")
|
||||
|
||||
flags.DEFINE_string("tfds_params", "",
|
||||
"Comma-separated list of TFDS parameter assigments for "
|
||||
"generic classfication data import (for more details "
|
||||
"see the TfdsProcessor class documentation).")
|
||||
|
||||
|
||||
def generate_classifier_dataset():
|
||||
"""Generates classifier dataset and returns input meta data."""
|
||||
assert (FLAGS.input_data_dir and FLAGS.classification_task_name
|
||||
or FLAGS.tfds_params)
|
||||
|
||||
if FLAGS.tokenizer_impl == "word_piece":
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
processor_text_fn = tokenization.convert_to_unicode
|
||||
else:
|
||||
assert FLAGS.tokenizer_impl == "sentence_piece"
|
||||
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
|
||||
processor_text_fn = functools.partial(
|
||||
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
|
||||
|
||||
if FLAGS.tfds_params:
|
||||
processor = classifier_data_lib.TfdsProcessor(
|
||||
tfds_params=FLAGS.tfds_params,
|
||||
process_text_fn=processor_text_fn)
|
||||
return classifier_data_lib.generate_tf_record_from_data_file(
|
||||
processor,
|
||||
None,
|
||||
tokenizer,
|
||||
train_data_output_path=FLAGS.train_data_output_path,
|
||||
eval_data_output_path=FLAGS.eval_data_output_path,
|
||||
max_seq_length=FLAGS.max_seq_length)
|
||||
else:
|
||||
processors = {
|
||||
"cola": classifier_data_lib.ColaProcessor,
|
||||
"mnli": classifier_data_lib.MnliProcessor,
|
||||
"mrpc": classifier_data_lib.MrpcProcessor,
|
||||
"qnli": classifier_data_lib.QnliProcessor,
|
||||
"sst-2": classifier_data_lib.SstProcessor,
|
||||
"xnli": classifier_data_lib.XnliProcessor,
|
||||
}
|
||||
task_name = FLAGS.classification_task_name.lower()
|
||||
if task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
processor = processors[task_name](processor_text_fn)
|
||||
return classifier_data_lib.generate_tf_record_from_data_file(
|
||||
processor,
|
||||
FLAGS.input_data_dir,
|
||||
tokenizer,
|
||||
train_data_output_path=FLAGS.train_data_output_path,
|
||||
eval_data_output_path=FLAGS.eval_data_output_path,
|
||||
max_seq_length=FLAGS.max_seq_length)
|
||||
|
||||
|
||||
def generate_squad_dataset():
|
||||
"""Generates squad training dataset and returns input meta data."""
|
||||
assert FLAGS.squad_data_file
|
||||
if FLAGS.tokenizer_impl == "word_piece":
|
||||
return squad_lib_wp.generate_tf_record_from_json_file(
|
||||
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
|
||||
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
|
||||
FLAGS.doc_stride, FLAGS.version_2_with_negative)
|
||||
else:
|
||||
assert FLAGS.tokenizer_impl == "sentence_piece"
|
||||
return squad_lib_sp.generate_tf_record_from_json_file(
|
||||
FLAGS.squad_data_file, FLAGS.sp_model_file,
|
||||
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
|
||||
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
|
||||
|
||||
|
||||
def main(_):
|
||||
if FLAGS.tokenizer_impl == "word_piece":
|
||||
if not FLAGS.vocab_file:
|
||||
raise ValueError(
|
||||
"FLAG vocab_file for word-piece tokenizer is not specified.")
|
||||
else:
|
||||
assert FLAGS.tokenizer_impl == "sentence_piece"
|
||||
if not FLAGS.sp_model_file:
|
||||
raise ValueError(
|
||||
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
|
||||
|
||||
if FLAGS.fine_tuning_task_type == "classification":
|
||||
input_meta_data = generate_classifier_dataset()
|
||||
else:
|
||||
input_meta_data = generate_squad_dataset()
|
||||
|
||||
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
|
||||
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
|
||||
writer.write(json.dumps(input_meta_data, indent=4) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("train_data_output_path")
|
||||
flags.mark_flag_as_required("meta_data_file_path")
|
||||
app.run(main)
|
||||
+486
@@ -0,0 +1,486 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import random
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.bert import tokenization
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("input_file", None,
|
||||
"Input raw text file (or comma-separated list of files).")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"output_file", None,
|
||||
"Output TF example file (or comma-separated list of files).")
|
||||
|
||||
flags.DEFINE_string("vocab_file", None,
|
||||
"The vocabulary file that the BERT model was trained on.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"do_lower_case", True,
|
||||
"Whether to lower case the input text. Should be True for uncased "
|
||||
"models and False for cased models.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"do_whole_word_mask", False,
|
||||
"Whether to use whole word masking rather than per-WordPiece masking.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"gzip_compress", False,
|
||||
"Whether to use `GZIP` compress option to get compressed TFRecord files.")
|
||||
|
||||
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
|
||||
|
||||
flags.DEFINE_integer("max_predictions_per_seq", 20,
|
||||
"Maximum number of masked LM predictions per sequence.")
|
||||
|
||||
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"dupe_factor", 10,
|
||||
"Number of times to duplicate the input data (with different masks).")
|
||||
|
||||
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"short_seq_prob", 0.1,
|
||||
"Probability of creating sequences which are shorter than the "
|
||||
"maximum length.")
|
||||
|
||||
|
||||
class TrainingInstance(object):
|
||||
"""A single training instance (sentence pair)."""
|
||||
|
||||
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
||||
is_random_next):
|
||||
self.tokens = tokens
|
||||
self.segment_ids = segment_ids
|
||||
self.is_random_next = is_random_next
|
||||
self.masked_lm_positions = masked_lm_positions
|
||||
self.masked_lm_labels = masked_lm_labels
|
||||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += "tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.tokens]))
|
||||
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
||||
s += "is_random_next: %s\n" % self.is_random_next
|
||||
s += "masked_lm_positions: %s\n" % (" ".join(
|
||||
[str(x) for x in self.masked_lm_positions]))
|
||||
s += "masked_lm_labels: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
||||
s += "\n"
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
||||
max_predictions_per_seq, output_files,
|
||||
gzip_compress):
|
||||
"""Create TF example files from `TrainingInstance`s."""
|
||||
writers = []
|
||||
for output_file in output_files:
|
||||
writers.append(
|
||||
tf.io.TFRecordWriter(
|
||||
output_file, options="GZIP" if gzip_compress else ""))
|
||||
|
||||
writer_index = 0
|
||||
|
||||
total_written = 0
|
||||
for (inst_index, instance) in enumerate(instances):
|
||||
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = list(instance.segment_ids)
|
||||
assert len(input_ids) <= max_seq_length
|
||||
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
masked_lm_positions = list(instance.masked_lm_positions)
|
||||
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
||||
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
||||
|
||||
while len(masked_lm_positions) < max_predictions_per_seq:
|
||||
masked_lm_positions.append(0)
|
||||
masked_lm_ids.append(0)
|
||||
masked_lm_weights.append(0.0)
|
||||
|
||||
next_sentence_label = 1 if instance.is_random_next else 0
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(input_ids)
|
||||
features["input_mask"] = create_int_feature(input_mask)
|
||||
features["segment_ids"] = create_int_feature(segment_ids)
|
||||
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
||||
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
||||
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
||||
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
|
||||
writers[writer_index].write(tf_example.SerializeToString())
|
||||
writer_index = (writer_index + 1) % len(writers)
|
||||
|
||||
total_written += 1
|
||||
|
||||
if inst_index < 20:
|
||||
logging.info("*** Example ***")
|
||||
logging.info("tokens: %s", " ".join(
|
||||
[tokenization.printable_text(x) for x in instance.tokens]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
values = []
|
||||
if feature.int64_list.value:
|
||||
values = feature.int64_list.value
|
||||
elif feature.float_list.value:
|
||||
values = feature.float_list.value
|
||||
logging.info("%s: %s", feature_name, " ".join([str(x) for x in values]))
|
||||
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
def create_int_feature(values):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_float_feature(values):
|
||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_training_instances(input_files,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
dupe_factor,
|
||||
short_seq_prob,
|
||||
masked_lm_prob,
|
||||
max_predictions_per_seq,
|
||||
rng,
|
||||
do_whole_word_mask=False):
|
||||
"""Create `TrainingInstance`s from raw text."""
|
||||
all_documents = [[]]
|
||||
|
||||
# Input file format:
|
||||
# (1) One sentence per line. These should ideally be actual sentences, not
|
||||
# entire paragraphs or arbitrary spans of text. (Because we use the
|
||||
# sentence boundaries for the "next sentence prediction" task).
|
||||
# (2) Blank lines between documents. Document boundaries are needed so
|
||||
# that the "next sentence prediction" task doesn't span between documents.
|
||||
for input_file in input_files:
|
||||
with tf.io.gfile.GFile(input_file, "rb") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
|
||||
# Empty lines are used as document delimiters
|
||||
if not line:
|
||||
all_documents.append([])
|
||||
tokens = tokenizer.tokenize(line)
|
||||
if tokens:
|
||||
all_documents[-1].append(tokens)
|
||||
|
||||
# Remove empty documents
|
||||
all_documents = [x for x in all_documents if x]
|
||||
rng.shuffle(all_documents)
|
||||
|
||||
vocab_words = list(tokenizer.vocab.keys())
|
||||
instances = []
|
||||
for _ in range(dupe_factor):
|
||||
for document_index in range(len(all_documents)):
|
||||
instances.extend(
|
||||
create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
|
||||
do_whole_word_mask))
|
||||
|
||||
rng.shuffle(instances)
|
||||
return instances
|
||||
|
||||
|
||||
def create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
|
||||
do_whole_word_mask=False):
|
||||
"""Creates `TrainingInstance`s for a single document."""
|
||||
document = all_documents[document_index]
|
||||
|
||||
# Account for [CLS], [SEP], [SEP]
|
||||
max_num_tokens = max_seq_length - 3
|
||||
|
||||
# We *usually* want to fill up the entire sequence since we are padding
|
||||
# to `max_seq_length` anyways, so short sequences are generally wasted
|
||||
# computation. However, we *sometimes*
|
||||
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||
# The `target_seq_length` is just a rough target however, whereas
|
||||
# `max_seq_length` is a hard limit.
|
||||
target_seq_length = max_num_tokens
|
||||
if rng.random() < short_seq_prob:
|
||||
target_seq_length = rng.randint(2, max_num_tokens)
|
||||
|
||||
# We DON'T just concatenate all of the tokens from a document into a long
|
||||
# sequence and choose an arbitrary split point because this would make the
|
||||
# next sentence prediction task too easy. Instead, we split the input into
|
||||
# segments "A" and "B" based on the actual "sentences" provided by the user
|
||||
# input.
|
||||
instances = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i = 0
|
||||
while i < len(document):
|
||||
segment = document[i]
|
||||
current_chunk.append(segment)
|
||||
current_length += len(segment)
|
||||
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||
if current_chunk:
|
||||
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||
# (first) sentence.
|
||||
a_end = 1
|
||||
if len(current_chunk) >= 2:
|
||||
a_end = rng.randint(1, len(current_chunk) - 1)
|
||||
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(current_chunk[j])
|
||||
|
||||
tokens_b = []
|
||||
# Random next
|
||||
is_random_next = False
|
||||
if len(current_chunk) == 1 or rng.random() < 0.5:
|
||||
is_random_next = True
|
||||
target_b_length = target_seq_length - len(tokens_a)
|
||||
|
||||
# This should rarely go for more than one iteration for large
|
||||
# corpora. However, just to be careful, we try to make sure that
|
||||
# the random document is not the same as the document
|
||||
# we're processing.
|
||||
for _ in range(10):
|
||||
random_document_index = rng.randint(0, len(all_documents) - 1)
|
||||
if random_document_index != document_index:
|
||||
break
|
||||
|
||||
random_document = all_documents[random_document_index]
|
||||
random_start = rng.randint(0, len(random_document) - 1)
|
||||
for j in range(random_start, len(random_document)):
|
||||
tokens_b.extend(random_document[j])
|
||||
if len(tokens_b) >= target_b_length:
|
||||
break
|
||||
# We didn't actually use these segments so we "put them back" so
|
||||
# they don't go to waste.
|
||||
num_unused_segments = len(current_chunk) - a_end
|
||||
i -= num_unused_segments
|
||||
# Actual next
|
||||
else:
|
||||
is_random_next = False
|
||||
for j in range(a_end, len(current_chunk)):
|
||||
tokens_b.extend(current_chunk[j])
|
||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
||||
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
(tokens, masked_lm_positions,
|
||||
masked_lm_labels) = create_masked_lm_predictions(
|
||||
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
|
||||
do_whole_word_mask)
|
||||
instance = TrainingInstance(
|
||||
tokens=tokens,
|
||||
segment_ids=segment_ids,
|
||||
is_random_next=is_random_next,
|
||||
masked_lm_positions=masked_lm_positions,
|
||||
masked_lm_labels=masked_lm_labels)
|
||||
instances.append(instance)
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i += 1
|
||||
|
||||
return instances
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||||
["index", "label"])
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
||||
max_predictions_per_seq, vocab_words, rng,
|
||||
do_whole_word_mask):
|
||||
"""Creates the predictions for the masked LM objective."""
|
||||
|
||||
cand_indexes = []
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
# Whole Word Masking means that if we mask all of the wordpieces
|
||||
# corresponding to an original word. When a word has been split into
|
||||
# WordPieces, the first token does not have any marker and any subsequence
|
||||
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||||
# append it to the previous set of word indexes.
|
||||
#
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
||||
token.startswith("##")):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
|
||||
rng.shuffle(cand_indexes)
|
||||
|
||||
output_tokens = list(tokens)
|
||||
|
||||
num_to_predict = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
for index_set in cand_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
covered_indexes.add(index)
|
||||
|
||||
masked_token = None
|
||||
# 80% of the time, replace with [MASK]
|
||||
if rng.random() < 0.8:
|
||||
masked_token = "[MASK]"
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
|
||||
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||
assert len(masked_lms) <= num_to_predict
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
||||
|
||||
|
||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
|
||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||
assert len(trunc_tokens) >= 1
|
||||
|
||||
# We want to sometimes truncate from the front and sometimes from the
|
||||
# back to add more randomness and avoid biases.
|
||||
if rng.random() < 0.5:
|
||||
del trunc_tokens[0]
|
||||
else:
|
||||
trunc_tokens.pop()
|
||||
|
||||
|
||||
def main(_):
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
|
||||
input_files = []
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.io.gfile.glob(input_pattern))
|
||||
|
||||
logging.info("*** Reading from input files ***")
|
||||
for input_file in input_files:
|
||||
logging.info(" %s", input_file)
|
||||
|
||||
rng = random.Random(FLAGS.random_seed)
|
||||
instances = create_training_instances(
|
||||
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
||||
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
||||
rng, FLAGS.do_whole_word_mask)
|
||||
|
||||
output_files = FLAGS.output_file.split(",")
|
||||
logging.info("*** Writing to output files ***")
|
||||
for output_file in output_files:
|
||||
logging.info(" %s", output_file)
|
||||
|
||||
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
||||
FLAGS.max_predictions_per_seq, output_files,
|
||||
FLAGS.gzip_compress)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("output_file")
|
||||
flags.mark_flag_as_required("vocab_file")
|
||||
app.run(main)
|
||||
+880
@@ -0,0 +1,880 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Library to process data for SQuAD 1.1 and SQuAD 2.0."""
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import six
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.bert import tokenization
|
||||
|
||||
|
||||
class SquadExample(object):
|
||||
"""A single training/test example for simple sequence classification.
|
||||
|
||||
For examples without an answer, the start and end position are -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
qas_id,
|
||||
question_text,
|
||||
doc_tokens,
|
||||
orig_answer_text=None,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=False):
|
||||
self.qas_id = qas_id
|
||||
self.question_text = question_text
|
||||
self.doc_tokens = doc_tokens
|
||||
self.orig_answer_text = orig_answer_text
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
s = ""
|
||||
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
|
||||
s += ", question_text: %s" % (
|
||||
tokenization.printable_text(self.question_text))
|
||||
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
||||
if self.start_position:
|
||||
s += ", start_position: %d" % (self.start_position)
|
||||
if self.start_position:
|
||||
s += ", end_position: %d" % (self.end_position)
|
||||
if self.start_position:
|
||||
s += ", is_impossible: %r" % (self.is_impossible)
|
||||
return s
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
token_is_max_context,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
self.unique_id = unique_id
|
||||
self.example_index = example_index
|
||||
self.doc_span_index = doc_span_index
|
||||
self.tokens = tokens
|
||||
self.token_to_orig_map = token_to_orig_map
|
||||
self.token_is_max_context = token_is_max_context
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
|
||||
class FeatureWriter(object):
|
||||
"""Writes InputFeature to TF example file."""
|
||||
|
||||
def __init__(self, filename, is_training):
|
||||
self.filename = filename
|
||||
self.is_training = is_training
|
||||
self.num_features = 0
|
||||
self._writer = tf.io.TFRecordWriter(filename)
|
||||
|
||||
def process_feature(self, feature):
|
||||
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
|
||||
self.num_features += 1
|
||||
|
||||
def create_int_feature(values):
|
||||
feature = tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=list(values)))
|
||||
return feature
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["unique_ids"] = create_int_feature([feature.unique_id])
|
||||
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
|
||||
if self.is_training:
|
||||
features["start_positions"] = create_int_feature([feature.start_position])
|
||||
features["end_positions"] = create_int_feature([feature.end_position])
|
||||
impossible = 0
|
||||
if feature.is_impossible:
|
||||
impossible = 1
|
||||
features["is_impossible"] = create_int_feature([impossible])
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
self._writer.write(tf_example.SerializeToString())
|
||||
|
||||
def close(self):
|
||||
self._writer.close()
|
||||
|
||||
|
||||
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
"""Read a SQuAD json file into a list of SquadExample."""
|
||||
with tf.io.gfile.GFile(input_file, "r") as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
|
||||
def is_whitespace(c):
|
||||
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||
return True
|
||||
return False
|
||||
|
||||
examples = []
|
||||
for entry in input_data:
|
||||
for paragraph in entry["paragraphs"]:
|
||||
paragraph_text = paragraph["context"]
|
||||
doc_tokens = []
|
||||
char_to_word_offset = []
|
||||
prev_is_whitespace = True
|
||||
for c in paragraph_text:
|
||||
if is_whitespace(c):
|
||||
prev_is_whitespace = True
|
||||
else:
|
||||
if prev_is_whitespace:
|
||||
doc_tokens.append(c)
|
||||
else:
|
||||
doc_tokens[-1] += c
|
||||
prev_is_whitespace = False
|
||||
char_to_word_offset.append(len(doc_tokens) - 1)
|
||||
|
||||
for qa in paragraph["qas"]:
|
||||
qas_id = qa["id"]
|
||||
question_text = qa["question"]
|
||||
start_position = None
|
||||
end_position = None
|
||||
orig_answer_text = None
|
||||
is_impossible = False
|
||||
if is_training:
|
||||
|
||||
if version_2_with_negative:
|
||||
is_impossible = qa["is_impossible"]
|
||||
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||
raise ValueError(
|
||||
"For training, each question should have exactly 1 answer.")
|
||||
if not is_impossible:
|
||||
answer = qa["answers"][0]
|
||||
orig_answer_text = answer["text"]
|
||||
answer_offset = answer["answer_start"]
|
||||
answer_length = len(orig_answer_text)
|
||||
start_position = char_to_word_offset[answer_offset]
|
||||
end_position = char_to_word_offset[answer_offset + answer_length -
|
||||
1]
|
||||
# Only add answers where the text can be exactly recovered from the
|
||||
# document. If this CAN'T happen it's likely due to weird Unicode
|
||||
# stuff so we will just skip the example.
|
||||
#
|
||||
# Note that this means for training mode, every example is NOT
|
||||
# guaranteed to be preserved.
|
||||
actual_text = " ".join(
|
||||
doc_tokens[start_position:(end_position + 1)])
|
||||
cleaned_answer_text = " ".join(
|
||||
tokenization.whitespace_tokenize(orig_answer_text))
|
||||
if actual_text.find(cleaned_answer_text) == -1:
|
||||
logging.warning("Could not find answer: '%s' vs. '%s'",
|
||||
actual_text, cleaned_answer_text)
|
||||
continue
|
||||
else:
|
||||
start_position = -1
|
||||
end_position = -1
|
||||
orig_answer_text = ""
|
||||
|
||||
example = SquadExample(
|
||||
qas_id=qas_id,
|
||||
question_text=question_text,
|
||||
doc_tokens=doc_tokens,
|
||||
orig_answer_text=orig_answer_text,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=is_impossible)
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
doc_stride,
|
||||
max_query_length,
|
||||
is_training,
|
||||
output_fn,
|
||||
batch_size=None):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
base_id = 1000000000
|
||||
unique_id = base_id
|
||||
feature = None
|
||||
for (example_index, example) in enumerate(examples):
|
||||
query_tokens = tokenizer.tokenize(example.question_text)
|
||||
|
||||
if len(query_tokens) > max_query_length:
|
||||
query_tokens = query_tokens[0:max_query_length]
|
||||
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
for (i, token) in enumerate(example.doc_tokens):
|
||||
orig_to_tok_index.append(len(all_doc_tokens))
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
for sub_token in sub_tokens:
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
|
||||
tok_start_position = None
|
||||
tok_end_position = None
|
||||
if is_training and example.is_impossible:
|
||||
tok_start_position = -1
|
||||
tok_end_position = -1
|
||||
if is_training and not example.is_impossible:
|
||||
tok_start_position = orig_to_tok_index[example.start_position]
|
||||
if example.end_position < len(example.doc_tokens) - 1:
|
||||
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
||||
else:
|
||||
tok_end_position = len(all_doc_tokens) - 1
|
||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
|
||||
example.orig_answer_text)
|
||||
|
||||
# The -3 accounts for [CLS], [SEP] and [SEP]
|
||||
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
||||
|
||||
# We can have documents that are longer than the maximum sequence length.
|
||||
# To deal with this we do a sliding window approach, where we take chunks
|
||||
# of the up to our max length with a stride of `doc_stride`.
|
||||
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"DocSpan", ["start", "length"])
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
length = len(all_doc_tokens) - start_offset
|
||||
if length > max_tokens_for_doc:
|
||||
length = max_tokens_for_doc
|
||||
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
||||
if start_offset + length == len(all_doc_tokens):
|
||||
break
|
||||
start_offset += min(length, doc_stride)
|
||||
|
||||
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
||||
tokens = []
|
||||
token_to_orig_map = {}
|
||||
token_is_max_context = {}
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
for i in range(doc_span.length):
|
||||
split_token_index = doc_span.start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
|
||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
||||
split_token_index)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
start_position = None
|
||||
end_position = None
|
||||
if is_training and not example.is_impossible:
|
||||
# For training, if our document chunk does not contain an annotation
|
||||
# we throw it out, since there is nothing to predict.
|
||||
doc_start = doc_span.start
|
||||
doc_end = doc_span.start + doc_span.length - 1
|
||||
out_of_span = False
|
||||
if not (tok_start_position >= doc_start and
|
||||
tok_end_position <= doc_end):
|
||||
out_of_span = True
|
||||
if out_of_span:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
else:
|
||||
doc_offset = len(query_tokens) + 2
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
if is_training and example.is_impossible:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
|
||||
if example_index < 20:
|
||||
logging.info("*** Example ***")
|
||||
logging.info("unique_id: %s", (unique_id))
|
||||
logging.info("example_index: %s", (example_index))
|
||||
logging.info("doc_span_index: %s", (doc_span_index))
|
||||
logging.info("tokens: %s",
|
||||
" ".join([tokenization.printable_text(x) for x in tokens]))
|
||||
logging.info(
|
||||
"token_to_orig_map: %s", " ".join([
|
||||
"%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)
|
||||
]))
|
||||
logging.info(
|
||||
"token_is_max_context: %s", " ".join([
|
||||
"%d:%s" % (x, y)
|
||||
for (x, y) in six.iteritems(token_is_max_context)
|
||||
]))
|
||||
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
|
||||
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
|
||||
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
|
||||
if is_training and example.is_impossible:
|
||||
logging.info("impossible example")
|
||||
if is_training and not example.is_impossible:
|
||||
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
||||
logging.info("start_position: %d", (start_position))
|
||||
logging.info("end_position: %d", (end_position))
|
||||
logging.info("answer: %s", tokenization.printable_text(answer_text))
|
||||
|
||||
feature = InputFeatures(
|
||||
unique_id=unique_id,
|
||||
example_index=example_index,
|
||||
doc_span_index=doc_span_index,
|
||||
tokens=tokens,
|
||||
token_to_orig_map=token_to_orig_map,
|
||||
token_is_max_context=token_is_max_context,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=example.is_impossible)
|
||||
|
||||
# Run callback
|
||||
if is_training:
|
||||
output_fn(feature)
|
||||
else:
|
||||
output_fn(feature, is_padding=False)
|
||||
|
||||
unique_id += 1
|
||||
|
||||
if not is_training and feature:
|
||||
assert batch_size
|
||||
num_padding = 0
|
||||
num_examples = unique_id - base_id
|
||||
if unique_id % batch_size != 0:
|
||||
num_padding = batch_size - (num_examples % batch_size)
|
||||
logging.info("Adding padding examples to make sure no partial batch.")
|
||||
logging.info("Adds %d padding examples for inference.", num_padding)
|
||||
dummy_feature = copy.deepcopy(feature)
|
||||
for _ in range(num_padding):
|
||||
dummy_feature.unique_id = unique_id
|
||||
|
||||
# Run callback
|
||||
output_fn(feature, is_padding=True)
|
||||
unique_id += 1
|
||||
return unique_id - base_id
|
||||
|
||||
|
||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
||||
orig_answer_text):
|
||||
"""Returns tokenized answer spans that better match the annotated answer."""
|
||||
|
||||
# The SQuAD annotations are character based. We first project them to
|
||||
# whitespace-tokenized words. But then after WordPiece tokenization, we can
|
||||
# often find a "better match". For example:
|
||||
#
|
||||
# Question: What year was John Smith born?
|
||||
# Context: The leader was John Smith (1895-1943).
|
||||
# Answer: 1895
|
||||
#
|
||||
# The original whitespace-tokenized answer will be "(1895-1943).". However
|
||||
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
|
||||
# the exact answer, 1895.
|
||||
#
|
||||
# However, this is not always possible. Consider the following:
|
||||
#
|
||||
# Question: What country is the top exporter of electornics?
|
||||
# Context: The Japanese electronics industry is the lagest in the world.
|
||||
# Answer: Japan
|
||||
#
|
||||
# In this case, the annotator chose "Japan" as a character sub-span of
|
||||
# the word "Japanese". Since our WordPiece tokenizer does not split
|
||||
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
|
||||
# in SQuAD, but does happen.
|
||||
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
||||
|
||||
for new_start in range(input_start, input_end + 1):
|
||||
for new_end in range(input_end, new_start - 1, -1):
|
||||
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
|
||||
if text_span == tok_answer_text:
|
||||
return (new_start, new_end)
|
||||
|
||||
return (input_start, input_end)
|
||||
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
|
||||
# Because of the sliding window approach taken to scoring documents, a single
|
||||
# token can appear in multiple documents. E.g.
|
||||
# Doc: the man went to the store and bought a gallon of milk
|
||||
# Span A: the man went to the
|
||||
# Span B: to the store and bought
|
||||
# Span C: and bought a gallon of
|
||||
# ...
|
||||
#
|
||||
# Now the word 'bought' will have two scores from spans B and C. We only
|
||||
# want to consider the score with "maximum context", which we define as
|
||||
# the *minimum* of its left and right context (the *sum* of left and
|
||||
# right context will always be the same, of course).
|
||||
#
|
||||
# In the example the maximum context for 'bought' would be span C since
|
||||
# it has 1 left context and 3 right context, while span B has 4 left context
|
||||
# and 0 right context.
|
||||
best_score = None
|
||||
best_span_index = None
|
||||
for (span_index, doc_span) in enumerate(doc_spans):
|
||||
end = doc_span.start + doc_span.length - 1
|
||||
if position < doc_span.start:
|
||||
continue
|
||||
if position > end:
|
||||
continue
|
||||
num_left_context = position - doc_span.start
|
||||
num_right_context = end - position
|
||||
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best_span_index = span_index
|
||||
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
def write_predictions(all_examples,
|
||||
all_features,
|
||||
all_results,
|
||||
n_best_size,
|
||||
max_answer_length,
|
||||
do_lower_case,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
version_2_with_negative=False,
|
||||
null_score_diff_threshold=0.0,
|
||||
verbose=False):
|
||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||
logging.info("Writing predictions to: %s", (output_prediction_file))
|
||||
logging.info("Writing nbest to: %s", (output_nbest_file))
|
||||
|
||||
all_predictions, all_nbest_json, scores_diff_json = (
|
||||
postprocess_output(all_examples=all_examples,
|
||||
all_features=all_features,
|
||||
all_results=all_results,
|
||||
n_best_size=n_best_size,
|
||||
max_answer_length=max_answer_length,
|
||||
do_lower_case=do_lower_case,
|
||||
version_2_with_negative=version_2_with_negative,
|
||||
null_score_diff_threshold=null_score_diff_threshold,
|
||||
verbose=verbose))
|
||||
|
||||
write_to_json_files(all_predictions, output_prediction_file)
|
||||
write_to_json_files(all_nbest_json, output_nbest_file)
|
||||
if version_2_with_negative:
|
||||
write_to_json_files(scores_diff_json, output_null_log_odds_file)
|
||||
|
||||
|
||||
def postprocess_output(all_examples,
|
||||
all_features,
|
||||
all_results,
|
||||
n_best_size,
|
||||
max_answer_length,
|
||||
do_lower_case,
|
||||
version_2_with_negative=False,
|
||||
null_score_diff_threshold=0.0,
|
||||
verbose=False):
|
||||
"""Postprocess model output, to form predicton results."""
|
||||
|
||||
example_index_to_features = collections.defaultdict(list)
|
||||
for feature in all_features:
|
||||
example_index_to_features[feature.example_index].append(feature)
|
||||
unique_id_to_result = {}
|
||||
for result in all_results:
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
||||
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
scores_diff_json = collections.OrderedDict()
|
||||
|
||||
for (example_index, example) in enumerate(all_examples):
|
||||
features = example_index_to_features[example_index]
|
||||
|
||||
prelim_predictions = []
|
||||
# keep track of the minimum score of null start+end of position 0
|
||||
score_null = 1000000 # large and positive
|
||||
min_null_feature_index = 0 # the paragraph slice with min mull score
|
||||
null_start_logit = 0 # the start logit at the slice with min null score
|
||||
null_end_logit = 0 # the end logit at the slice with min null score
|
||||
for (feature_index, feature) in enumerate(features):
|
||||
result = unique_id_to_result[feature.unique_id]
|
||||
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
||||
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
||||
# if we could have irrelevant answers, get the min score of irrelevant
|
||||
if version_2_with_negative:
|
||||
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
||||
if feature_null_score < score_null:
|
||||
score_null = feature_null_score
|
||||
min_null_feature_index = feature_index
|
||||
null_start_logit = result.start_logits[0]
|
||||
null_end_logit = result.end_logits[0]
|
||||
for start_index in start_indexes:
|
||||
for end_index in end_indexes:
|
||||
# We could hypothetically create invalid predictions, e.g., predict
|
||||
# that the start of the span is in the question. We throw out all
|
||||
# invalid predictions.
|
||||
if start_index >= len(feature.tokens):
|
||||
continue
|
||||
if end_index >= len(feature.tokens):
|
||||
continue
|
||||
if start_index not in feature.token_to_orig_map:
|
||||
continue
|
||||
if end_index not in feature.token_to_orig_map:
|
||||
continue
|
||||
if not feature.token_is_max_context.get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index:
|
||||
continue
|
||||
length = end_index - start_index + 1
|
||||
if length > max_answer_length:
|
||||
continue
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=feature_index,
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_logit=result.start_logits[start_index],
|
||||
end_logit=result.end_logits[end_index]))
|
||||
|
||||
if version_2_with_negative:
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=min_null_feature_index,
|
||||
start_index=0,
|
||||
end_index=0,
|
||||
start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_logit + x.end_logit),
|
||||
reverse=True)
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
for pred in prelim_predictions:
|
||||
if len(nbest) >= n_best_size:
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
if pred.start_index > 0: # this is a non-null prediction
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
tok_text = " ".join(tok_tokens)
|
||||
|
||||
# De-tokenize WordPieces that have been split off.
|
||||
tok_text = tok_text.replace(" ##", "")
|
||||
tok_text = tok_text.replace("##", "")
|
||||
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
|
||||
final_text = get_final_text(
|
||||
tok_text, orig_text, do_lower_case, verbose=verbose)
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
|
||||
seen_predictions[final_text] = True
|
||||
else:
|
||||
final_text = ""
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_logit=pred.start_logit,
|
||||
end_logit=pred.end_logit))
|
||||
|
||||
# if we didn't inlude the empty option in the n-best, inlcude it
|
||||
if version_2_with_negative:
|
||||
if "" not in seen_predictions:
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text="", start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
assert len(nbest) >= 1
|
||||
|
||||
total_scores = []
|
||||
best_non_null_entry = None
|
||||
for entry in nbest:
|
||||
total_scores.append(entry.start_logit + entry.end_logit)
|
||||
if not best_non_null_entry:
|
||||
if entry.text:
|
||||
best_non_null_entry = entry
|
||||
|
||||
probs = _compute_softmax(total_scores)
|
||||
|
||||
nbest_json = []
|
||||
for (i, entry) in enumerate(nbest):
|
||||
output = collections.OrderedDict()
|
||||
output["text"] = entry.text
|
||||
output["probability"] = probs[i]
|
||||
output["start_logit"] = entry.start_logit
|
||||
output["end_logit"] = entry.end_logit
|
||||
nbest_json.append(output)
|
||||
|
||||
assert len(nbest_json) >= 1
|
||||
|
||||
if not version_2_with_negative:
|
||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||
else:
|
||||
# pytype: disable=attribute-error
|
||||
# predict "" iff the null score - the score of best non-null > threshold
|
||||
score_diff = score_null - best_non_null_entry.start_logit - (
|
||||
best_non_null_entry.end_logit)
|
||||
scores_diff_json[example.qas_id] = score_diff
|
||||
if score_diff > null_score_diff_threshold:
|
||||
all_predictions[example.qas_id] = ""
|
||||
else:
|
||||
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||
# pytype: enable=attribute-error
|
||||
|
||||
all_nbest_json[example.qas_id] = nbest_json
|
||||
|
||||
return all_predictions, all_nbest_json, scores_diff_json
|
||||
|
||||
|
||||
def write_to_json_files(json_records, json_file):
|
||||
with tf.io.gfile.GFile(json_file, "w") as writer:
|
||||
writer.write(json.dumps(json_records, indent=4) + "\n")
|
||||
|
||||
|
||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
|
||||
"""Project the tokenized prediction back to the original text."""
|
||||
|
||||
# When we created the data, we kept track of the alignment between original
|
||||
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
||||
# now `orig_text` contains the span of our original text corresponding to the
|
||||
# span that we predicted.
|
||||
#
|
||||
# However, `orig_text` may contain extra characters that we don't want in
|
||||
# our prediction.
|
||||
#
|
||||
# For example, let's say:
|
||||
# pred_text = steve smith
|
||||
# orig_text = Steve Smith's
|
||||
#
|
||||
# We don't want to return `orig_text` because it contains the extra "'s".
|
||||
#
|
||||
# We don't want to return `pred_text` because it's already been normalized
|
||||
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
||||
# our tokenizer does additional normalization like stripping accent
|
||||
# characters).
|
||||
#
|
||||
# What we really want to return is "Steve Smith".
|
||||
#
|
||||
# Therefore, we have to apply a semi-complicated alignment heruistic between
|
||||
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
|
||||
# can fail in certain cases in which case we just return `orig_text`.
|
||||
|
||||
def _strip_spaces(text):
|
||||
ns_chars = []
|
||||
ns_to_s_map = collections.OrderedDict()
|
||||
for (i, c) in enumerate(text):
|
||||
if c == " ":
|
||||
continue
|
||||
ns_to_s_map[len(ns_chars)] = i
|
||||
ns_chars.append(c)
|
||||
ns_text = "".join(ns_chars)
|
||||
return (ns_text, ns_to_s_map)
|
||||
|
||||
# We first tokenize `orig_text`, strip whitespace from the result
|
||||
# and `pred_text`, and check if they are the same length. If they are
|
||||
# NOT the same length, the heuristic has failed. If they are the same
|
||||
# length, we assume the characters are one-to-one aligned.
|
||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
|
||||
|
||||
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
||||
|
||||
start_position = tok_text.find(pred_text)
|
||||
if start_position == -1:
|
||||
if verbose:
|
||||
logging.info("Unable to find text: '%s' in '%s'", pred_text, orig_text)
|
||||
return orig_text
|
||||
end_position = start_position + len(pred_text) - 1
|
||||
|
||||
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
||||
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
||||
|
||||
if len(orig_ns_text) != len(tok_ns_text):
|
||||
if verbose:
|
||||
logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
||||
orig_ns_text, tok_ns_text)
|
||||
return orig_text
|
||||
|
||||
# We then project the characters in `pred_text` back to `orig_text` using
|
||||
# the character-to-character alignment.
|
||||
tok_s_to_ns_map = {}
|
||||
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
|
||||
tok_s_to_ns_map[tok_index] = i
|
||||
|
||||
orig_start_position = None
|
||||
if start_position in tok_s_to_ns_map:
|
||||
ns_start_position = tok_s_to_ns_map[start_position]
|
||||
if ns_start_position in orig_ns_to_s_map:
|
||||
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
||||
|
||||
if orig_start_position is None:
|
||||
if verbose:
|
||||
logging.info("Couldn't map start position")
|
||||
return orig_text
|
||||
|
||||
orig_end_position = None
|
||||
if end_position in tok_s_to_ns_map:
|
||||
ns_end_position = tok_s_to_ns_map[end_position]
|
||||
if ns_end_position in orig_ns_to_s_map:
|
||||
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
||||
|
||||
if orig_end_position is None:
|
||||
if verbose:
|
||||
logging.info("Couldn't map end position")
|
||||
return orig_text
|
||||
|
||||
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
|
||||
def _get_best_indexes(logits, n_best_size):
|
||||
"""Get the n-best logits from a list."""
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
|
||||
best_indexes = []
|
||||
for i in range(len(index_and_score)): # pylint: disable=consider-using-enumerate
|
||||
if i >= n_best_size:
|
||||
break
|
||||
best_indexes.append(index_and_score[i][0])
|
||||
return best_indexes
|
||||
|
||||
|
||||
def _compute_softmax(scores):
|
||||
"""Compute softmax probability over raw logits."""
|
||||
if not scores:
|
||||
return []
|
||||
|
||||
max_score = None
|
||||
for score in scores:
|
||||
if max_score is None or score > max_score:
|
||||
max_score = score
|
||||
|
||||
exp_scores = []
|
||||
total_sum = 0.0
|
||||
for score in scores:
|
||||
x = math.exp(score - max_score)
|
||||
exp_scores.append(x)
|
||||
total_sum += x
|
||||
|
||||
probs = []
|
||||
for score in exp_scores:
|
||||
probs.append(score / total_sum)
|
||||
return probs
|
||||
|
||||
|
||||
def generate_tf_record_from_json_file(input_file_path,
|
||||
vocab_file_path,
|
||||
output_path,
|
||||
max_seq_length=384,
|
||||
do_lower_case=True,
|
||||
max_query_length=64,
|
||||
doc_stride=128,
|
||||
version_2_with_negative=False):
|
||||
"""Generates and saves training data into a tf record file."""
|
||||
train_examples = read_squad_examples(
|
||||
input_file=input_file_path,
|
||||
is_training=True,
|
||||
version_2_with_negative=version_2_with_negative)
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=vocab_file_path, do_lower_case=do_lower_case)
|
||||
train_writer = FeatureWriter(filename=output_path, is_training=True)
|
||||
number_of_examples = convert_examples_to_features(
|
||||
examples=train_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=max_seq_length,
|
||||
doc_stride=doc_stride,
|
||||
max_query_length=max_query_length,
|
||||
is_training=True,
|
||||
output_fn=train_writer.process_feature)
|
||||
train_writer.close()
|
||||
|
||||
meta_data = {
|
||||
"task_type": "bert_squad",
|
||||
"train_data_size": number_of_examples,
|
||||
"max_seq_length": max_seq_length,
|
||||
"max_query_length": max_query_length,
|
||||
"doc_stride": doc_stride,
|
||||
"version_2_with_negative": version_2_with_negative,
|
||||
}
|
||||
|
||||
return meta_data
|
||||
+890
@@ -0,0 +1,890 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization.
|
||||
|
||||
The file is forked from:
|
||||
|
||||
https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.bert import tokenization
|
||||
|
||||
|
||||
class SquadExample(object):
|
||||
"""A single training/test example for simple sequence classification.
|
||||
|
||||
For examples without an answer, the start and end position are -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
qas_id,
|
||||
question_text,
|
||||
paragraph_text,
|
||||
orig_answer_text=None,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=False):
|
||||
self.qas_id = qas_id
|
||||
self.question_text = question_text
|
||||
self.paragraph_text = paragraph_text
|
||||
self.orig_answer_text = orig_answer_text
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
s = ""
|
||||
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
|
||||
s += ", question_text: %s" % (
|
||||
tokenization.printable_text(self.question_text))
|
||||
s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
|
||||
if self.start_position:
|
||||
s += ", start_position: %d" % (self.start_position)
|
||||
if self.start_position:
|
||||
s += ", end_position: %d" % (self.end_position)
|
||||
if self.start_position:
|
||||
s += ", is_impossible: %r" % (self.is_impossible)
|
||||
return s
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tok_start_to_orig_index,
|
||||
tok_end_to_orig_index,
|
||||
token_is_max_context,
|
||||
tokens,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
paragraph_len,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
self.unique_id = unique_id
|
||||
self.example_index = example_index
|
||||
self.doc_span_index = doc_span_index
|
||||
self.tok_start_to_orig_index = tok_start_to_orig_index
|
||||
self.tok_end_to_orig_index = tok_end_to_orig_index
|
||||
self.token_is_max_context = token_is_max_context
|
||||
self.tokens = tokens
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.paragraph_len = paragraph_len
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
|
||||
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
"""Read a SQuAD json file into a list of SquadExample."""
|
||||
del version_2_with_negative
|
||||
with tf.io.gfile.GFile(input_file, "r") as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
|
||||
examples = []
|
||||
for entry in input_data:
|
||||
for paragraph in entry["paragraphs"]:
|
||||
paragraph_text = paragraph["context"]
|
||||
|
||||
for qa in paragraph["qas"]:
|
||||
qas_id = qa["id"]
|
||||
question_text = qa["question"]
|
||||
start_position = None
|
||||
orig_answer_text = None
|
||||
is_impossible = False
|
||||
|
||||
if is_training:
|
||||
is_impossible = qa.get("is_impossible", False)
|
||||
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||
raise ValueError(
|
||||
"For training, each question should have exactly 1 answer.")
|
||||
if not is_impossible:
|
||||
answer = qa["answers"][0]
|
||||
orig_answer_text = answer["text"]
|
||||
start_position = answer["answer_start"]
|
||||
else:
|
||||
start_position = -1
|
||||
orig_answer_text = ""
|
||||
|
||||
example = SquadExample(
|
||||
qas_id=qas_id,
|
||||
question_text=question_text,
|
||||
paragraph_text=paragraph_text,
|
||||
orig_answer_text=orig_answer_text,
|
||||
start_position=start_position,
|
||||
is_impossible=is_impossible)
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def _convert_index(index, pos, m=None, is_start=True):
|
||||
"""Converts index."""
|
||||
if index[pos] is not None:
|
||||
return index[pos]
|
||||
n = len(index)
|
||||
rear = pos
|
||||
while rear < n - 1 and index[rear] is None:
|
||||
rear += 1
|
||||
front = pos
|
||||
while front > 0 and index[front] is None:
|
||||
front -= 1
|
||||
assert index[front] is not None or index[rear] is not None
|
||||
if index[front] is None:
|
||||
if index[rear] >= 1:
|
||||
if is_start:
|
||||
return 0
|
||||
else:
|
||||
return index[rear] - 1
|
||||
return index[rear]
|
||||
if index[rear] is None:
|
||||
if m is not None and index[front] < m - 1:
|
||||
if is_start:
|
||||
return index[front] + 1
|
||||
else:
|
||||
return m - 1
|
||||
return index[front]
|
||||
if is_start:
|
||||
if index[rear] > index[front] + 1:
|
||||
return index[front] + 1
|
||||
else:
|
||||
return index[rear]
|
||||
else:
|
||||
if index[rear] > index[front] + 1:
|
||||
return index[rear] - 1
|
||||
else:
|
||||
return index[front]
|
||||
|
||||
|
||||
def convert_examples_to_features(examples,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
doc_stride,
|
||||
max_query_length,
|
||||
is_training,
|
||||
output_fn,
|
||||
do_lower_case,
|
||||
batch_size=None):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
cnt_pos, cnt_neg = 0, 0
|
||||
base_id = 1000000000
|
||||
unique_id = base_id
|
||||
max_n, max_m = 1024, 1024
|
||||
f = np.zeros((max_n, max_m), dtype=np.float32)
|
||||
|
||||
for (example_index, example) in enumerate(examples):
|
||||
|
||||
if example_index % 100 == 0:
|
||||
logging.info("Converting %d/%d pos %d neg %d", example_index,
|
||||
len(examples), cnt_pos, cnt_neg)
|
||||
|
||||
query_tokens = tokenization.encode_ids(
|
||||
tokenizer.sp_model,
|
||||
tokenization.preprocess_text(
|
||||
example.question_text, lower=do_lower_case))
|
||||
|
||||
if len(query_tokens) > max_query_length:
|
||||
query_tokens = query_tokens[0:max_query_length]
|
||||
|
||||
paragraph_text = example.paragraph_text
|
||||
para_tokens = tokenization.encode_pieces(
|
||||
tokenizer.sp_model,
|
||||
tokenization.preprocess_text(
|
||||
example.paragraph_text, lower=do_lower_case))
|
||||
|
||||
chartok_to_tok_index = []
|
||||
tok_start_to_chartok_index = []
|
||||
tok_end_to_chartok_index = []
|
||||
char_cnt = 0
|
||||
for i, token in enumerate(para_tokens):
|
||||
new_token = token.replace(tokenization.SPIECE_UNDERLINE, " ")
|
||||
chartok_to_tok_index.extend([i] * len(new_token))
|
||||
tok_start_to_chartok_index.append(char_cnt)
|
||||
char_cnt += len(new_token)
|
||||
tok_end_to_chartok_index.append(char_cnt - 1)
|
||||
|
||||
tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE,
|
||||
" ")
|
||||
n, m = len(paragraph_text), len(tok_cat_text)
|
||||
|
||||
if n > max_n or m > max_m:
|
||||
max_n = max(n, max_n)
|
||||
max_m = max(m, max_m)
|
||||
f = np.zeros((max_n, max_m), dtype=np.float32)
|
||||
|
||||
g = {}
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def _lcs_match(max_dist, n=n, m=m):
|
||||
"""Longest-common-substring algorithm."""
|
||||
f.fill(0)
|
||||
g.clear()
|
||||
|
||||
### longest common sub sequence
|
||||
# f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
|
||||
for i in range(n):
|
||||
|
||||
# unlike standard LCS, this is specifically optimized for the setting
|
||||
# because the mismatch between sentence pieces and original text will
|
||||
# be small
|
||||
for j in range(i - max_dist, i + max_dist):
|
||||
if j >= m or j < 0:
|
||||
continue
|
||||
|
||||
if i > 0:
|
||||
g[(i, j)] = 0
|
||||
f[i, j] = f[i - 1, j]
|
||||
|
||||
if j > 0 and f[i, j - 1] > f[i, j]:
|
||||
g[(i, j)] = 1
|
||||
f[i, j] = f[i, j - 1]
|
||||
|
||||
f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
|
||||
if (tokenization.preprocess_text(
|
||||
paragraph_text[i], lower=do_lower_case,
|
||||
remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
|
||||
g[(i, j)] = 2
|
||||
f[i, j] = f_prev + 1
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
max_dist = abs(n - m) + 5
|
||||
for _ in range(2):
|
||||
_lcs_match(max_dist)
|
||||
if f[n - 1, m - 1] > 0.8 * n:
|
||||
break
|
||||
max_dist *= 2
|
||||
|
||||
orig_to_chartok_index = [None] * n
|
||||
chartok_to_orig_index = [None] * m
|
||||
i, j = n - 1, m - 1
|
||||
while i >= 0 and j >= 0:
|
||||
if (i, j) not in g:
|
||||
break
|
||||
if g[(i, j)] == 2:
|
||||
orig_to_chartok_index[i] = j
|
||||
chartok_to_orig_index[j] = i
|
||||
i, j = i - 1, j - 1
|
||||
elif g[(i, j)] == 1:
|
||||
j = j - 1
|
||||
else:
|
||||
i = i - 1
|
||||
|
||||
if (all(v is None for v in orig_to_chartok_index) or
|
||||
f[n - 1, m - 1] < 0.8 * n):
|
||||
logging.info("MISMATCH DETECTED!")
|
||||
continue
|
||||
|
||||
tok_start_to_orig_index = []
|
||||
tok_end_to_orig_index = []
|
||||
for i in range(len(para_tokens)):
|
||||
start_chartok_pos = tok_start_to_chartok_index[i]
|
||||
end_chartok_pos = tok_end_to_chartok_index[i]
|
||||
start_orig_pos = _convert_index(
|
||||
chartok_to_orig_index, start_chartok_pos, n, is_start=True)
|
||||
end_orig_pos = _convert_index(
|
||||
chartok_to_orig_index, end_chartok_pos, n, is_start=False)
|
||||
|
||||
tok_start_to_orig_index.append(start_orig_pos)
|
||||
tok_end_to_orig_index.append(end_orig_pos)
|
||||
|
||||
if not is_training:
|
||||
tok_start_position = tok_end_position = None
|
||||
|
||||
if is_training and example.is_impossible:
|
||||
tok_start_position = 0
|
||||
tok_end_position = 0
|
||||
|
||||
if is_training and not example.is_impossible:
|
||||
start_position = example.start_position
|
||||
end_position = start_position + len(example.orig_answer_text) - 1
|
||||
|
||||
start_chartok_pos = _convert_index(
|
||||
orig_to_chartok_index, start_position, is_start=True)
|
||||
tok_start_position = chartok_to_tok_index[start_chartok_pos]
|
||||
|
||||
end_chartok_pos = _convert_index(
|
||||
orig_to_chartok_index, end_position, is_start=False)
|
||||
tok_end_position = chartok_to_tok_index[end_chartok_pos]
|
||||
assert tok_start_position <= tok_end_position
|
||||
|
||||
def _piece_to_id(x):
|
||||
return tokenizer.sp_model.PieceToId(x)
|
||||
|
||||
all_doc_tokens = list(map(_piece_to_id, para_tokens))
|
||||
|
||||
# The -3 accounts for [CLS], [SEP] and [SEP]
|
||||
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
||||
|
||||
# We can have documents that are longer than the maximum sequence length.
|
||||
# To deal with this we do a sliding window approach, where we take chunks
|
||||
# of the up to our max length with a stride of `doc_stride`.
|
||||
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"DocSpan", ["start", "length"])
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
length = len(all_doc_tokens) - start_offset
|
||||
if length > max_tokens_for_doc:
|
||||
length = max_tokens_for_doc
|
||||
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
||||
if start_offset + length == len(all_doc_tokens):
|
||||
break
|
||||
start_offset += min(length, doc_stride)
|
||||
|
||||
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
||||
tokens = []
|
||||
token_is_max_context = {}
|
||||
segment_ids = []
|
||||
|
||||
cur_tok_start_to_orig_index = []
|
||||
cur_tok_end_to_orig_index = []
|
||||
|
||||
tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
|
||||
segment_ids.append(0)
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
|
||||
segment_ids.append(0)
|
||||
|
||||
for i in range(doc_span.length):
|
||||
split_token_index = doc_span.start + i
|
||||
|
||||
cur_tok_start_to_orig_index.append(
|
||||
tok_start_to_orig_index[split_token_index])
|
||||
cur_tok_end_to_orig_index.append(
|
||||
tok_end_to_orig_index[split_token_index])
|
||||
|
||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
||||
split_token_index)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(1)
|
||||
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
|
||||
segment_ids.append(1)
|
||||
|
||||
paragraph_len = len(tokens)
|
||||
input_ids = tokens
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
span_is_impossible = example.is_impossible
|
||||
start_position = None
|
||||
end_position = None
|
||||
if is_training and not span_is_impossible:
|
||||
# For training, if our document chunk does not contain an annotation
|
||||
# we throw it out, since there is nothing to predict.
|
||||
doc_start = doc_span.start
|
||||
doc_end = doc_span.start + doc_span.length - 1
|
||||
out_of_span = False
|
||||
if not (tok_start_position >= doc_start and
|
||||
tok_end_position <= doc_end):
|
||||
out_of_span = True
|
||||
if out_of_span:
|
||||
# continue
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
span_is_impossible = True
|
||||
else:
|
||||
doc_offset = len(query_tokens) + 2
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
if is_training and span_is_impossible:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
|
||||
if example_index < 20:
|
||||
logging.info("*** Example ***")
|
||||
logging.info("unique_id: %s", (unique_id))
|
||||
logging.info("example_index: %s", (example_index))
|
||||
logging.info("doc_span_index: %s", (doc_span_index))
|
||||
logging.info("tok_start_to_orig_index: %s",
|
||||
" ".join([str(x) for x in cur_tok_start_to_orig_index]))
|
||||
logging.info("tok_end_to_orig_index: %s",
|
||||
" ".join([str(x) for x in cur_tok_end_to_orig_index]))
|
||||
logging.info(
|
||||
"token_is_max_context: %s", " ".join(
|
||||
["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()]))
|
||||
logging.info(
|
||||
"input_pieces: %s",
|
||||
" ".join([tokenizer.sp_model.IdToPiece(x) for x in tokens]))
|
||||
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
|
||||
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
|
||||
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
|
||||
|
||||
if is_training and span_is_impossible:
|
||||
logging.info("impossible example span")
|
||||
|
||||
if is_training and not span_is_impossible:
|
||||
pieces = [
|
||||
tokenizer.sp_model.IdToPiece(token)
|
||||
for token in tokens[start_position:(end_position + 1)]
|
||||
]
|
||||
answer_text = tokenizer.sp_model.DecodePieces(pieces)
|
||||
logging.info("start_position: %d", (start_position))
|
||||
logging.info("end_position: %d", (end_position))
|
||||
logging.info("answer: %s", (tokenization.printable_text(answer_text)))
|
||||
|
||||
# With multi processing, the example_index is actually the index
|
||||
# within the current process therefore we use example_index=None
|
||||
# to avoid being used in the future.
|
||||
# The current code does not use example_index of training data.
|
||||
if is_training:
|
||||
feat_example_index = None
|
||||
else:
|
||||
feat_example_index = example_index
|
||||
|
||||
feature = InputFeatures(
|
||||
unique_id=unique_id,
|
||||
example_index=feat_example_index,
|
||||
doc_span_index=doc_span_index,
|
||||
tok_start_to_orig_index=cur_tok_start_to_orig_index,
|
||||
tok_end_to_orig_index=cur_tok_end_to_orig_index,
|
||||
token_is_max_context=token_is_max_context,
|
||||
tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
paragraph_len=paragraph_len,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=span_is_impossible)
|
||||
|
||||
# Run callback
|
||||
if is_training:
|
||||
output_fn(feature)
|
||||
else:
|
||||
output_fn(feature, is_padding=False)
|
||||
|
||||
unique_id += 1
|
||||
if span_is_impossible:
|
||||
cnt_neg += 1
|
||||
else:
|
||||
cnt_pos += 1
|
||||
|
||||
if not is_training and feature:
|
||||
assert batch_size
|
||||
num_padding = 0
|
||||
num_examples = unique_id - base_id
|
||||
if unique_id % batch_size != 0:
|
||||
num_padding = batch_size - (num_examples % batch_size)
|
||||
dummy_feature = copy.deepcopy(feature)
|
||||
for _ in range(num_padding):
|
||||
dummy_feature.unique_id = unique_id
|
||||
|
||||
# Run callback
|
||||
output_fn(feature, is_padding=True)
|
||||
unique_id += 1
|
||||
|
||||
logging.info("Total number of instances: %d = pos %d neg %d",
|
||||
cnt_pos + cnt_neg, cnt_pos, cnt_neg)
|
||||
return unique_id - base_id
|
||||
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
|
||||
# Because of the sliding window approach taken to scoring documents, a single
|
||||
# token can appear in multiple documents. E.g.
|
||||
# Doc: the man went to the store and bought a gallon of milk
|
||||
# Span A: the man went to the
|
||||
# Span B: to the store and bought
|
||||
# Span C: and bought a gallon of
|
||||
# ...
|
||||
#
|
||||
# Now the word 'bought' will have two scores from spans B and C. We only
|
||||
# want to consider the score with "maximum context", which we define as
|
||||
# the *minimum* of its left and right context (the *sum* of left and
|
||||
# right context will always be the same, of course).
|
||||
#
|
||||
# In the example the maximum context for 'bought' would be span C since
|
||||
# it has 1 left context and 3 right context, while span B has 4 left context
|
||||
# and 0 right context.
|
||||
best_score = None
|
||||
best_span_index = None
|
||||
for (span_index, doc_span) in enumerate(doc_spans):
|
||||
end = doc_span.start + doc_span.length - 1
|
||||
if position < doc_span.start:
|
||||
continue
|
||||
if position > end:
|
||||
continue
|
||||
num_left_context = position - doc_span.start
|
||||
num_right_context = end - position
|
||||
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best_span_index = span_index
|
||||
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
def write_predictions(all_examples,
|
||||
all_features,
|
||||
all_results,
|
||||
n_best_size,
|
||||
max_answer_length,
|
||||
do_lower_case,
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
version_2_with_negative=False,
|
||||
null_score_diff_threshold=0.0,
|
||||
verbose=False):
|
||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||
logging.info("Writing predictions to: %s", (output_prediction_file))
|
||||
logging.info("Writing nbest to: %s", (output_nbest_file))
|
||||
|
||||
all_predictions, all_nbest_json, scores_diff_json = (
|
||||
postprocess_output(all_examples=all_examples,
|
||||
all_features=all_features,
|
||||
all_results=all_results,
|
||||
n_best_size=n_best_size,
|
||||
max_answer_length=max_answer_length,
|
||||
do_lower_case=do_lower_case,
|
||||
version_2_with_negative=version_2_with_negative,
|
||||
null_score_diff_threshold=null_score_diff_threshold,
|
||||
verbose=verbose))
|
||||
|
||||
write_to_json_files(all_predictions, output_prediction_file)
|
||||
write_to_json_files(all_nbest_json, output_nbest_file)
|
||||
if version_2_with_negative:
|
||||
write_to_json_files(scores_diff_json, output_null_log_odds_file)
|
||||
|
||||
|
||||
def postprocess_output(all_examples,
|
||||
all_features,
|
||||
all_results,
|
||||
n_best_size,
|
||||
max_answer_length,
|
||||
do_lower_case,
|
||||
version_2_with_negative=False,
|
||||
null_score_diff_threshold=0.0,
|
||||
verbose=False):
|
||||
"""Postprocess model output, to form predicton results."""
|
||||
|
||||
del do_lower_case, verbose
|
||||
|
||||
example_index_to_features = collections.defaultdict(list)
|
||||
for feature in all_features:
|
||||
example_index_to_features[feature.example_index].append(feature)
|
||||
|
||||
unique_id_to_result = {}
|
||||
for result in all_results:
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
||||
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
scores_diff_json = collections.OrderedDict()
|
||||
|
||||
for (example_index, example) in enumerate(all_examples):
|
||||
features = example_index_to_features[example_index]
|
||||
|
||||
prelim_predictions = []
|
||||
# keep track of the minimum score of null start+end of position 0
|
||||
score_null = 1000000 # large and positive
|
||||
min_null_feature_index = 0 # the paragraph slice with min mull score
|
||||
null_start_logit = 0 # the start logit at the slice with min null score
|
||||
null_end_logit = 0 # the end logit at the slice with min null score
|
||||
for (feature_index, feature) in enumerate(features):
|
||||
result = unique_id_to_result[feature.unique_id]
|
||||
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
||||
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
||||
# if we could have irrelevant answers, get the min score of irrelevant
|
||||
if version_2_with_negative:
|
||||
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
||||
if feature_null_score < score_null:
|
||||
score_null = feature_null_score
|
||||
min_null_feature_index = feature_index
|
||||
null_start_logit = result.start_logits[0]
|
||||
null_end_logit = result.end_logits[0]
|
||||
for start_index in start_indexes:
|
||||
for end_index in end_indexes:
|
||||
doc_offset = feature.tokens.index("[SEP]") + 1
|
||||
# We could hypothetically create invalid predictions, e.g., predict
|
||||
# that the start of the span is in the question. We throw out all
|
||||
# invalid predictions.
|
||||
if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
|
||||
continue
|
||||
if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
|
||||
continue
|
||||
# if start_index not in feature.tok_start_to_orig_index:
|
||||
# continue
|
||||
# if end_index not in feature.tok_end_to_orig_index:
|
||||
# continue
|
||||
if not feature.token_is_max_context.get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index:
|
||||
continue
|
||||
length = end_index - start_index + 1
|
||||
if length > max_answer_length:
|
||||
continue
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=feature_index,
|
||||
start_index=start_index - doc_offset,
|
||||
end_index=end_index - doc_offset,
|
||||
start_logit=result.start_logits[start_index],
|
||||
end_logit=result.end_logits[end_index]))
|
||||
|
||||
if version_2_with_negative:
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=min_null_feature_index,
|
||||
start_index=-1,
|
||||
end_index=-1,
|
||||
start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_logit + x.end_logit),
|
||||
reverse=True)
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
for pred in prelim_predictions:
|
||||
if len(nbest) >= n_best_size:
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
if pred.start_index >= 0: # this is a non-null prediction
|
||||
tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||
tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||
start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||
end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
||||
|
||||
paragraph_text = example.paragraph_text
|
||||
final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
|
||||
seen_predictions[final_text] = True
|
||||
else:
|
||||
final_text = ""
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_logit=pred.start_logit,
|
||||
end_logit=pred.end_logit))
|
||||
|
||||
# if we didn't inlude the empty option in the n-best, inlcude it
|
||||
if version_2_with_negative:
|
||||
if "" not in seen_predictions:
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text="", start_logit=null_start_logit,
|
||||
end_logit=null_end_logit))
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
assert len(nbest) >= 1
|
||||
|
||||
total_scores = []
|
||||
best_non_null_entry = None
|
||||
for entry in nbest:
|
||||
total_scores.append(entry.start_logit + entry.end_logit)
|
||||
if not best_non_null_entry:
|
||||
if entry.text:
|
||||
best_non_null_entry = entry
|
||||
|
||||
probs = _compute_softmax(total_scores)
|
||||
|
||||
nbest_json = []
|
||||
for (i, entry) in enumerate(nbest):
|
||||
output = collections.OrderedDict()
|
||||
output["text"] = entry.text
|
||||
output["probability"] = probs[i]
|
||||
output["start_logit"] = entry.start_logit
|
||||
output["end_logit"] = entry.end_logit
|
||||
nbest_json.append(output)
|
||||
|
||||
assert len(nbest_json) >= 1
|
||||
|
||||
if not version_2_with_negative:
|
||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||
else:
|
||||
assert best_non_null_entry is not None
|
||||
# predict "" iff the null score - the score of best non-null > threshold
|
||||
score_diff = score_null - best_non_null_entry.start_logit - (
|
||||
best_non_null_entry.end_logit)
|
||||
scores_diff_json[example.qas_id] = score_diff
|
||||
if score_diff > null_score_diff_threshold:
|
||||
all_predictions[example.qas_id] = ""
|
||||
else:
|
||||
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||
|
||||
all_nbest_json[example.qas_id] = nbest_json
|
||||
|
||||
return all_predictions, all_nbest_json, scores_diff_json
|
||||
|
||||
|
||||
def write_to_json_files(json_records, json_file):
|
||||
with tf.io.gfile.GFile(json_file, "w") as writer:
|
||||
writer.write(json.dumps(json_records, indent=4) + "\n")
|
||||
|
||||
|
||||
def _get_best_indexes(logits, n_best_size):
|
||||
"""Get the n-best logits from a list."""
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
|
||||
best_indexes = []
|
||||
for i in range(len(index_and_score)):
|
||||
if i >= n_best_size:
|
||||
break
|
||||
best_indexes.append(index_and_score[i][0])
|
||||
return best_indexes
|
||||
|
||||
|
||||
def _compute_softmax(scores):
|
||||
"""Compute softmax probability over raw logits."""
|
||||
if not scores:
|
||||
return []
|
||||
|
||||
max_score = None
|
||||
for score in scores:
|
||||
if max_score is None or score > max_score:
|
||||
max_score = score
|
||||
|
||||
exp_scores = []
|
||||
total_sum = 0.0
|
||||
for score in scores:
|
||||
x = math.exp(score - max_score)
|
||||
exp_scores.append(x)
|
||||
total_sum += x
|
||||
|
||||
probs = []
|
||||
for score in exp_scores:
|
||||
probs.append(score / total_sum)
|
||||
return probs
|
||||
|
||||
|
||||
class FeatureWriter(object):
|
||||
"""Writes InputFeature to TF example file."""
|
||||
|
||||
def __init__(self, filename, is_training):
|
||||
self.filename = filename
|
||||
self.is_training = is_training
|
||||
self.num_features = 0
|
||||
self._writer = tf.io.TFRecordWriter(filename)
|
||||
|
||||
def process_feature(self, feature):
|
||||
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
|
||||
self.num_features += 1
|
||||
|
||||
def create_int_feature(values):
|
||||
feature = tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=list(values)))
|
||||
return feature
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["unique_ids"] = create_int_feature([feature.unique_id])
|
||||
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
|
||||
if self.is_training:
|
||||
features["start_positions"] = create_int_feature([feature.start_position])
|
||||
features["end_positions"] = create_int_feature([feature.end_position])
|
||||
impossible = 0
|
||||
if feature.is_impossible:
|
||||
impossible = 1
|
||||
features["is_impossible"] = create_int_feature([impossible])
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
self._writer.write(tf_example.SerializeToString())
|
||||
|
||||
def close(self):
|
||||
self._writer.close()
|
||||
|
||||
|
||||
def generate_tf_record_from_json_file(input_file_path,
|
||||
sp_model_file,
|
||||
output_path,
|
||||
max_seq_length=384,
|
||||
do_lower_case=True,
|
||||
max_query_length=64,
|
||||
doc_stride=128,
|
||||
version_2_with_negative=False):
|
||||
"""Generates and saves training data into a tf record file."""
|
||||
train_examples = read_squad_examples(
|
||||
input_file=input_file_path,
|
||||
is_training=True,
|
||||
version_2_with_negative=version_2_with_negative)
|
||||
tokenizer = tokenization.FullSentencePieceTokenizer(
|
||||
sp_model_file=sp_model_file)
|
||||
train_writer = FeatureWriter(filename=output_path, is_training=True)
|
||||
number_of_examples = convert_examples_to_features(
|
||||
examples=train_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=max_seq_length,
|
||||
doc_stride=doc_stride,
|
||||
max_query_length=max_query_length,
|
||||
is_training=True,
|
||||
output_fn=train_writer.process_feature,
|
||||
do_lower_case=do_lower_case)
|
||||
train_writer.close()
|
||||
|
||||
meta_data = {
|
||||
"task_type": "bert_squad",
|
||||
"train_data_size": number_of_examples,
|
||||
"max_seq_length": max_seq_length,
|
||||
"max_query_length": max_query_length,
|
||||
"doc_stride": doc_stride,
|
||||
"version_2_with_negative": version_2_with_negative,
|
||||
}
|
||||
|
||||
return meta_data
|
||||
+43
@@ -0,0 +1,43 @@
|
||||
# NLP Modeling Library
|
||||
|
||||
This libary provides a set of Keras primitives (Layers, Networks, and Models)
|
||||
that can be assembled into transformer-based models. They are
|
||||
flexible, validated, interoperable, and both TF1 and TF2 compatible.
|
||||
|
||||
* [`layers`](layers) are the fundamental building blocks for NLP models.
|
||||
They can be used to assemble new layers, networks, or models.
|
||||
|
||||
* [`networks`](networks) are combinations of layers (and possibly other networks). They are sub-units of models that would not be trained alone. They
|
||||
encapsulate common network structures like a classification head
|
||||
or a transformer encoder into an easily handled object with a
|
||||
standardized configuration.
|
||||
|
||||
* [`models`](models) are combinations of layers and networks that would be trained. Pre-built canned models are provided as both convenience functions and canonical examples.
|
||||
|
||||
* [`losses`](losses) contains common loss computation used in NLP tasks.
|
||||
|
||||
Besides the pre-defined primitives, it also provides scaffold classes to allow
|
||||
easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance.
|
||||
|
||||
* [`TransformerScaffold`](layers/transformer_scaffold.py) implements the
|
||||
Transformer from ["Attention Is All You Need"]
|
||||
(https://arxiv.org/abs/1706.03762), with a customizable attention layer
|
||||
option. Users can pass a class to `attention_cls` and associated config to
|
||||
`attention_cfg`, in which case the scaffold will instantiate the class with
|
||||
the config, or pass a class instance to `attention_cls`.
|
||||
|
||||
* [`EncoderScaffold`](networks/encoder_scaffold.py) implements the transformer
|
||||
encoder from ["BERT: Pre-training of Deep Bidirectional Transformers for
|
||||
Language Understanding"](https://arxiv.org/abs/1810.04805), with customizable
|
||||
embedding subnetwork (which will replace the standard embedding logic) and/or a
|
||||
custom hidden layer (which will replace the Transformer instantiation in the
|
||||
encoder).
|
||||
|
||||
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
|
||||
+29
@@ -0,0 +1,29 @@
|
||||
# Layers
|
||||
Layers are the fundamental building blocks for NLP models. They can be used to
|
||||
assemble new layers, networks, or models.
|
||||
|
||||
* [DenseEinsum](dense_einsum.py) implements a feedforward network using tf.einsum. This layer contains the einsum op, the associated weight, and the
|
||||
logic required to generate the einsum expression for the given initialization
|
||||
parameters.
|
||||
|
||||
* [MultiHeadAttention](attention.py) implements an optionally masked attention
|
||||
between two tensors, from_tensor and to_tensor, as described in
|
||||
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
|
||||
If `from_tensor` and `to_tensor` are the same, then this is self-attention.
|
||||
|
||||
* [CachedAttention](attention.py) implements an attention layer with cache used
|
||||
for auto-agressive decoding.
|
||||
|
||||
* [Transformer](transformer.py) implements an optionally masked transformer as
|
||||
described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
|
||||
|
||||
* [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding lookups designed for TPU-based models.
|
||||
|
||||
* [PositionalEmbedding](position_embedding.py) creates a positional embedding
|
||||
as described in ["BERT: Pre-training
|
||||
of Deep Bidirectional Transformers for Language Understanding"]
|
||||
(https://arxiv.org/abs/1810.04805).
|
||||
|
||||
* [SelfAttentionMask](self_attention_mask.py) creates a 3D attention mask from a 2D tensor mask.
|
||||
|
||||
* [MaskedSoftmax](masked_softmax.py) implements a softmax with an optional masking input. If no mask is provided to this layer, it performs a standard softmax; however, if a mask tensor is applied (which should be 1 in positions where the data should be allowed through, and 0 where the data should be masked), the output will have masked positions set to approximately zero.
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Layers package definition."""
|
||||
from official.nlp.modeling.layers.attention import * # pylint: disable=wildcard-import
|
||||
from official.nlp.modeling.layers.dense_einsum import DenseEinsum
|
||||
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
|
||||
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
|
||||
from official.nlp.modeling.layers.position_embedding import PositionEmbedding
|
||||
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
|
||||
from official.nlp.modeling.layers.transformer import Transformer
|
||||
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
|
||||
+264
@@ -0,0 +1,264 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras-based attention layer."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
# from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.modeling.layers import dense_einsum
|
||||
from official.nlp.modeling.layers import masked_softmax
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package="Text")
|
||||
class MultiHeadAttention(tf.keras.layers.Layer):
|
||||
"""MultiHeadAttention layer.
|
||||
|
||||
This is an implementation of multi-headed attention based on "Attention
|
||||
is all you Need". If `from_tensor` and `to_tensor` are the same, then
|
||||
this is self-attention. Each timestep in `from_tensor` attends to the
|
||||
corresponding sequence in `to_tensor`, and returns a fixed-width vector.
|
||||
|
||||
This function first projects `from_tensor` into a "query" tensor and
|
||||
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
|
||||
of tensors of length `num_attention_heads`, where each tensor is of shape
|
||||
[batch_size, seq_length, size_per_head].
|
||||
|
||||
Then, the query and key tensors are dot-producted and scaled. These are
|
||||
softmaxed to obtain attention probabilities. The value tensors are then
|
||||
interpolated by these probabilities, then concatenated back to a single
|
||||
tensor and returned.
|
||||
|
||||
Arguments:
|
||||
num_heads: Number of attention heads.
|
||||
head_size: Size of each attention head.
|
||||
dropout: Dropout probability.
|
||||
kernel_initializer: Initializer for dense layer kernels.
|
||||
bias_initializer: Initializer for dense layer biases.
|
||||
kernel_regularizer: Regularizer for dense layer kernels.
|
||||
bias_regularizer: Regularizer for dense layer biases.
|
||||
activity_regularizer: Regularizer for dense layer activity.
|
||||
kernel_constraint: Constraint for dense layer kernels.
|
||||
bias_constraint: Constraint for dense layer kernels.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads,
|
||||
head_size,
|
||||
dropout_rate=0.0,
|
||||
kernel_initializer="glorot_uniform",
|
||||
bias_initializer="zeros",
|
||||
kernel_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
super(MultiHeadAttention, self).__init__(**kwargs)
|
||||
self._num_heads = num_heads
|
||||
self._head_size = head_size
|
||||
self._dropout_rate = dropout_rate
|
||||
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
|
||||
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
|
||||
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
|
||||
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
|
||||
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
|
||||
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
|
||||
|
||||
self._query_dense = dense_einsum.DenseEinsum(
|
||||
output_shape=(self._num_heads, self._head_size),
|
||||
kernel_initializer=self._kernel_initializer,
|
||||
bias_initializer=self._bias_initializer,
|
||||
kernel_regularizer=self._kernel_regularizer,
|
||||
bias_regularizer=self._bias_regularizer,
|
||||
activity_regularizer=self._activity_regularizer,
|
||||
kernel_constraint=self._kernel_constraint,
|
||||
bias_constraint=self._bias_constraint,
|
||||
name="query")
|
||||
|
||||
self._key_dense = dense_einsum.DenseEinsum(
|
||||
output_shape=(self._num_heads, self._head_size),
|
||||
kernel_initializer=self._kernel_initializer,
|
||||
bias_initializer=self._bias_initializer,
|
||||
kernel_regularizer=self._kernel_regularizer,
|
||||
bias_regularizer=self._bias_regularizer,
|
||||
activity_regularizer=self._activity_regularizer,
|
||||
kernel_constraint=self._kernel_constraint,
|
||||
bias_constraint=self._bias_constraint,
|
||||
name="key")
|
||||
|
||||
self._value_dense = dense_einsum.DenseEinsum(
|
||||
output_shape=(self._num_heads, self._head_size),
|
||||
kernel_initializer=self._kernel_initializer,
|
||||
bias_initializer=self._bias_initializer,
|
||||
kernel_regularizer=self._kernel_regularizer,
|
||||
bias_regularizer=self._bias_regularizer,
|
||||
activity_regularizer=self._activity_regularizer,
|
||||
kernel_constraint=self._kernel_constraint,
|
||||
bias_constraint=self._bias_constraint,
|
||||
name="value")
|
||||
|
||||
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
|
||||
|
||||
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"num_heads":
|
||||
self._num_heads,
|
||||
"head_size":
|
||||
self._head_size,
|
||||
"dropout_rate":
|
||||
self._dropout_rate,
|
||||
"kernel_initializer":
|
||||
tf.keras.initializers.serialize(self._kernel_initializer),
|
||||
"bias_initializer":
|
||||
tf.keras.initializers.serialize(self._bias_initializer),
|
||||
"kernel_regularizer":
|
||||
tf.keras.regularizers.serialize(self._kernel_regularizer),
|
||||
"bias_regularizer":
|
||||
tf.keras.regularizers.serialize(self._bias_regularizer),
|
||||
"activity_regularizer":
|
||||
tf.keras.regularizers.serialize(self._activity_regularizer),
|
||||
"kernel_constraint":
|
||||
tf.keras.constraints.serialize(self._kernel_constraint),
|
||||
"bias_constraint":
|
||||
tf.keras.constraints.serialize(self._bias_constraint)
|
||||
}
|
||||
base_config = super(MultiHeadAttention, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, inputs):
|
||||
from_tensor = inputs[0]
|
||||
to_tensor = inputs[1]
|
||||
attention_mask = inputs[2] if len(inputs) == 3 else None
|
||||
|
||||
# Scalar dimensions referenced here:
|
||||
# B = batch size (number of sequences)
|
||||
# F = `from_tensor` sequence length
|
||||
# T = `to_tensor` sequence length
|
||||
# N = `num_attention_heads`
|
||||
# H = `size_per_head`
|
||||
# `query_tensor` = [B, F, N ,H]
|
||||
query_tensor = self._query_dense(from_tensor)
|
||||
|
||||
# `key_tensor` = [B, T, N, H]
|
||||
key_tensor = self._key_dense(to_tensor)
|
||||
|
||||
# `value_tensor` = [B, T, N, H]
|
||||
value_tensor = self._value_dense(to_tensor)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw
|
||||
# attention scores.
|
||||
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
|
||||
attention_scores = tf.multiply(attention_scores,
|
||||
1.0 / math.sqrt(float(self._head_size)))
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
# `attention_probs` = [B, N, F, T]
|
||||
attention_probs = self._masked_softmax([attention_scores, attention_mask])
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self._dropout(attention_probs)
|
||||
|
||||
# `context_layer` = [B, F, N, H]
|
||||
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package="Text")
|
||||
class CachedAttention(MultiHeadAttention):
|
||||
"""Attention layer with cache used for auto-agressive decoding.
|
||||
|
||||
Arguments:
|
||||
num_heads: Number of attention heads.
|
||||
head_size: Size of each attention head.
|
||||
**kwargs: Other keyword arguments inherit from `Attention` class.
|
||||
"""
|
||||
|
||||
def __init__(self, num_heads, head_size, **kwargs):
|
||||
super(CachedAttention, self).__init__(num_heads, head_size, **kwargs)
|
||||
|
||||
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
|
||||
"""Updates cache states and gets full-length key/value tensors."""
|
||||
# Combines cached keys and values with new keys and values.
|
||||
if decode_loop_step is not None:
|
||||
# TPU special case.
|
||||
key_seq_dim = cache["key"].shape.as_list()[1]
|
||||
indices = tf.reshape(
|
||||
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
|
||||
[1, key_seq_dim, 1, 1])
|
||||
key_tensor = cache["key"] + key_tensor * indices
|
||||
value_seq_dim = cache["value"].shape.as_list()[1]
|
||||
indices = tf.reshape(
|
||||
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
|
||||
[1, value_seq_dim, 1, 1])
|
||||
value_tensor = cache["value"] + value_tensor * indices
|
||||
else:
|
||||
key_tensor = tf.concat(
|
||||
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
|
||||
value_tensor = tf.concat(
|
||||
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
|
||||
|
||||
# Update cache
|
||||
cache["key"] = key_tensor
|
||||
cache["value"] = value_tensor
|
||||
|
||||
return key_tensor, value_tensor
|
||||
|
||||
def call(self, inputs, decode_loop_step=None):
|
||||
from_tensor = inputs[0]
|
||||
to_tensor = inputs[1]
|
||||
attention_mask = inputs[2] if len(inputs) >= 3 else None
|
||||
cache = inputs[3] if len(inputs) >= 4 else None
|
||||
# Scalar dimensions referenced here:
|
||||
# B = batch size (number of sequences)
|
||||
# F = `from_tensor` sequence length
|
||||
# T = `to_tensor` sequence length
|
||||
# N = `num_attention_heads`
|
||||
# H = `size_per_head`
|
||||
# `query_tensor` = [B, F, N ,H]
|
||||
query_tensor = self._query_dense(from_tensor)
|
||||
|
||||
# `key_tensor` = [B, T, N, H]
|
||||
key_tensor = self._key_dense(to_tensor)
|
||||
|
||||
# `value_tensor` = [B, T, N, H]
|
||||
value_tensor = self._value_dense(to_tensor)
|
||||
|
||||
if cache:
|
||||
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
|
||||
cache, decode_loop_step)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw
|
||||
# attention scores.
|
||||
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
|
||||
attention_scores = tf.multiply(attention_scores,
|
||||
1.0 / math.sqrt(float(self._head_size)))
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
# `attention_probs` = [B, N, F, T]
|
||||
attention_probs = self._masked_softmax([attention_scores, attention_mask])
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self._dropout(attention_probs)
|
||||
|
||||
# `context_layer` = [B, F, N, H]
|
||||
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor), cache
|
||||
+157
@@ -0,0 +1,157 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the attention layer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
|
||||
from official.nlp.modeling.layers import attention
|
||||
|
||||
|
||||
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
|
||||
# guarantees forward compatibility of this code for the V2 switchover.
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class MultiHeadAttentionTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_non_masked_attention(self):
|
||||
"""Test that the attention layer can be created without a mask tensor."""
|
||||
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
from_tensor = tf.keras.Input(shape=(40, 80))
|
||||
to_tensor = tf.keras.Input(shape=(20, 80))
|
||||
output = test_layer([from_tensor, to_tensor])
|
||||
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
|
||||
|
||||
def test_non_masked_self_attention(self):
|
||||
"""Test with one input (self-attenntion) and no mask tensor."""
|
||||
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
from_tensor = tf.keras.Input(shape=(40, 80))
|
||||
output = test_layer([from_tensor, from_tensor])
|
||||
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
|
||||
|
||||
def test_masked_attention(self):
|
||||
"""Test with a mask tensor."""
|
||||
test_layer = attention.MultiHeadAttention(num_heads=2, head_size=2)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
from_tensor = tf.keras.Input(shape=(4, 8))
|
||||
to_tensor = tf.keras.Input(shape=(2, 8))
|
||||
mask_tensor = tf.keras.Input(shape=(4, 2))
|
||||
output = test_layer([from_tensor, to_tensor, mask_tensor])
|
||||
|
||||
# Create a model containing the test layer.
|
||||
model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output)
|
||||
|
||||
# Generate data for the input (non-mask) tensors.
|
||||
from_data = 10 * np.random.random_sample((3, 4, 8))
|
||||
to_data = 10 * np.random.random_sample((3, 2, 8))
|
||||
|
||||
# Invoke the data with a random set of mask data. This should mask at least
|
||||
# one element.
|
||||
mask_data = np.random.randint(2, size=(3, 4, 2))
|
||||
masked_output_data = model.predict([from_data, to_data, mask_data])
|
||||
|
||||
# Invoke the same data, but with a null mask (where no elements are masked).
|
||||
null_mask_data = np.ones((3, 4, 2))
|
||||
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
||||
|
||||
# Because one data is masked and one is not, the outputs should not be the
|
||||
# same.
|
||||
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
||||
|
||||
def test_initializer(self):
|
||||
"""Test with a specified initializer."""
|
||||
test_layer = attention.MultiHeadAttention(
|
||||
num_heads=12,
|
||||
head_size=64,
|
||||
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
from_tensor = tf.keras.Input(shape=(40, 80))
|
||||
output = test_layer([from_tensor, from_tensor])
|
||||
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
|
||||
|
||||
|
||||
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
|
||||
return {
|
||||
"key":
|
||||
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
|
||||
dtype=tf.float32),
|
||||
"value":
|
||||
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
|
||||
dtype=tf.float32)
|
||||
}
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class CachedAttentionTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_masked_attention(self):
|
||||
"""Test with a mask tensor."""
|
||||
num_heads, head_size = 2, 2
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
from_seq_length = 4
|
||||
batch_size = 3
|
||||
# GPU/CPU case.
|
||||
init_decode_length = 0
|
||||
# Directly tests the keras layer.
|
||||
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
|
||||
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size)
|
||||
|
||||
# Generate data for the input (non-mask) tensors.
|
||||
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
|
||||
# Invoke the data with a random set of mask data. This should mask at least
|
||||
# one element.
|
||||
mask_data = np.random.randint(
|
||||
2, size=(batch_size, from_seq_length, from_seq_length))
|
||||
masked_output_data, cache = layer([from_data, from_data, mask_data, cache])
|
||||
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
|
||||
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
|
||||
|
||||
# Tests inputs without cache.
|
||||
masked_output_data, cache = layer([from_data, from_data, mask_data])
|
||||
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
|
||||
self.assertIsNone(cache)
|
||||
|
||||
def test_padded_decode(self):
|
||||
"""Test with a mask tensor."""
|
||||
num_heads, head_size = 2, 2
|
||||
from_seq_length = 4
|
||||
# TPU decoding should pre-allocate the entire sequence.
|
||||
batch_size = 3
|
||||
init_decode_length = from_seq_length
|
||||
|
||||
# Directly tests the keras layer.
|
||||
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
|
||||
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size)
|
||||
|
||||
# Generate data for the input (non-mask) tensors.
|
||||
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
|
||||
decode_loop_step = 2
|
||||
mask_data = np.random.randint(
|
||||
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
|
||||
# Testing the invocation directly as Keras cannot consume inputs correctly.
|
||||
masked_output_data, cache = layer([from_data, from_data, mask_data, cache],
|
||||
decode_loop_step=decode_loop_step)
|
||||
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
|
||||
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
+180
@@ -0,0 +1,180 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras-based einsum layer."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
# from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package="Text")
|
||||
class DenseEinsum(tf.keras.layers.Layer):
|
||||
"""A densely connected layer that uses tf.einsum as the backing computation.
|
||||
|
||||
This layer can perform einsum calculations of arbitrary dimensionality.
|
||||
|
||||
Arguments:
|
||||
output_shape: Positive integer or tuple, dimensionality of the output space.
|
||||
num_summed_dimensions: The number of dimensions to sum over. Standard 2D
|
||||
matmul should use 1, 3D matmul should use 2, and so forth.
|
||||
activation: Activation function to use. If you don't specify anything, no
|
||||
activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
use_bias: Boolean, whether the layer uses a bias vector.
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix.
|
||||
bias_initializer: Initializer for the bias vector.
|
||||
kernel_regularizer: Regularizer function applied to the `kernel` weights
|
||||
matrix.
|
||||
bias_regularizer: Regularizer function applied to the bias vector.
|
||||
activity_regularizer: Regularizer function applied to the output of the
|
||||
layer (its "activation")..
|
||||
kernel_constraint: Constraint function applied to the `kernel` weights
|
||||
matrix.
|
||||
bias_constraint: Constraint function applied to the bias vector.
|
||||
Input shape:
|
||||
N-D tensor with shape: `(batch_size, ..., input_dim)`. The most common
|
||||
situation would be a 2D input with shape `(batch_size, input_dim)`.
|
||||
Output shape:
|
||||
N-D tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D
|
||||
input with shape `(batch_size, input_dim)`, the output would have shape
|
||||
`(batch_size, units)`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_shape,
|
||||
num_summed_dimensions=1,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
kernel_initializer="glorot_uniform",
|
||||
bias_initializer="zeros",
|
||||
kernel_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
super(DenseEinsum, self).__init__(**kwargs)
|
||||
self._output_shape = output_shape if isinstance(
|
||||
output_shape, (list, tuple)) else (output_shape,)
|
||||
self._activation = tf.keras.activations.get(activation)
|
||||
self._use_bias = use_bias
|
||||
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
|
||||
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
|
||||
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
|
||||
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
|
||||
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
|
||||
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
|
||||
self._num_summed_dimensions = num_summed_dimensions
|
||||
self._einsum_string = None
|
||||
|
||||
def _build_einsum_string(self, free_input_dims, bound_dims, output_dims):
|
||||
input_str = ""
|
||||
kernel_str = ""
|
||||
output_str = ""
|
||||
letter_offset = 0
|
||||
for i in range(free_input_dims):
|
||||
char = _CHR_IDX[i + letter_offset]
|
||||
input_str += char
|
||||
output_str += char
|
||||
|
||||
letter_offset += free_input_dims
|
||||
for i in range(bound_dims):
|
||||
char = _CHR_IDX[i + letter_offset]
|
||||
input_str += char
|
||||
kernel_str += char
|
||||
|
||||
letter_offset += bound_dims
|
||||
for i in range(output_dims):
|
||||
char = _CHR_IDX[i + letter_offset]
|
||||
kernel_str += char
|
||||
output_str += char
|
||||
|
||||
return input_str + "," + kernel_str + "->" + output_str
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tf.TensorShape(input_shape)
|
||||
input_rank = input_shape.rank
|
||||
free_input_dims = input_rank - self._num_summed_dimensions
|
||||
output_dims = len(self._output_shape)
|
||||
|
||||
self._einsum_string = self._build_einsum_string(free_input_dims,
|
||||
self._num_summed_dimensions,
|
||||
output_dims)
|
||||
|
||||
# This is only saved for testing purposes.
|
||||
self._kernel_shape = (
|
||||
input_shape[free_input_dims:].concatenate(self._output_shape))
|
||||
|
||||
self._kernel = self.add_weight(
|
||||
"kernel",
|
||||
shape=self._kernel_shape,
|
||||
initializer=self._kernel_initializer,
|
||||
regularizer=self._kernel_regularizer,
|
||||
constraint=self._kernel_constraint,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
if self._use_bias:
|
||||
self._bias = self.add_weight(
|
||||
"bias",
|
||||
shape=self._output_shape,
|
||||
initializer=self._bias_initializer,
|
||||
regularizer=self._bias_regularizer,
|
||||
constraint=self._bias_constraint,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
else:
|
||||
self._bias = None
|
||||
super(DenseEinsum, self).build(input_shape)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"output_shape":
|
||||
self._output_shape,
|
||||
"num_summed_dimensions":
|
||||
self._num_summed_dimensions,
|
||||
"activation":
|
||||
tf.keras.activations.serialize(self._activation),
|
||||
"use_bias":
|
||||
self._use_bias,
|
||||
"kernel_initializer":
|
||||
tf.keras.initializers.serialize(self._kernel_initializer),
|
||||
"bias_initializer":
|
||||
tf.keras.initializers.serialize(self._bias_initializer),
|
||||
"kernel_regularizer":
|
||||
tf.keras.regularizers.serialize(self._kernel_regularizer),
|
||||
"bias_regularizer":
|
||||
tf.keras.regularizers.serialize(self._bias_regularizer),
|
||||
"activity_regularizer":
|
||||
tf.keras.regularizers.serialize(self._activity_regularizer),
|
||||
"kernel_constraint":
|
||||
tf.keras.constraints.serialize(self._kernel_constraint),
|
||||
"bias_constraint":
|
||||
tf.keras.constraints.serialize(self._bias_constraint)
|
||||
}
|
||||
base_config = super(DenseEinsum, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, inputs):
|
||||
ret = tf.einsum(self._einsum_string, inputs, self._kernel)
|
||||
if self._use_bias:
|
||||
ret += self._bias
|
||||
if self._activation is not None:
|
||||
ret = self._activation(ret)
|
||||
return ret
|
||||
+123
@@ -0,0 +1,123 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras-based einsum layer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
|
||||
from official.nlp.modeling.layers import dense_einsum
|
||||
|
||||
|
||||
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
|
||||
# guarantees forward compatibility of this code for the V2 switchover.
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class DenseEinsumLayer(keras_parameterized.TestCase):
|
||||
|
||||
def test_3D_einsum_with_two_bound_dimensions(self):
|
||||
test_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=(64,), num_summed_dimensions=2)
|
||||
# Create a 4-dimensional input (the first dimension is implicit).
|
||||
input_tensor = tf.keras.Input(shape=(None, 40, 80))
|
||||
_ = test_layer(input_tensor)
|
||||
self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
|
||||
self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
|
||||
|
||||
def test_3D_einsum_with_one_bound_dimensions(self):
|
||||
test_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=(64, 32), num_summed_dimensions=1)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
input_tensor = tf.keras.Input(shape=(None, 80))
|
||||
_ = test_layer(input_tensor)
|
||||
self.assertEqual(test_layer._einsum_string, "abc,cde->abde")
|
||||
self.assertEqual(test_layer._kernel_shape, (80, 64, 32))
|
||||
|
||||
def test_2D_einsum_with_one_bound_dimensions(self):
|
||||
test_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=(64,), num_summed_dimensions=1)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
input_tensor = tf.keras.Input(shape=(None, 80))
|
||||
_ = test_layer(input_tensor)
|
||||
self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
|
||||
self.assertEqual(test_layer._kernel_shape, (80, 64))
|
||||
|
||||
def test_bias_term_can_be_disabled(self):
|
||||
# A layer created using the bias should have two weights.
|
||||
test_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=64, num_summed_dimensions=1, use_bias=True)
|
||||
input_tensor = tf.keras.Input(shape=(None, 80))
|
||||
_ = test_layer(input_tensor)
|
||||
self.assertEqual(2, len(test_layer.get_weights()))
|
||||
|
||||
# A layer created without the bias should have only one weight.
|
||||
test_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=64, num_summed_dimensions=1, use_bias=False)
|
||||
input_tensor = tf.keras.Input(shape=(None, 80))
|
||||
_ = test_layer(input_tensor)
|
||||
self.assertEqual(1, len(test_layer.get_weights()))
|
||||
|
||||
def test_activation(self):
|
||||
# Create a model that does not use an activation.
|
||||
no_activation_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=64, num_summed_dimensions=1, activation=None)
|
||||
input_tensor = tf.keras.Input(shape=(None, 80))
|
||||
output_tensor = no_activation_layer(input_tensor)
|
||||
no_activation_model = tf.keras.Model(input_tensor, output_tensor)
|
||||
|
||||
# Create a model that uses a softmax activation.
|
||||
activation_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=64, num_summed_dimensions=1, activation="softmax")
|
||||
input_tensor = tf.keras.Input(shape=(None, 80))
|
||||
output_tensor = activation_layer(input_tensor)
|
||||
activation_model = tf.keras.Model(input_tensor, output_tensor)
|
||||
|
||||
# Make sure the models' weights are identical.
|
||||
activation_model.set_weights(no_activation_model.get_weights())
|
||||
|
||||
# Predict using each model on the same input data. The output should be
|
||||
# different, since one is using a softmax - even though the models' weights
|
||||
# are the same.
|
||||
input_values = 10 * np.random.random_sample((10, 4, 80))
|
||||
non_activated_data = no_activation_model.predict(input_values)
|
||||
activated_data = activation_model.predict(input_values)
|
||||
self.assertNotAllClose(activated_data, non_activated_data)
|
||||
|
||||
def test_non_iterable_output_shape(self):
|
||||
test_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=64, num_summed_dimensions=1)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
input_tensor = tf.keras.Input(shape=(None, 80))
|
||||
_ = test_layer(input_tensor)
|
||||
self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
|
||||
self.assertEqual(test_layer._kernel_shape, (80, 64))
|
||||
|
||||
def test_with_explicit_initializer(self):
|
||||
test_layer = dense_einsum.DenseEinsum(
|
||||
output_shape=(64,),
|
||||
num_summed_dimensions=2,
|
||||
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
|
||||
# Create a 4-dimensional input (the first dimension is implicit).
|
||||
input_tensor = tf.keras.Input(shape=(None, 40, 80))
|
||||
_ = test_layer(input_tensor)
|
||||
self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
|
||||
self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
+61
@@ -0,0 +1,61 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras-based softmax layer with optional masking."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
# from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package='Text')
|
||||
class MaskedSoftmax(tf.keras.layers.Layer):
|
||||
"""Performs a softmax with optional masking on a tensor.
|
||||
|
||||
Arguments:
|
||||
mask_expansion_axes: Any axes that should be padded on the mask tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, mask_expansion_axes=None, **kwargs):
|
||||
self._mask_expansion_axes = mask_expansion_axes
|
||||
super(MaskedSoftmax, self).__init__(**kwargs)
|
||||
|
||||
def call(self, inputs):
|
||||
if isinstance(inputs, list) and len(inputs) == 2:
|
||||
scores, mask = inputs
|
||||
else:
|
||||
scores, mask = (inputs, None)
|
||||
|
||||
if mask is not None:
|
||||
if self._mask_expansion_axes is not None:
|
||||
mask = tf.expand_dims(mask, axis=self._mask_expansion_axes)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
adder = (1.0 - tf.cast(mask, scores.dtype)) * -10000.0
|
||||
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
scores += adder
|
||||
|
||||
return tf.nn.softmax(scores)
|
||||
|
||||
def get_config(self):
|
||||
config = {'mask_expansion_axes': self._mask_expansion_axes}
|
||||
base_config = super(MaskedSoftmax, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
+88
@@ -0,0 +1,88 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras-based masked softmax layer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
|
||||
from official.nlp.modeling.layers import masked_softmax
|
||||
|
||||
|
||||
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
|
||||
# guarantees forward compatibility of this code for the V2 switchover.
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_non_masked_softmax(self):
|
||||
test_layer = masked_softmax.MaskedSoftmax()
|
||||
input_tensor = tf.keras.Input(shape=(4, 8))
|
||||
output = test_layer(input_tensor)
|
||||
model = tf.keras.Model(input_tensor, output)
|
||||
|
||||
input_data = 10 * np.random.random_sample((3, 4, 8))
|
||||
output_data = model.predict(input_data)
|
||||
expected_data = tf.nn.softmax(input_data)
|
||||
self.assertAllClose(expected_data, output_data)
|
||||
|
||||
def test_masked_softmax(self):
|
||||
test_layer = masked_softmax.MaskedSoftmax()
|
||||
input_tensor = tf.keras.Input(shape=(4, 8))
|
||||
mask_tensor = tf.keras.Input(shape=(4, 8))
|
||||
output = test_layer([input_tensor, mask_tensor])
|
||||
model = tf.keras.Model([input_tensor, mask_tensor], output)
|
||||
|
||||
input_data = 10 * np.random.random_sample((3, 4, 8))
|
||||
mask_data = np.random.randint(2, size=(3, 4, 8))
|
||||
|
||||
output_data = model.predict([input_data, mask_data])
|
||||
expected_zeros = np.greater(mask_data, 0)
|
||||
is_zeros = np.greater(output_data, 0)
|
||||
self.assertAllEqual(expected_zeros, is_zeros)
|
||||
|
||||
def test_masked_softmax_with_none_mask(self):
|
||||
test_layer = masked_softmax.MaskedSoftmax()
|
||||
input_tensor = tf.keras.Input(shape=(4, 8))
|
||||
output = test_layer([input_tensor, None])
|
||||
model = tf.keras.Model(input_tensor, output)
|
||||
|
||||
input_data = 10 * np.random.random_sample((3, 4, 8))
|
||||
output_data = model.predict(input_data)
|
||||
expected_data = tf.nn.softmax(input_data)
|
||||
self.assertAllClose(expected_data, output_data)
|
||||
|
||||
def test_softmax_with_axes_expansion(self):
|
||||
test_layer = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
|
||||
input_tensor = tf.keras.Input(shape=(4, 8))
|
||||
mask_tensor = tf.keras.Input(shape=(8))
|
||||
output = test_layer([input_tensor, mask_tensor])
|
||||
model = tf.keras.Model([input_tensor, mask_tensor], output)
|
||||
|
||||
input_data = 10 * np.random.random_sample((3, 4, 8))
|
||||
mask_data = np.random.randint(2, size=(3, 8))
|
||||
|
||||
output_data = model.predict([input_data, mask_data])
|
||||
expanded_mask = np.expand_dims(mask_data, axis=1) * np.ones_like(input_data)
|
||||
expected_zeros = np.greater(expanded_mask, 0)
|
||||
is_zeros = np.greater(output_data, 0)
|
||||
self.assertAllEqual(expected_zeros, is_zeros)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user