[add]上传训练benchmark by z00560161

This commit is contained in:
liang_chaoming@huawei.com
2020-10-19 20:22:23 +08:00
parent 22b83024f5
commit 82522e2f61
1225 changed files with 345421 additions and 0 deletions
@@ -0,0 +1,37 @@
# ResNet101_tensorflow训练说明
### 1. 模型训练参数配置
在train/yaml/ResNet101.yaml中修改相应配置, 配置项含义:
```
tensorflow_config:
# 基本参数
data_url: /home/imagenet_TF/
# 1p/8p,epoches设为150
epoches: 1
epochs_between_evals: 1
max_train_steps: 1000
batch_size: 128
# 仅多机执行需要配置: ip1:卡数量1,ip2:卡数量2
mpirun_ip: 90.90.176.152:8,90.90.176.154:8
# docker 镜像名称:版本号
docker_image: c73:b02
# 指定 device id, 多个 id 使用空格分隔, 数量需与 rank_size 相同
device_group_1p: 0
device_group_2p: 0 1
device_group_4p: 0 1 2 3
```
------
@@ -0,0 +1,203 @@
Copyright 2015 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2015, The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
@@ -0,0 +1,151 @@
![Logo](https://storage.googleapis.com/model_garden_artifacts/TF_Model_Garden.png)
# TensorFlow Official Models
The TensorFlow official models are a collection of models
that use TensorFlows high-level APIs.
They are intended to be well-maintained, tested, and kept up to date
with the latest TensorFlow API.
They should also be reasonably optimized for fast performance while still
being easy to read.
These models are used as end-to-end tests, ensuring that the models run
with the same or improved speed and performance with each new TensorFlow build.
## Model Implementations
### Natural Language Processing
| Model | Description | Reference |
| ----- | ----------- | --------- |
| [ALBERT](nlp/albert) | A Lite BERT for Self-supervised Learning of Language Representations | [arXiv:1909.11942](https://arxiv.org/abs/1909.11942) |
| [BERT](nlp/bert) | A powerful pre-trained language representation model: BERT (Bidirectional Encoder Representations from Transformers) | [arXiv:1810.04805](https://arxiv.org/abs/1810.04805) |
| [NHNet](nlp/nhnet) | A transformer-based multi-sequence to sequence model: Generating Representative Headlines for News Stories | [arXiv:2001.09386](https://arxiv.org/abs/2001.09386) |
| [Transformer](nlp/transformer) | A transformer model to translate the WMT English to German dataset | [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) |
| [XLNet](nlp/xlnet) | XLNet: Generalized Autoregressive Pretraining for Language Understanding | [arXiv:1906.08237](https://arxiv.org/abs/1906.08237) |
### Computer Vision
| Model | Description | Reference |
| ----- | ----------- | --------- |
| [MNIST](vision/image_classification) | A basic model to classify digits from the MNIST dataset | [Link](http://yann.lecun.com/exdb/mnist/) |
| [ResNet](vision/image_classification) | A deep residual network for image recognition | [arXiv:1512.03385](https://arxiv.org/abs/1512.03385) |
| [RetinaNet](vision/detection) | A fast and powerful object detector | [arXiv:1708.02002](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | An object detection and instance segmentation model | [arXiv:1703.06870](https://arxiv.org/abs/1703.06870) |
### Other models
| Model | Description | Reference |
| ----- | ----------- | --------- |
| [NCF](recommendation) | Neural Collaborative Filtering model for recommendation tasks | [arXiv:1708.05031](https://arxiv.org/abs/1708.05031) |
---
## How to get started with the Model Garden official models
* The models in the master branch are developed using TensorFlow 2,
and they target the TensorFlow [nightly binaries](https://github.com/tensorflow/tensorflow#installation)
built from the
[master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master).
* The stable versions targeting releases of TensorFlow are available
as tagged branches or [downloadable releases](https://github.com/tensorflow/models/releases).
* Model repository version numbers match the target TensorFlow release,
such that
[release v2.1.0](https://github.com/tensorflow/models/releases/tag/v2.1.0)
are compatible with
[TensorFlow v2.1.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0).
Please follow the below steps before running models in this repository.
### Requirements
* The latest TensorFlow Model Garden release and TensorFlow 2
* If you are on a version of TensorFlow earlier than 2.1, please
upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
```shell
pip3 install tf-nightly
```
### Installation
#### Method 1: Install the TensorFlow Model Garden pip package
**tf-models-nightly** is the nightly Model Garden package
created daily automatically. pip will install all models
and dependencies automatically.
```shell
pip install tf-models-nightly
```
Please check out our [example](colab/bert.ipynb)
to learn how to use a PIP package.
#### Method 2: Clone the source
1. Clone the GitHub repository:
```shell
git clone https://github.com/tensorflow/models.git
```
2. Add the top-level ***/models*** folder to the Python path.
```shell
export PYTHONPATH=$PYTHONPATH:/path/to/models
```
If you are using a Colab notebook, please set the Python path with os.environ.
```python
import os
os.environ['PYTHONPATH'] += ":/path/to/models"
```
3. Install other dependencies
```shell
pip3 install --user -r official/requirements.txt
```
---
## More models to come!
The team is actively developing new models.
In the near future, we will add:
- State-of-the-art language understanding models:
More members in Transformer family
- Start-of-the-art image classification models:
EfficientNet, MnasNet and variants.
- A set of excellent objection detection models.
If you would like to make any fixes or improvements to the models, please
[submit a pull request](https://github.com/tensorflow/models/compare).
---
## Contributions
Every model should follow our guidelines to uphold our objectives of readable,
usable, and maintainable code.
### General Guidelines
- Code should be well documented and tested.
- Runnable from a blank environment with ease.
- Trainable on: single GPU/CPU (baseline), multiple GPUs & TPUs
- Compatible with Python 3 (using [six](https://pythonhosted.org/six/)
when being compatible with Python 2 is necessary)
- Conform to
[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md)
### Implementation Guidelines
These guidelines are to ensure consistent model implementations for
better readability and maintainability.
- Use [common utility functions](utils)
- Export SavedModel at the end of the training.
- Consistent flags and flag-parsing library ([read more here](utils/flags/guidelines.md))
@@ -0,0 +1,25 @@
# Offically Supported TensorFlow 2.1+ Models on Cloud TPU
## Natural Language Processing
* [bert](nlp/bert): A powerful pre-trained language representation model:
BERT, which stands for Bidirectional Encoder Representations from
Transformers.
[BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
* [transformer](nlp/transformer): A transformer model to translate the WMT
English to German dataset.
[Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
## Computer Vision
* [efficientnet](vision/image_classification): A family of convolutional
neural networks that scale by balancing network depth, width, and
resolution and can be used to classify ImageNet's dataset of 1000 classes.
See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
* [mnist](vision/image_classification): A basic model to classify digits
from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
* [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
* [resnet](vision/image_classification): A deep residual network that can
be used to classify ImageNet's dataset of 1000 classes.
See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
* [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
@@ -0,0 +1,79 @@
# ResNet in TensorFlow On NPU
---
# Classification Model
## Overview
1. This is an implementation of the ResNet101 model as described in the [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) paper.
2. Current implementation is based on the code from the TensorFlow Official implementation in the [tensorflow/models Repo](https://github.com/tensorflow/models).
## Introduction
1. ResNet is a relatively good network for classification problems in the ImageNet competition. It introduces the concept of residual learning. It protects the integrity of information by adding direct connection channels, solves problems such as information loss, gradient disappearance, and gradient explosion. The network can also be trained. ResNet has different network layers, commonly used are 18-layer, 34-layer, 50-layer, 101-layer, 152-layer.
2. Ascend provides the V1.5 version of the 50-layer ResNet network this time. The difference between the V1.5 version of the ResNet network and the V1 version is that in the bottleneck module, the V1 version is set stride=2 in the first 1x1 convolutional layer, and V1.5 sets stride=2 in the 3x3 convolutional layer.
## Dataset
We have used the [ImageNet](http://www.image-net.org/)dataset as an example here, you can use mnist or your own dataset to modify and adapt.
We use [build_imagenet_data](https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/research/slim/datasets/build_imagenet_data.py) to build record for training.
## Running Code
### Config the env paramater
check if path '/usr/local/HiAI' or ''/usr/local/Ascend' is existed or not.
modify '/usr/local/HiAI' to the actual path in scripts/run.sh
### Train and evaluate model
[imagenet_main.py](official/r1/resnet/imagenet_main.py) is the Entry Python script.
[resnet_run_loop.py](official/r1/resnet/resnet_run_loop.py) is the Main Python script.
### Check your rank_table
default rank_table setting in [configs](official/r1/resnet/configs) is usrd for X86.
if you use aach64, please modify board_id from "0x0000" -->
To train and evaluate the model, issue the following command:
```
# for single training
bash ./scripts/train_1p.sh
# for multi training
bash ./scripts/train_8p.sh
```
Default Args:
- Batch size: 128
- Momentum: 0.9
- LR scheduler: cosine
- Learning rate(LR): 0.064
- loss scale: 512
- Weight decay: 0.0001
- Label smoothing: 0.1
- train epoch: 90
There are other arguments about models and training process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions.
### Train and evaluate result
- 1 NPU
- Train performance109ms/step1170images/sec.
- 8 NPU
- Train performance109ms/step9390images/sec.
- best result
- Accuracy(Top1): 79.03
- Accuracy(Top5): 94.53
### More
#### modify file
- The npu modify file list as follows:
- DaVinci npu platform adaptation code,including
1.official/r1/resnet/imagenet_main.py
2.official/r1/resnet/resnet_model.py
3.official/r1/resnet/resnet_run_loop.py
4.official/utils/flags/_base.py
#### FileTree Intro
- Main Dir
- ./official/r1/resnet
- Single NPU Training Shell
- npu_train_1p_test.sh
- Multi NPU(8p) Training Shell
- npu_train_8p_test.sh
- Log Info
- STDOUT nohup.out
- Performance perf.log
@@ -0,0 +1,157 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to upload benchmark generated by BenchmarkLogger to remote repo.
This library require google cloud bigquery lib as dependency, which can be
installed with:
> pip install --upgrade google-cloud-bigquery
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from google.cloud import bigquery
from google.cloud import exceptions
import tensorflow as tf
class BigQueryUploader(object):
"""Upload the benchmark and metric info from JSON input to BigQuery. """
def __init__(self, gcp_project=None, credentials=None):
"""Initialized BigQueryUploader with proper setting.
Args:
gcp_project: string, the name of the GCP project that the log will be
uploaded to. The default project name will be detected from local
environment if no value is provided.
credentials: google.auth.credentials. The credential to access the
BigQuery service. The default service account credential will be
detected from local environment if no value is provided. Please use
google.oauth2.service_account.Credentials to load credential from local
file for the case that the test is run out side of GCP.
"""
self._bq_client = bigquery.Client(
project=gcp_project, credentials=credentials)
def upload_benchmark_run_json(
self, dataset_name, table_name, run_id, run_json):
"""Upload benchmark run information to Bigquery.
Args:
dataset_name: string, the name of bigquery dataset where the data will be
uploaded.
table_name: string, the name of bigquery table under the dataset where
the data will be uploaded.
run_id: string, a unique ID that will be attached to the data, usually
this is a UUID4 format.
run_json: dict, the JSON data that contains the benchmark run info.
"""
run_json["model_id"] = run_id
self._upload_json(dataset_name, table_name, [run_json])
def upload_benchmark_metric_json(
self, dataset_name, table_name, run_id, metric_json_list):
"""Upload metric information to Bigquery.
Args:
dataset_name: string, the name of bigquery dataset where the data will be
uploaded.
table_name: string, the name of bigquery table under the dataset where
the metric data will be uploaded. This is different from the
benchmark_run table.
run_id: string, a unique ID that will be attached to the data, usually
this is a UUID4 format. This should be the same as the benchmark run_id.
metric_json_list: list, a list of JSON object that record the metric info.
"""
for m in metric_json_list:
m["run_id"] = run_id
self._upload_json(dataset_name, table_name, metric_json_list)
def upload_benchmark_run_file(
self, dataset_name, table_name, run_id, run_json_file):
"""Upload benchmark run information to Bigquery from input json file.
Args:
dataset_name: string, the name of bigquery dataset where the data will be
uploaded.
table_name: string, the name of bigquery table under the dataset where
the data will be uploaded.
run_id: string, a unique ID that will be attached to the data, usually
this is a UUID4 format.
run_json_file: string, the file path that contains the run JSON data.
"""
with tf.io.gfile.GFile(run_json_file) as f:
benchmark_json = json.load(f)
self.upload_benchmark_run_json(
dataset_name, table_name, run_id, benchmark_json)
def upload_metric_file(
self, dataset_name, table_name, run_id, metric_json_file):
"""Upload metric information to Bigquery from input json file.
Args:
dataset_name: string, the name of bigquery dataset where the data will be
uploaded.
table_name: string, the name of bigquery table under the dataset where
the metric data will be uploaded. This is different from the
benchmark_run table.
run_id: string, a unique ID that will be attached to the data, usually
this is a UUID4 format. This should be the same as the benchmark run_id.
metric_json_file: string, the file path that contains the metric JSON
data.
"""
with tf.io.gfile.GFile(metric_json_file) as f:
metrics = []
for line in f:
metrics.append(json.loads(line.strip()))
self.upload_benchmark_metric_json(
dataset_name, table_name, run_id, metrics)
def _upload_json(self, dataset_name, table_name, json_list):
# Find the unique table reference based on dataset and table name, so that
# the data can be inserted to it.
table_ref = self._bq_client.dataset(dataset_name).table(table_name)
errors = self._bq_client.insert_rows_json(table_ref, json_list)
if errors:
tf.logging.error(
"Failed to upload benchmark info to bigquery: {}".format(errors))
def insert_run_status(self, dataset_name, table_name, run_id, run_status):
"""Insert the run status in to Bigquery run status table."""
query = ("INSERT {ds}.{tb} "
"(run_id, status) "
"VALUES('{rid}', '{status}')").format(
ds=dataset_name, tb=table_name, rid=run_id, status=run_status)
try:
self._bq_client.query(query=query).result()
except exceptions.GoogleCloudError as e:
tf.logging.error("Failed to insert run status: %s", e)
def update_run_status(self, dataset_name, table_name, run_id, run_status):
"""Update the run status in in Bigquery run status table."""
query = ("UPDATE {ds}.{tb} "
"SET status = '{status}' "
"WHERE run_id = '{rid}'").format(
ds=dataset_name, tb=table_name, status=run_status, rid=run_id)
try:
self._bq_client.query(query=query).result()
except exceptions.GoogleCloudError as e:
tf.logging.error("Failed to update run status: %s", e)
@@ -0,0 +1,66 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Binary to upload benchmark generated by BenchmarkLogger to remote repo.
This library require google cloud bigquery lib as dependency, which can be
installed with:
> pip install --upgrade google-cloud-bigquery
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import uuid
from absl import app as absl_app
from absl import flags
from official.benchmark import benchmark_uploader
from official.utils.flags import core as flags_core
from official.utils.logs import logger
def main(_):
if not flags.FLAGS.benchmark_log_dir:
print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir")
sys.exit(1)
uploader = benchmark_uploader.BigQueryUploader(
gcp_project=flags.FLAGS.gcp_project)
run_id = str(uuid.uuid4())
run_json_file = os.path.join(
flags.FLAGS.benchmark_log_dir, logger.BENCHMARK_RUN_LOG_FILE_NAME)
metric_json_file = os.path.join(
flags.FLAGS.benchmark_log_dir, logger.METRIC_LOG_FILE_NAME)
uploader.upload_benchmark_run_file(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id,
run_json_file)
uploader.upload_metric_file(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id,
metric_json_file)
# Assume the run finished successfully before user invoke the upload script.
uploader.insert_run_status(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_status_table,
run_id, logger.RUN_STATUS_SUCCESS)
if __name__ == "__main__":
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
absl_app.run(main=main)
@@ -0,0 +1,123 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for benchmark_uploader."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import tempfile
import unittest
from mock import MagicMock
from mock import patch
import tensorflow as tf # pylint: disable=g-bad-import-order
try:
from google.cloud import bigquery
from official.benchmark import benchmark_uploader
except ImportError:
bigquery = None
benchmark_uploader = None
@unittest.skipIf(bigquery is None, "Bigquery dependency is not installed.")
class BigQueryUploaderTest(tf.test.TestCase):
@patch.object(bigquery, "Client")
def setUp(self, mock_bigquery):
self.mock_client = mock_bigquery.return_value
self.mock_dataset = MagicMock(name="dataset")
self.mock_table = MagicMock(name="table")
self.mock_client.dataset.return_value = self.mock_dataset
self.mock_dataset.table.return_value = self.mock_table
self.mock_client.insert_rows_json.return_value = []
self.benchmark_uploader = benchmark_uploader.BigQueryUploader()
self.benchmark_uploader._bq_client = self.mock_client
self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
with open(os.path.join(self.log_dir, "metric.log"), "a") as f:
json.dump({"name": "accuracy", "value": 1.0}, f)
f.write("\n")
json.dump({"name": "loss", "value": 0.5}, f)
f.write("\n")
with open(os.path.join(self.log_dir, "run.log"), "w") as f:
json.dump({"model_name": "value"}, f)
def tearDown(self):
tf.io.gfile.rmtree(self.get_temp_dir())
def test_upload_benchmark_run_json(self):
self.benchmark_uploader.upload_benchmark_run_json(
"dataset", "table", "run_id", {"model_name": "value"})
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
def test_upload_benchmark_metric_json(self):
metric_json_list = [
{"name": "accuracy", "value": 1.0},
{"name": "loss", "value": 0.5}
]
expected_params = [
{"run_id": "run_id", "name": "accuracy", "value": 1.0},
{"run_id": "run_id", "name": "loss", "value": 0.5}
]
self.benchmark_uploader.upload_benchmark_metric_json(
"dataset", "table", "run_id", metric_json_list)
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params)
def test_upload_benchmark_run_file(self):
self.benchmark_uploader.upload_benchmark_run_file(
"dataset", "table", "run_id", os.path.join(self.log_dir, "run.log"))
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
def test_upload_metric_file(self):
self.benchmark_uploader.upload_metric_file(
"dataset", "table", "run_id",
os.path.join(self.log_dir, "metric.log"))
expected_params = [
{"run_id": "run_id", "name": "accuracy", "value": 1.0},
{"run_id": "run_id", "name": "loss", "value": 0.5}
]
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params)
def test_insert_run_status(self):
self.benchmark_uploader.insert_run_status(
"dataset", "table", "run_id", "status")
expected_query = ("INSERT dataset.table "
"(run_id, status) "
"VALUES('run_id', 'status')")
self.mock_client.query.assert_called_once_with(query=expected_query)
def test_update_run_status(self):
self.benchmark_uploader.update_run_status(
"dataset", "table", "run_id", "status")
expected_query = ("UPDATE dataset.table "
"SET status = 'status' "
"WHERE run_id = 'run_id'")
self.mock_client.query.assert_called_once_with(query=expected_query)
if __name__ == "__main__":
tf.test.main()
@@ -0,0 +1,97 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils to annotate and trace benchmarks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl import logging
from absl.testing import flagsaver
FLAGS = flags.FLAGS
flags.DEFINE_multi_string(
'benchmark_method_flags', None,
'Optional list of runtime flags of the form key=value. Specify '
'multiple times to specify different flags. These will override the FLAGS '
'object directly after hardcoded settings in individual benchmark methods '
'before they call _run_and_report benchmark. Example if we set '
'--benchmark_method_flags=train_steps=10 and a benchmark method hardcodes '
'FLAGS.train_steps=10000 and later calls _run_and_report_benchmark, '
'it\'ll only run for 10 steps. This is useful for '
'debugging/profiling workflows.')
def enable_runtime_flags(decorated_func):
"""Sets attributes from --benchmark_method_flags for method execution.
@enable_runtime_flags decorator temporarily adds flags passed in via
--benchmark_method_flags and runs the decorated function in that context.
A user can set --benchmark_method_flags=train_steps=5 to run the benchmark
method in the snippet below with FLAGS.train_steps=5 for debugging (without
modifying the benchmark code).
class ModelBenchmark():
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
# run benchmark ...
# report benchmark results ...
def benchmark_method(self):
FLAGS.train_steps = 1000
...
self._run_and_report_benchmark()
Args:
decorated_func: The method that runs the benchmark after previous setup
execution that set some flags.
Returns:
new_func: The same method which executes in a temporary context where flag
overrides from --benchmark_method_flags are active.
"""
def runner(*args, **kwargs):
"""Creates a temporary context to activate --benchmark_method_flags."""
if FLAGS.benchmark_method_flags:
saved_flag_values = flagsaver.save_flag_values()
for key_value in FLAGS.benchmark_method_flags:
key, value = key_value.split('=', 1)
try:
numeric_float = float(value)
numeric_int = int(numeric_float)
if abs(numeric_int) == abs(numeric_float):
flag_value = numeric_int
else:
flag_value = numeric_float
except ValueError:
flag_value = value
logging.info('Setting --%s=%s', key, flag_value)
setattr(FLAGS, key, flag_value)
else:
saved_flag_values = None
try:
result = decorated_func(*args, **kwargs)
return result
finally:
if saved_flag_values:
flagsaver.restore_flag_values(saved_flag_values)
return runner
@@ -0,0 +1,354 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes BERT benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import math
import os
import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp.bert import configs
from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils
from official.benchmark import benchmark_wrappers
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record'
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record'
CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data'
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
# pylint: enable=line-too-long
TMP_DIR = os.getenv('TMPDIR')
FLAGS = flags.FLAGS
class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, tpu=None):
super(BertClassifyBenchmarkBase, self).__init__(output_dir)
self.num_epochs = None
self.num_steps_per_epoch = None
self.tpu = tpu
FLAGS.steps_per_loop = 50
@flagsaver.flagsaver
def _run_bert_classifier(self, callbacks=None, use_ds=True):
"""Starts BERT classification task."""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
if self.num_steps_per_epoch:
steps_per_epoch = self.num_steps_per_epoch
else:
train_data_size = input_meta_data['train_data_size']
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * steps_per_epoch * 0.1)
eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if self.tpu:
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
else:
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = run_classifier.get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
eval_input_fn = run_classifier.get_dataset_fn(
FLAGS.eval_data_path,
max_seq_length,
FLAGS.eval_batch_size,
is_training=False)
run_classifier.run_bert_classifier(
strategy,
bert_config,
input_meta_data,
FLAGS.model_dir,
epochs,
steps_per_epoch,
FLAGS.steps_per_loop,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
train_input_fn,
eval_input_fn,
custom_callbacks=callbacks)
class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
"""Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU, TPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` for GPUs and
`benchmark_(topology)_tpu_(dataset type)` for TPUs.
"""
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertClassifyBenchmarkReal, self).__init__(
output_dir=output_dir, tpu=tpu)
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
self.bert_config_file = MODEL_CONFIG_FILE_PATH
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
# Since we only care about performance metrics, we limit
# the number of training steps and epochs to prevent unnecessarily
# long tests.
self.num_steps_per_epoch = 100
self.num_epochs = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0,
max_accuracy=1,
use_ds=True):
"""Starts BERT performance benchmark test."""
start_time_sec = time.time()
self._run_bert_classifier(callbacks=[self.timer_callback], use_ds=use_ds)
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
# Since we do not load from any pretrained checkpoints, we ignore all
# accuracy metrics.
summary.pop('eval_metrics', None)
summary['start_time_sec'] = start_time_sec
super(BertClassifyBenchmarkReal, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def benchmark_1_gpu_mrpc(self):
"""Test BERT model performance with 1 GPU."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_mrpc_xla(self):
"""Test BERT model performance with 1 GPU."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_xla')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
FLAGS.enable_xla = True
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_mrpc_no_dist_strat(self):
"""Test BERT model performance with 1 GPU, no distribution strategy."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_no_dist_strat')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_8_gpu_mrpc(self):
"""Test BERT model performance with 8 GPUs."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
"""Performance for 1 GPU no DS with automatic mixed precision."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_amp_mrpc_no_dist_strat')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_8_gpu_amp_mrpc(self):
"""Test BERT model performance with 8 GPUs with automatic mixed precision.
"""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 32
FLAGS.eval_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_2x2_tpu_mrpc(self):
"""Test BERT model performance with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 32
FLAGS.eval_batch_size = 32
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Short accuracy test for BERT model.
Tests BERT classification task model accuracy. The naming
convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, output_dir=TMP_DIR, **kwargs):
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
self.bert_config_file = MODEL_CONFIG_FILE_PATH
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
super(BertClassifyAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.84,
max_accuracy=0.88):
"""Starts BERT accuracy benchmark test."""
start_time_sec = time.time()
self._run_bert_classifier(callbacks=[self.timer_callback])
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
super(BertClassifyAccuracy, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def _setup(self):
super(BertClassifyAccuracy, self)._setup()
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.init_checkpoint = self.pretrained_checkpoint_path
def benchmark_8_gpu_mrpc(self):
"""Run BERT model accuracy test with 8 GPUs.
Due to comparatively small cardinality of MRPC dataset, training
accuracy metric has high variance between trainings. As so, we
set the wide range of allowed accuracy (84% to 88%).
"""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu_mrpc_xla(self):
"""Run BERT model accuracy test with 8 GPUs with XLA."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
FLAGS.enable_xla = True
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,126 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions or classes shared between BERT benchmarks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
FLAGS = flags.FLAGS
class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
"""Callback that records time it takes to run each batch."""
def __init__(self, num_batches_to_skip=10):
super(BenchmarkTimerCallback, self).__init__()
self.batch_start_times = {}
self.batch_stop_times = {}
def on_batch_begin(self, batch, logs=None):
self.batch_start_times[batch] = time.time()
def on_batch_end(self, batch, logs=None):
# If there are multiple steps_per_loop, the end batch index will not be the
# same as the starting index. Use the last starting index instead.
if batch not in self.batch_start_times:
batch = max(self.batch_start_times.keys())
self.batch_stop_times[batch] = time.time()
def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
batch_durations = []
for batch in self.batch_start_times:
if batch in self.batch_stop_times and batch >= num_batches_to_skip:
batch_durations.append(self.batch_stop_times[batch] -
self.batch_start_times[batch])
return batch_size / np.mean(batch_durations)
def get_startup_time(self, program_start_time):
return self.batch_start_times[0] - program_start_time
class BertBenchmarkBase(PerfZeroBenchmark):
"""Base class to hold methods common to test classes."""
local_flags = None
def __init__(self, output_dir=None):
super(BertBenchmarkBase, self).__init__(output_dir=output_dir)
self.num_gpus = 8
self.timer_callback = None
def _setup(self):
"""Sets up and resets flags before each test."""
super(BertBenchmarkBase, self)._setup()
self.timer_callback = BenchmarkTimerCallback()
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics = [{
'name': 'training_loss',
'value': stats['train_loss'],
}]
if self.timer_callback:
metrics.append({
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
FLAGS.steps_per_loop)
})
else:
metrics.append({
'name': 'exp_per_second',
'value': 0.0,
})
if self.timer_callback and 'start_time_sec' in stats:
metrics.append({
'name': 'startup_time',
'value': self.timer_callback.get_startup_time(stats['start_time_sec'])
})
if 'eval_metrics' in stats:
metrics.append({
'name': 'eval_accuracy',
'value': stats['eval_metrics'],
'min_value': min_accuracy,
'max_value': max_accuracy,
})
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark(
iters=stats['total_training_steps'],
wall_time=wall_time_sec,
metrics=metrics,
extras={'flags': flags_str})
@@ -0,0 +1,654 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes BERT SQuAD benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.benchmark import benchmark_wrappers
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data'
SQUAD_LONG_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_long_meta_data'
SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data'
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
# pylint: enable=line-too-long
TMP_DIR = os.getenv('TMPDIR')
FLAGS = flags.FLAGS
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, tpu=None):
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir)
self.tpu = tpu
def _read_training_summary_from_file(self):
"""Reads the training summary from a file."""
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
with tf.io.gfile.GFile(summary_path, 'rb') as reader:
return json.loads(reader.read().decode('utf-8'))
def _read_input_meta_data_from_file(self):
"""Reads the input metadata from a file."""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
return json.loads(reader.read().decode('utf-8'))
def _get_distribution_strategy(self, ds_type='mirrored'):
"""Gets the distribution strategy.
Args:
ds_type: String, the distribution strategy type to be used. Can be
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribution_utils.get_distribution_strategy(
distribution_strategy=ds_type,
num_gpus=self.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg)
def _init_gpu_and_data_threads(self):
"""Set env variables before any TF calls."""
if FLAGS.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=FLAGS.per_gpu_thread_count,
gpu_thread_mode=FLAGS.tf_gpu_thread_mode,
num_gpus=self.num_gpus,
datasets_num_private_threads=FLAGS.datasets_num_private_threads)
@flagsaver.flagsaver
def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
"""Runs BERT SQuAD training. Uses mirrored strategy by default."""
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(ds_type)
run_squad.train_squad(
strategy=strategy,
input_meta_data=input_meta_data,
run_eagerly=run_eagerly,
custom_callbacks=[self.timer_callback])
@flagsaver.flagsaver
def _evaluate_squad(self, ds_type='mirrored'):
"""Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(ds_type)
if input_meta_data.get('version_2_with_negative', False):
logging.error('In memory evaluation result for SQuAD v2 is not accurate')
eval_metrics = run_squad.eval_squad(strategy=strategy,
input_meta_data=input_meta_data)
# Use F1 score as reported evaluation metric.
self.eval_metrics = eval_metrics['final_f1']
class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
"""Short benchmark performance tests for BERT SQuAD model.
Tests BERT SQuAD performance in different GPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format for GPUs and
`benchmark_(topology)_tpu` format for TPUs.
"""
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
super(BertSquadBenchmarkReal, self)._setup()
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
run_eagerly=False,
ds_type='mirrored'):
"""Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4 or run_eagerly:
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
else:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
start_time_sec = time.time()
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
summary['start_time_sec'] = start_time_sec
super(BertSquadBenchmarkReal, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=0,
max_accuracy=1)
def benchmark_1_gpu(self):
"""Tests BERT SQuAD model performance with 1 GPU."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark()
def benchmark_1_gpu_eager(self):
"""Tests BERT SQuAD model performance with 1 GPU."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
FLAGS.train_batch_size = 2
self._run_and_report_benchmark(run_eagerly=True)
def benchmark_1_gpu_xla(self):
"""Tests BERT SQuAD model performance with 1 GPU with XLA."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad')
# XLA runs out of memory when running with batch size 4.
FLAGS.train_batch_size = 3
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat(self):
"""Tests BERT SQuAD model performance with 1 GPU without DS."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(ds_type='off')
def benchmark_1_gpu_eager_no_dist_strat(self):
"""Tests BERT SQuAD model performance with 1 GPU with eager execution."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_eager_no_dist_strat_squad')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
def benchmark_2_gpu(self):
"""Tests BERT SQuAD model performance with 2 GPUs."""
self._setup()
self.num_gpus = 2
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad')
FLAGS.train_batch_size = 8
self._run_and_report_benchmark()
def benchmark_4_gpu(self):
"""Tests BERT SQuAD model performance with 4 GPUs."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad')
FLAGS.train_batch_size = 16
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
"""Tests BERT SQuAD model performance with 8 GPUs."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
FLAGS.train_batch_size = 24
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_1_gpu_fp16_eager(self):
"""Tests BERT SQuAD model performance with 1 GPU and FP16."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16_eager')
FLAGS.train_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark(run_eagerly=True)
def benchmark_1_gpu_fp16(self):
"""Tests BERT SQuAD model performance with 1 GPU and FP16."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16')
FLAGS.train_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_1_gpu_xla_fp16(self):
"""Tests BERT SQuAD model performance with 1 GPU with XLA and FP16."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad_fp16')
FLAGS.train_batch_size = 4
FLAGS.enable_xla = True
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_2_gpu_fp16(self):
"""Tests BERT SQuAD model performance with 2 GPUs and FP16."""
self._setup()
self.num_gpus = 2
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad_fp16')
FLAGS.train_batch_size = 8
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_4_gpu_fp16(self):
"""Tests BERT SQuAD model performance with 4 GPUs and FP16."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad_fp16')
FLAGS.train_batch_size = 16
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
"""Tests BERT SQuAD model performance with 8 GPUs."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_8_gpu_xla_fp16(self):
"""Tests BERT SQuAD model performance with 8 GPUs with XLA."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
FLAGS.train_batch_size = 32
FLAGS.enable_xla = True
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp_squad')
FLAGS.train_batch_size = 4
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_4_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_amp_squad')
FLAGS.train_batch_size = 16
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_squad')
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_2x2_tpu(self):
"""Tests BERT SQuAD model performance with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_batch_size = 48
self._run_and_report_benchmark()
class BertSquadAccuracy(BertSquadBenchmarkBase):
"""Short accuracy test for BERT SQuAD model.
Tests BERT SQuAD accuracy. The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format for GPUs and
`benchmark_(topology)_tpu` format for TPUs.
"""
def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
super(BertSquadAccuracy, self)._setup()
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
run_eagerly=False,
ds_type='mirrored'):
"""Runs the benchmark and reports various metrics."""
start_time_sec = time.time()
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
self._evaluate_squad(ds_type=ds_type)
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
summary['eval_metrics'] = self.eval_metrics
summary['start_time_sec'] = start_time_sec
super(BertSquadAccuracy, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=0.900,
max_accuracy=0.920)
def benchmark_1_gpu_eager(self):
"""Tests BERT SQuAD model accuracy with 1 GPU with eager execution."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
def benchmark_8_gpu(self):
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
FLAGS.train_batch_size = 24
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
"""Tests BERT SQuAD model accuracy with 8 GPUs and FP16."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_8_gpu_xla(self):
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_xla')
FLAGS.train_batch_size = 32
FLAGS.enable_xla = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_2x2_tpu(self):
"""Tests BERT SQuAD model accuracy with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_batch_size = 48
self._run_and_report_benchmark()
class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
"""BERT SQuAD distributed accuracy tests with multiple workers."""
def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadMultiWorkerAccuracy, self).__init__(
output_dir=output_dir, tpu=tpu)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
super(BertSquadMultiWorkerAccuracy, self)._setup()
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
start_time_sec = time.time()
self._train_squad(run_eagerly=run_eagerly,
ds_type='multi_worker_mirrored')
self._evaluate_squad(ds_type='multi_worker_mirrored')
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
summary['eval_metrics'] = self.eval_metrics
super(BertSquadMultiWorkerAccuracy, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=0.900,
max_accuracy=0.920)
def _benchmark_common(self, num_workers, all_reduce_alg):
"""Common to all benchmarks in this class."""
self._setup()
num_gpus = 8
FLAGS.num_gpus = num_gpus
FLAGS.dtype = 'fp16'
FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir(
'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
num_workers, all_reduce_alg))
FLAGS.train_batch_size = 4 * num_gpus * num_workers
FLAGS.all_reduce_alg = all_reduce_alg
self._run_and_report_benchmark()
def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='ring')
def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
"""BERT SQuAD distributed benchmark tests with multiple workers."""
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadMultiWorkerBenchmark, self).__init__(
output_dir=output_dir, tpu=tpu)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
super(BertSquadMultiWorkerBenchmark, self)._setup()
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 100
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4 * 8:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
else:
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
start_time_sec = time.time()
self._train_squad(run_eagerly=run_eagerly,
ds_type='multi_worker_mirrored')
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
summary['start_time_sec'] = start_time_sec
super(BertSquadMultiWorkerBenchmark, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=0,
max_accuracy=1)
def _benchmark_common(self, num_workers, all_reduce_alg):
"""Common to all benchmarks in this class."""
self._setup()
num_gpus = 8
FLAGS.num_gpus = num_gpus
FLAGS.dtype = 'fp16'
FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir(
'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
num_workers, all_reduce_alg))
FLAGS.train_batch_size = 4 * num_gpus * num_workers
FLAGS.all_reduce_alg = all_reduce_alg
self._run_and_report_benchmark()
def benchmark_8_gpu_1_worker_fp16_ring_tweaked(self):
"""8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
self._benchmark_common(num_workers=1, all_reduce_alg='ring')
def benchmark_8_gpu_1_worker_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=1, all_reduce_alg='nccl')
def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='ring')
def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,56 @@
[
{
"description": "The ID of the benchmark run, where this metric should tie to.",
"mode": "REQUIRED",
"name": "run_id",
"type": "STRING"
},
{
"description": "The name of the metric, which should be descriptive. E.g. training_loss, accuracy.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The unit of the metric. E.g. MB per sec.",
"mode": "NULLABLE",
"name": "unit",
"type": "STRING"
},
{
"description": "The value of the metric.",
"mode": "NULLABLE",
"name": "value",
"type": "FLOAT"
},
{
"description": "The timestamp when the metric is recorded.",
"mode": "REQUIRED",
"name": "timestamp",
"type": "TIMESTAMP"
},
{
"description": "The global step when this metric is recorded.",
"mode": "NULLABLE",
"name": "global_step",
"type": "INTEGER"
},
{
"description": "Free format metadata for the extra information about the metric.",
"mode": "REPEATED",
"name": "extras",
"type": "RECORD",
"fields": [
{
"mode": "NULLABLE",
"name": "name",
"type": "STRING"
},
{
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
]
}
]
@@ -0,0 +1,368 @@
[
{
"description": "The UUID of the run for the benchmark.",
"mode": "REQUIRED",
"name": "model_id",
"type": "STRING"
},
{
"description": "The name of the model, E.g ResNet50, LeNet-5 etc.",
"mode": "REQUIRED",
"name": "model_name",
"type": "STRING"
},
{
"description": "The date when the test of the model is started",
"mode": "REQUIRED",
"name": "run_date",
"type": "TIMESTAMP"
},
{
"description": "The unique name for a test by the combination of key parameters, eg batch size, num of GPU, etc. It is hardware independent.",
"mode": "NULLABLE",
"name": "test_id",
"type": "STRING"
},
{
"description": "The tensorflow version information.",
"fields": [
{
"description": "Version of the tensorflow. E.g. 1.7.0-rc0",
"mode": "REQUIRED",
"name": "version",
"type": "STRING"
},
{
"description": "Git Hash of the tensorflow",
"mode": "NULLABLE",
"name": "git_hash",
"type": "STRING"
},
{
"description": "The channel of the tensorflow binary, eg, nightly, RC, final, custom.",
"mode": "NULLABLE",
"name": "channel",
"type": "STRING"
},
{
"description": "Identify anything special about the build, eg CUDA 10, NCCL, MKL, etc.",
"mode": "NULLABLE",
"name": "build_type",
"type": "STRING"
}
],
"mode": "REQUIRED",
"name": "tensorflow_version",
"type": "RECORD"
},
{
"description": "The arbitrary attribute of the model.",
"fields": [
{
"description": "The name of the attribute.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the attribute.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "attribute",
"type": "RECORD"
},
{
"description": "Environment variables when the benchmark run is executed.",
"fields": [
{
"description": "The name of the variable.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the variable.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "environment_variable",
"type": "RECORD"
},
{
"description": "TF Environment variables when the benchmark run is executed.",
"fields": [
{
"description": "The name of the variable.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the variable.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "tensorflow_environment_variables",
"type": "RECORD"
},
{
"description": "The list of parameters run with the model. It could contain hyperparameters or others.",
"fields": [
{
"description": "The name of the parameter.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The string value of the parameter.",
"mode": "NULLABLE",
"name": "string_value",
"type": "STRING"
},
{
"description": "The bool value of the parameter.",
"mode": "NULLABLE",
"name": "bool_value",
"type": "STRING"
},
{
"description": "The int/long value of the parameter.",
"mode": "NULLABLE",
"name": "long_value",
"type": "INTEGER"
},
{
"description": "The double/float value of parameter.",
"mode": "NULLABLE",
"name": "float_value",
"type": "FLOAT"
}
],
"mode": "REPEATED",
"name": "run_parameters",
"type": "RECORD"
},
{
"description": "The dataset that run with the benchmark.",
"mode": "NULLABLE",
"name": "dataset",
"type": "RECORD",
"fields": [
{
"description": "The name of the dataset that the model is trained/validated with. E.g ImageNet, mnist.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The arbitrary attribute of the dataset.",
"fields": [
{
"description": "The name of the attribute.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the attribute.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "attribute",
"type": "RECORD"
}
]
},
{
"description": "Used to differentiate from AWS, GCE or DGX-1 at a high level",
"mode": "NULLABLE",
"name": "test_environment",
"type": "STRING"
},
{
"description": "The machine configuration of the benchmark run.",
"mode": "NULLABLE",
"name": "machine_config",
"type": "RECORD",
"fields": [
{
"description": "The platform information of the benchmark run.",
"mode": "NULLABLE",
"name": "platform_info",
"type": "RECORD",
"fields": [
{
"description": "Eg: 64bit.",
"mode": "NULLABLE",
"name": "bits",
"type": "STRING"
},
{
"description": "Eg: ELF.",
"mode": "NULLABLE",
"name": "linkage",
"type": "STRING"
},
{
"description": "Eg: i386.",
"mode": "NULLABLE",
"name": "machine",
"type": "STRING"
},
{
"description": "Eg: 3.13.0-76-generic.",
"mode": "NULLABLE",
"name": "release",
"type": "STRING"
},
{
"description": "Eg: Linux.",
"mode": "NULLABLE",
"name": "system",
"type": "STRING"
},
{
"description": "Eg: #120-Ubuntu SMP Mon Jan 18 15:59:10 UTC 2016.",
"mode": "NULLABLE",
"name": "version",
"type": "STRING"
}
]
},
{
"description": "The CPU information of the benchmark run.",
"mode": "NULLABLE",
"name": "cpu_info",
"type": "RECORD",
"fields": [
{
"mode": "NULLABLE",
"name": "num_cores",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "num_cores_allowed",
"type": "INTEGER"
},
{
"description" : "How fast are those CPUs.",
"mode": "NULLABLE",
"name": "mhz_per_cpu",
"type": "FLOAT"
},
{
"description" : "Additional CPU info, Eg: Intel Ivybridge with HyperThreading (24 cores).",
"mode": "NULLABLE",
"name": "cpu_info",
"type": "STRING"
},
{
"description" : "What kind of cpu scaling is enabled on the host. Eg performance, ondemand, conservative, mixed.",
"mode": "NULLABLE",
"name": "cpu_governor",
"type": "STRING"
},
{
"description": "Cache size of the CPUs.",
"mode": "NULLABLE",
"name": "cache_size",
"type": "RECORD",
"fields": [
{
"mode": "NULLABLE",
"name": "level",
"type": "STRING"
},
{
"mode": "NULLABLE",
"name": "size",
"type": "INTEGER"
}
]
}
]
},
{
"mode": "NULLABLE",
"name": "gpu_info",
"type": "RECORD",
"fields": [
{
"mode": "NULLABLE",
"name": "count",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "model",
"type": "STRING"
},
{
"mode": "NULLABLE",
"name": "cuda_version",
"type": "STRING"
}
]
},
{
"description": "The cloud instance inforation if the benchmark run is executed on cloud",
"mode": "NULLABLE",
"name": "cloud_info",
"type": "RECORD",
"fields": [
{
"description": "The instance type, E.g. n1-standard-4.",
"mode": "NULLABLE",
"name": "instance_type",
"type": "STRING"
},
{
"description": "The arbitrary attribute of the cloud info.",
"fields": [
{
"description": "The name of the attribute.",
"mode": "REQUIRED",
"name": "name",
"type": "STRING"
},
{
"description": "The value of the attribute.",
"mode": "NULLABLE",
"name": "value",
"type": "STRING"
}
],
"mode": "REPEATED",
"name": "attribute",
"type": "RECORD"
}
]
},
{
"mode": "NULLABLE",
"name": "memory_total",
"type": "INTEGER"
},
{
"mode": "NULLABLE",
"name": "memory_available",
"type": "STRING"
}
]
}
]
@@ -0,0 +1,14 @@
[
{
"description": "The UUID of the run for the benchmark.",
"mode": "REQUIRED",
"name": "run_id",
"type": "STRING"
},
{
"description": "The status of the run for the benchmark. Eg, running, failed, success",
"mode": "REQUIRED",
"name": "status",
"type": "STRING"
}
]
@@ -0,0 +1,98 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
from official.utils.flags import core as flags_core
class KerasBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing."""
def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
tpu=None):
super(KerasBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _report_benchmark(self,
stats,
wall_time_sec,
top_1_max=None,
top_1_min=None,
log_steps=None,
total_batch_size=None,
warmup=1,
start_time_sec=None):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from keras models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
top_1_max: highest passing level for top_1 accuracy.
top_1_min: lowest passing level for top_1 accuracy.
log_steps: How often the log was created for stats['step_timestamp_log'].
total_batch_size: Global batch-size.
warmup: number of entries in stats['step_timestamp_log'] to ignore.
start_time_sec: the start time of the program in seconds since epoch
"""
metrics = []
if 'accuracy_top_1' in stats:
metrics.append({'name': 'accuracy_top_1',
'value': stats['accuracy_top_1'],
'min_value': top_1_min,
'max_value': top_1_max})
metrics.append({'name': 'top_1_train_accuracy',
'value': stats['training_accuracy_top_1']})
if (warmup and 'step_timestamp_log' in stats and
len(stats['step_timestamp_log']) > warmup):
# first entry in the time_log is start of step 1. The rest of the
# entries are the end of each step recorded
time_log = stats['step_timestamp_log']
elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
num_examples = (
total_batch_size * log_steps * (len(time_log) - warmup - 1))
examples_per_sec = num_examples / elapsed
metrics.append({'name': 'exp_per_second',
'value': examples_per_sec})
if 'avg_exp_per_second' in stats:
metrics.append({'name': 'avg_exp_per_second',
'value': stats['avg_exp_per_second']})
if start_time_sec and 'step_timestamp_log' in stats:
time_log = stats['step_timestamp_log']
# time_log[0] is recorded at the beginning of the first step.
startup_time = time_log[0].timestamp - start_time_sec
metrics.append({'name': 'startup_time', 'value': startup_time})
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark(
iters=-1,
wall_time=wall_time_sec,
metrics=metrics,
extras={'flags': flags_str})
@@ -0,0 +1,402 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.benchmark import keras_benchmark
from official.benchmark import benchmark_wrappers
from official.benchmark.models import resnet_cifar_main
MIN_TOP_1_ACCURACY = 0.929
MAX_TOP_1_ACCURACY = 0.938
FLAGS = flags.FLAGS
CIFAR_DATA_DIR_NAME = 'cifar-10-batches-bin'
class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
"""Accuracy tests for ResNet56 Keras CIFAR-10."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""A benchmark class.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
flag_methods = [resnet_cifar_main.define_cifar_flags]
super(Resnet56KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
def _setup(self):
super(Resnet56KerasAccuracy, self)._setup()
FLAGS.use_tensor_lr = False
def benchmark_graph_1_gpu(self):
"""Test keras based model with Keras fit and distribution strategies."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
"""Test keras based model with eager and distribution strategies."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
self._run_and_report_benchmark()
def benchmark_cpu(self):
"""Test keras based model on CPU."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
FLAGS.data_format = 'channels_last'
self._run_and_report_benchmark()
def benchmark_cpu_no_dist_strat(self):
"""Test keras based model on CPU without distribution strategies."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'off'
FLAGS.data_format = 'channels_last'
self._run_and_report_benchmark()
def benchmark_cpu_no_dist_strat_run_eagerly(self):
"""Test keras based model on CPU w/forced eager and no dist_strat."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir(
'benchmark_cpu_no_dist_strat_run_eagerly')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
FLAGS.run_eagerly = True
FLAGS.distribution_strategy = 'off'
FLAGS.data_format = 'channels_last'
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat(self):
"""Test keras based model with eager and no dist strat."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
"""Test keras based model w/forced eager and no dist_strat."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_no_dist_strat_run_eagerly')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
FLAGS.run_eagerly = True
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_graph_1_gpu_no_dist_strat(self):
"""Test keras based model with Keras fit but not distribution strategies."""
self._setup()
FLAGS.distribution_strategy = 'off'
FLAGS.num_gpus = 1
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
def benchmark_2_gpu(self):
"""Test keras based model with eager and distribution strategies."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
self._run_and_report_benchmark()
def benchmark_graph_2_gpu(self):
"""Test keras based model with Keras fit and distribution strategies."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_cifar_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet56KerasAccuracy, self)._report_benchmark(
stats,
wall_time_sec,
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY,
total_batch_size=FLAGS.batch_size,
log_steps=100)
class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Short performance tests for ResNet56 via Keras and CIFAR-10."""
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [resnet_cifar_main.define_cifar_flags]
super(Resnet56KerasBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_cifar_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet56KerasBenchmarkBase, self)._report_benchmark(
stats,
wall_time_sec,
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu(self):
"""Test 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu_xla(self):
"""Test 1 gpu with xla enabled."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.run_eagerly = False
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_graph_1_gpu(self):
"""Test 1 gpu graph."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = False
FLAGS.run_eagerly = False
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat(self):
"""Test 1 gpu without distribution strategies."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'off'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_graph_1_gpu_no_dist_strat(self):
"""Test 1 gpu graph mode without distribution strategies."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = False
FLAGS.distribution_strategy = 'off'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
"""Test 1 gpu without distribution strategy and forced eager."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_no_dist_strat_run_eagerly')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
FLAGS.run_eagerly = True
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_2_gpu(self):
"""Test 2 gpu."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.enable_eager = True
FLAGS.run_eagerly = False
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
FLAGS.batch_size = 128 * 2 # 2 GPUs
self._run_and_report_benchmark()
def benchmark_graph_2_gpu(self):
"""Test 2 gpu graph mode."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.enable_eager = False
FLAGS.run_eagerly = False
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
FLAGS.batch_size = 128 * 2 # 2 GPUs
self._run_and_report_benchmark()
def benchmark_cpu(self):
"""Test cpu."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.enable_eager = True
FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
FLAGS.batch_size = 128
FLAGS.data_format = 'channels_last'
self._run_and_report_benchmark()
def benchmark_graph_cpu(self):
"""Test cpu graph mode."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.enable_eager = False
FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu')
FLAGS.batch_size = 128
FLAGS.data_format = 'channels_last'
self._run_and_report_benchmark()
def benchmark_cpu_no_dist_strat_run_eagerly(self):
"""Test cpu without distribution strategy and forced eager."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.distribution_strategy = 'off'
FLAGS.enable_eager = True
FLAGS.run_eagerly = True
FLAGS.model_dir = self._get_model_dir(
'benchmark_cpu_no_dist_strat_run_eagerly')
FLAGS.batch_size = 128
FLAGS.data_format = 'channels_last'
self._run_and_report_benchmark()
def benchmark_cpu_no_dist_strat(self):
"""Test cpu without distribution strategies."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'off'
FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
FLAGS.batch_size = 128
FLAGS.data_format = 'channels_last'
self._run_and_report_benchmark()
def benchmark_graph_cpu_no_dist_strat(self):
"""Test cpu graph mode without distribution strategies."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.enable_eager = False
FLAGS.distribution_strategy = 'off'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu_no_dist_strat')
FLAGS.batch_size = 128
FLAGS.data_format = 'channels_last'
self._run_and_report_benchmark()
class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
"""Synthetic benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
default_flags = {}
default_flags['skip_eval'] = True
default_flags['use_synthetic_data'] = True
default_flags['train_steps'] = 110
default_flags['log_steps'] = 10
default_flags['use_tensor_lr'] = False
super(Resnet56KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=default_flags)
class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
"""Real data benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
default_flags = {}
default_flags['skip_eval'] = True
default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
default_flags['train_steps'] = 110
default_flags['log_steps'] = 10
default_flags['use_tensor_lr'] = False
super(Resnet56KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=default_flags)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,287 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the Cifar-10 dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
from official.benchmark.models import resnet_cifar_model
from official.benchmark.models import synthetic_util
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification.resnet import cifar_preprocessing
from official.vision.image_classification.resnet import common
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182)
]
def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
del current_batch, batches_per_epoch # not used
initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch:
learning_rate = initial_learning_rate * mult
else:
break
return learning_rate
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Attributes:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.steps_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
logging.debug(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
"""
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=flags_obj.per_gpu_thread_count,
gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
num_gpus=flags_obj.num_gpus,
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
if strategy:
# flags_obj.enable_get_next_as_optional controls whether enabling
# get_next_as_optional behavior in DistributedIterator. If true, last
# partial batch can be supported.
strategy.extended.experimental_enable_get_next_as_optional = (
flags_obj.enable_get_next_as_optional
)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data:
synthetic_util.set_up_synthetic_data()
input_fn = common.get_synth_input_fn(
height=cifar_preprocessing.HEIGHT,
width=cifar_preprocessing.WIDTH,
num_channels=cifar_preprocessing.NUM_CHANNELS,
num_classes=cifar_preprocessing.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj),
drop_remainder=True)
else:
synthetic_util.undo_set_up_synthetic_data()
input_fn = cifar_preprocessing.input_fn
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
parse_record_fn=cifar_preprocessing.parse_record,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype,
# Setting drop_remainder to avoid the partial batch logic in normalization
# layer, which triggers tf.where and leads to extra memory copy of input
# sizes between host and GPU.
drop_remainder=(not flags_obj.enable_get_next_as_optional))
eval_input_dataset = None
if not flags_obj.skip_eval:
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
parse_record_fn=cifar_preprocessing.parse_record)
steps_per_epoch = (
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
values=[initial_learning_rate] +
list(p[0] * initial_learning_rate for p in LR_SCHEDULE))
with strategy_scope:
optimizer = common.get_optimizer(lr_schedule)
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
model.compile(
loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=(['sparse_categorical_accuracy']
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
train_epochs = flags_obj.train_epochs
callbacks = common.get_callbacks(steps_per_epoch)
if not flags_obj.use_tensor_lr:
lr_callback = LearningRateBatchScheduler(
schedule=learning_rate_schedule,
batch_size=flags_obj.batch_size,
steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback)
# if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps:
steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
train_epochs = 1
num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
flags_obj.batch_size)
validation_data = eval_input_dataset
if flags_obj.skip_eval:
if flags_obj.set_learning_phase_to_train:
# TODO(haoyuzhang): Understand slowdown of setting learning phase when
# not using distribution strategy.
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
if not strategy and flags_obj.explicit_gpu_placement:
# TODO(b/135607227): Add device scope automatically in Keras training loop
# when not using distribition strategy.
no_dist_strat_device = tf.device('/device:GPU:0')
no_dist_strat_device.__enter__()
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_steps=num_eval_steps,
validation_data=validation_data,
validation_freq=flags_obj.epochs_between_evals,
verbose=2)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=2)
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_cifar_flags():
common.define_keras_flags(dynamic_loss_scale=False)
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model',
epochs_between_evals=10,
batch_size=128)
def main(_):
with logger.benchmark_context(flags.FLAGS):
return run(flags.FLAGS)
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
define_cifar_flags()
app.run(main)
@@ -0,0 +1,262 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet56 model for Keras adapted from tf.keras.applications.ResNet50.
# Reference:
- [Deep Residual Learning for Image Recognition](
https://arxiv.org/abs/1512.03385)
Adapted from code contributed by BigMoyan.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import regularizers
BATCH_NORM_DECAY = 0.997
BATCH_NORM_EPSILON = 1e-5
L2_WEIGHT_DECAY = 2e-4
def identity_building_block(input_tensor,
kernel_size,
filters,
stage,
block,
training=None):
"""The identity block is the block that has no conv layer at shortcut.
Arguments:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: current block label, used for generating layer names
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor for the block.
"""
filters1, filters2 = filters
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, kernel_size,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x, training=training)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x, training=training)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
return x
def conv_building_block(input_tensor,
kernel_size,
filters,
stage,
block,
strides=(2, 2),
training=None):
"""A block that has a conv layer at shortcut.
Arguments:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: current block label, used for generating layer names
strides: Strides for the first conv layer in the block.
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor for the block.
Note that from stage 3,
the first conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2 = filters
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, kernel_size, strides=strides,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x, training=training)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x, training=training)
shortcut = layers.Conv2D(filters2, (1, 1), strides=strides, use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut, training=training)
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
return x
def resnet_block(input_tensor,
size,
kernel_size,
filters,
stage,
conv_strides=(2, 2),
training=None):
"""A block which applies conv followed by multiple identity blocks.
Arguments:
input_tensor: input tensor
size: integer, number of constituent conv/identity building blocks.
A conv block is applied once, followed by (size - 1) identity blocks.
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
conv_strides: Strides for the first conv layer in the block.
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor after applying conv and identity blocks.
"""
x = conv_building_block(input_tensor, kernel_size, filters, stage=stage,
strides=conv_strides, block='block_0',
training=training)
for i in range(size - 1):
x = identity_building_block(x, kernel_size, filters, stage=stage,
block='block_%d' % (i + 1), training=training)
return x
def resnet(num_blocks, classes=10, training=None):
"""Instantiates the ResNet architecture.
Arguments:
num_blocks: integer, the number of conv/identity blocks in each block.
The ResNet contains 3 blocks with each block containing one conv block
followed by (layers_per_block - 1) number of idenity blocks. Each
conv/idenity block has 2 convolutional layers. With the input
convolutional layer and the pooling layer towards the end, this brings
the total size of the network to (6*num_blocks + 2)
classes: optional number of classes to classify images into
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
A Keras model instance.
"""
input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else: # channel_last
x = img_input
bn_axis = 3
x = layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = layers.Conv2D(16, (3, 3),
strides=(1, 1),
padding='valid', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1',)(x, training=training)
x = layers.Activation('relu')(x)
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16],
stage=2, conv_strides=(1, 1), training=training)
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[32, 32],
stage=3, conv_strides=(2, 2), training=training)
x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64],
stage=4, conv_strides=(2, 2), training=training)
rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
x = layers.Dense(classes,
activation='softmax',
kernel_initializer=initializers.RandomNormal(stddev=0.01),
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc10')(x)
inputs = img_input
# Create model.
model = tf.keras.models.Model(inputs, x, name='resnet56')
return model
resnet20 = functools.partial(resnet, num_blocks=3)
resnet32 = functools.partial(resnet, num_blocks=5)
resnet56 = functools.partial(resnet, num_blocks=9)
resnet10 = functools.partial(resnet, num_blocks=110)
@@ -0,0 +1,187 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the keras ResNet model with Cifar data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.benchmark.models import resnet_cifar_main
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification.resnet import cifar_preprocessing
class KerasCifarTest(googletest.TestCase):
"""Unit tests for Keras ResNet with Cifar."""
_extra_flags = [
"-batch_size", "4",
"-train_steps", "1",
"-use_synthetic_data", "true"
]
_tempdir = None
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCifarTest, cls).setUpClass()
resnet_cifar_main.define_cifar_flags()
def setUp(self):
super(KerasCifarTest, self).setUp()
cifar_preprocessing.NUM_IMAGES["validation"] = 4
def tearDown(self):
super(KerasCifarTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
def test_end_to_end_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [
"-distribution_strategy", "off",
"-model_dir", "keras_cifar_no_dist_strat",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_graph_no_dist_strat(self):
"""Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
extra_flags = [
"-enable_eager", "false",
"-distribution_strategy", "off",
"-model_dir", "keras_cifar_graph_no_dist_strat",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_1_gpu(self):
"""Test Keras model with 1 GPU."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 1:
self.skipTest(
"{} GPUs are not available for this test. {} GPUs are available".
format(1, context.num_gpus()))
extra_flags = [
"-num_gpus", "1",
"-distribution_strategy", "mirrored",
"-model_dir", "keras_cifar_1_gpu",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_graph_1_gpu(self):
"""Test Keras model in legacy graph mode with 1 GPU."""
if context.num_gpus() < 1:
self.skipTest(
"{} GPUs are not available for this test. {} GPUs are available".
format(1, context.num_gpus()))
extra_flags = [
"-num_gpus", "1",
"-noenable_eager",
"-distribution_strategy", "mirrored",
"-model_dir", "keras_cifar_graph_1_gpu",
"-data_format", "channels_last",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_2_gpu(self):
"""Test Keras model with 2 GPUs."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
if context.num_gpus() < 2:
self.skipTest(
"{} GPUs are not available for this test. {} GPUs are available".
format(2, context.num_gpus()))
extra_flags = [
"-num_gpus", "2",
"-distribution_strategy", "mirrored",
"-model_dir", "keras_cifar_2_gpu",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
def test_end_to_end_graph_2_gpu(self):
"""Test Keras model in legacy graph mode with 2 GPUs."""
if context.num_gpus() < 2:
self.skipTest(
"{} GPUs are not available for this test. {} GPUs are available".
format(2, context.num_gpus()))
extra_flags = [
"-num_gpus", "2",
"-enable_eager", "false",
"-distribution_strategy", "mirrored",
"-model_dir", "keras_cifar_graph_2_gpu",
]
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
if __name__ == "__main__":
googletest.main()
@@ -0,0 +1,31 @@
# Shakespeare character LSTM model
This is an implemention of a simple character LSTM used to generate text.
## Instructions
First download the source data:
```
wget https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
```
Note that files other than shakepeare.txt can also be used to train the model to generater other text.
Then train the model:
```python
python3 shakespeare_main.py --training_data shakespeare.txt \
--model_dir /tmp/shakespeare
```
This will place model checkpoints in `/tmp/shakespeare`, so that we can use them to make predictions.
Then generate predictions:
```python
python3 shakespeare_main.py --training_data shakespeare.txt \
--model_dir /tmp/shakespeare --notrain --predict_context=ROMEO:
```
Change `--predict_context` and `--predict_length` to suit your needs.
@@ -0,0 +1,316 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a character LSTM model trained on Shakespeare."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
# pylint: disable=wrong-import-order
from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
# pylint: enable=wrong-import-order
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
EMBEDDING_DIM = 256
RNN_UNITS = 1024
SEQ_LENGTH = 100
# Calculated by running batch_size=1
BATCHES_PER_EPOCH = 11043
def define_flags():
"""Define the flags for the Shakespeare character LSTM."""
flags_core.define_base(data_dir=False,
clean=False,
train_epochs=True,
epochs_between_evals=False,
stop_threshold=False,
num_gpu=True,
hooks=False,
export_dir=False,
run_eagerly=True,
distribution_strategy=True)
flags_core.define_performance(num_parallel_calls=False,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=True,
loss_scale=True,
enable_xla=True)
flags_core.set_defaults(train_epochs=43,
batch_size=64)
flags.DEFINE_boolean(name='enable_eager', default=True, help='Enable eager?')
flags.DEFINE_boolean(
name='train', default=True,
help='If true trains the model.')
flags.DEFINE_string(
name='predict_context', default=None,
help='If set, makes a prediction with the given context.')
flags.DEFINE_integer(
name='predict_length', default=1000,
help='Length of the predicted text including the context.')
flags.DEFINE_integer(name='train_steps', default=None,
help='Overrides train_steps per epoch if not None.')
flags.DEFINE_integer(
name='log_steps', default=100,
help='For every log_steps, we log the timing information such as '
'examples per second.')
flags.DEFINE_string(
name='training_data', default=None,
help='Path to file containing the training data.')
flags.DEFINE_boolean(name='cudnn', default=True, help='Use CuDNN LSTM.')
def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
"""Creates a dataset from a given text file.
Args:
path_to_file: The path to the training data.
batch_size: Batch size to use.
seq_length: The length of the LSTM sequence.
Returns:
A tuple, consisting of the Dataset and the class to character mapping
and character to class mapping.
"""
with tf.io.gfile.GFile(path_to_file, 'rb') as train_data:
text = train_data.read().decode(encoding='utf-8')
# Create vocab
vocab = sorted(set(text))
char2idx = {u: i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
# Split text into sequence length + 1 chucks to create examples
text_as_int = np.array([char2idx[c] for c in text])
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length+1, drop_remainder=True)
def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, tf.one_hot(target_text, len(vocab))
dataset = sequences.map(split_input_target)
dataset = dataset.shuffle(10000).repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset, idx2char, char2idx
def build_model(vocab_size,
embedding_dim=EMBEDDING_DIM,
rnn_units=RNN_UNITS,
batch_size=None,
stateful=False,
use_cudnn=True):
"""Builds the Shakespeare model.
Args:
vocab_size: The number of character classes in the input.
embedding_dim: The dimension of the embedding space for each class.
rnn_units: The number of RNN units in the layer.
batch_size: When predicting, the batch size of the predictions.
stateful: If true, the LSTM is stateful.
Returns:
A Keras Model.
"""
assert keras_utils.is_v2_0()
LSTM = functools.partial(tf.keras.layers.LSTM, implementation=2)
# By indirecting the activation through a lambda layer, the logic to dispatch
# to CuDNN in V2 doesn't trigger and we force the LSTM to run in non-CuDNN
# mode.
lstm_activation = ('tanh' if use_cudnn else
lambda x: tf.math.tanh(x))
batch_shape = [batch_size if stateful else None, None]
return tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim,
batch_input_shape=batch_shape),
LSTM(rnn_units,
activation=lstm_activation,
return_sequences=True,
stateful=stateful,
recurrent_initializer='glorot_uniform'),
tf.keras.layers.Dense(vocab_size),
tf.keras.layers.Softmax(dtype=tf.float32)])
def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
"""Trains a Shakespeare model.
Args:
flags_obj: An object containing parsed flag values.s
dataset: the training data set.
vocab_size: the number of unique character classes.
strategy: distribution strategy to use.
checkpoint_dir: if not None, the directory in which to make checkpoints.
Returns:
The training history and callbacks.
"""
if flags_obj.train_steps:
train_steps = flags_obj.train_steps
else:
train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size
strategy_scope = distribution_utils.get_strategy_scope(strategy)
with strategy_scope:
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size,
use_cudnn=flags_obj.cudnn)
# When keras_use_ctl is False, Model.fit() automatically applies
# loss scaling so we don't need to create a LossScaleOptimizer.
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.Recall(top_k=1, name='RecallAt1'),
tf.keras.metrics.Recall(top_k=5, name='RecallAt5')],
run_eagerly=flags_obj.run_eagerly)
callbacks = []
if checkpoint_dir:
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch}')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
callbacks.append(checkpoint_callback)
time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
flags_obj.log_steps)
callbacks.append(time_callback)
history = model.fit(dataset,
epochs=flags_obj.train_epochs,
steps_per_epoch=train_steps,
callbacks=callbacks,
verbose=2)
return history, callbacks
def make_prediction(checkpoint_dir, length, context, idx2char, char2idx):
"""Make predictions from a Shakespeare model.
Args:
checkpoint_dir: the directory from which to load checkpoints
length: the total length of the generated text (including the context).
context: the initial text with which the LSTM is primed.
idx2char: the character class to character mapping.
char2idx: the character to character class mapping.
Returns:
A generated string of text of the given length.
"""
prediction_model = build_model(
vocab_size=len(idx2char), batch_size=1, stateful=True)
prediction_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
prediction_model.build(tf.TensorShape([1, None]))
input_eval = [char2idx[s] for s in context]
input_eval = tf.expand_dims(input_eval, 0)
text_generated = []
prediction_model.reset_states()
for _ in range(length - len(context)):
predictions = prediction_model(input_eval)
predictions = tf.squeeze(predictions, 0)
# We applied a softmax to the output of the model so that
# tf.keras.metrics.Recall would work. We need logits for
# tf.random.categorical, so we convert the probabilities back to log odds
predictions = tf.math.log(predictions / (1 - predictions))
random_output = tf.random.categorical(predictions, num_samples=1)
selected_id = random_output[-1, 0].numpy()
input_eval = tf.expand_dims([selected_id], 0)
text_generated.append(idx2char[selected_id])
return context + ''.join(text_generated)
def run(flags_obj):
"""Run Shakespeare training and predict.
Args:
flags_obj: An object containing parsed flag values.
Returns:
Dictionary with status from the run.
"""
if not flags_obj.training_data:
raise ValueError(
'Must set the path to a training data file. e.g download the following '
'https://storage.googleapis.com/download.tensorflow.org/data/'
'shakespeare.txt')
if flags_obj.dtype == 'fp16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16',
loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16='dynamic'))
tf.keras.mixed_precision.experimental.set_policy(policy)
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus)
dataset, idx2char, char2idx = get_dataset(flags_obj.training_data,
batch_size=flags_obj.batch_size)
stats = {}
if flags_obj.train:
history, callbacks = train_model(flags_obj, dataset,
len(idx2char), strategy,
checkpoint_dir=flags_obj.model_dir)
stats['history'] = history.history
stats['callbacks'] = callbacks
if flags_obj.predict_context:
if not flags_obj.model_dir:
raise ValueError('Must set model_dir to get predictions.')
print(make_prediction(flags_obj.model_dir,
flags_obj.predict_length,
flags_obj.predict_context,
idx2char,
char2idx))
return stats
def main(_):
flags_obj = flags.FLAGS
run(flags_obj)
if __name__ == '__main__':
define_flags()
app.run(main)
@@ -0,0 +1,129 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to generate data directly on devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import string
from absl import logging
import tensorflow as tf
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
# dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t)
variable_data.append(v)
initializers.append(v.initializer)
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self):
return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset):
logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
else:
return SyntheticDataset(dataset)
def make_iterator(self, dataset):
dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
@@ -0,0 +1,457 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
from official.benchmark import benchmark_wrappers
from official.recommendation import ncf_common
from official.recommendation import ncf_keras_main
from official.utils.flags import core
FLAGS = flags.FLAGS
NCF_DATA_DIR_NAME = 'movielens_data'
NCF_TF_DATA_1M_BATCH_DIR_NAME = 'gs://tf-perfzero-data/movielens_data/ncf_8gpu_1M_batch'
class NCFKerasBenchmarkBase(tf.test.Benchmark):
"""Base class for NCF model benchmark."""
local_flags = None
def __init__(self,
output_dir=None,
default_flags=None,
**kwargs):
self.output_dir = output_dir
self.default_flags = default_flags or {}
# Run all benchmarks with ml_perf flag.
self.default_flags['ml_perf'] = True
def _setup(self):
"""Sets up and resets flags before each test."""
logging.set_verbosity(logging.INFO)
if NCFKerasBenchmarkBase.local_flags is None:
ncf_common.define_ncf_flags()
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
core.set_defaults(**self.default_flags)
saved_flag_values = flagsaver.save_flag_values()
NCFKerasBenchmarkBase.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(NCFKerasBenchmarkBase.local_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, hr_at_10_min=0, hr_at_10_max=0):
start_time_sec = time.time()
stats = ncf_keras_main.run_ncf(FLAGS)
wall_time_sec = time.time() - start_time_sec
metrics = []
metrics.append({'name': 'exp_per_second',
'value': stats['avg_exp_per_second']})
if hr_at_10_min > 0:
metrics.append({'name': 'hr_at_10',
'value': stats['eval_hit_rate'],
'min_value': hr_at_10_min,
'max_value': hr_at_10_max})
metrics.append({'name': 'train_loss',
'value': stats['loss']})
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)
class NCFKerasAccuracy(NCFKerasBenchmarkBase):
"""Benchmark NCF model using real data."""
def __init__(self,
output_dir=None,
root_data_dir=None,
default_flags=None,
**kwargs):
root_data_dir = root_data_dir if root_data_dir else ''
default_flags = {}
default_flags['dataset'] = 'ml-20m'
default_flags['num_gpus'] = 1
default_flags['train_epochs'] = 10
default_flags['clean'] = True
default_flags['batch_size'] = 99000
default_flags['learning_rate'] = 0.00382059
default_flags['beta1'] = 0.783529
default_flags['beta2'] = 0.909003
default_flags['epsilon'] = 1.45439e-07
default_flags['layers'] = [256, 256, 128, 64]
default_flags['num_factors'] = 64
default_flags['hr_threshold'] = 0.635
default_flags['ml_perf'] = True
default_flags['use_synthetic_data'] = False
default_flags['data_dir'] = os.path.join(root_data_dir, NCF_DATA_DIR_NAME)
super(NCFKerasAccuracy, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
**kwargs)
def _run_and_report_benchmark_mlperf_like(self):
"""Run test and report results.
Note: MLPerf like tests are not tuned to hit a specific hr@10 value, but
we want it recorded.
"""
self._run_and_report_benchmark(hr_at_10_min=0.61)
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.645):
"""Run test and report results.
Note: Target is 0.635, but some runs are below that level. Until we have
multi-run tests, we have to accept a lower target.
Args:
hr_at_10_min: Minimum acceptable hr@10 value.
hr_at_10_max: Maximum acceptable hr@10 value.
"""
super(NCFKerasAccuracy, self)._run_and_report_benchmark(
hr_at_10_min=hr_at_10_min,
hr_at_10_max=hr_at_10_max)
def benchmark_1_gpu_early_stop(self):
self._setup()
FLAGS.early_stopping = True
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_early_stop(self):
self._setup()
FLAGS.distribution_strategy = 'off'
FLAGS.early_stopping = True
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_run_eagerly_early_stop(self):
self._setup()
FLAGS.distribution_strategy = 'off'
FLAGS.early_stopping = True
FLAGS.run_eagerly = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_early_stop(self):
self._setup()
FLAGS.early_stopping = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_1_gpu_ctl_early_stop(self):
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True
self._run_and_report_benchmark()
def benchmark_1_gpu_ctl_run_eagerly_early_stop(self):
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True
FLAGS.run_eagerly = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_ctl_early_stop(self):
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_2_gpus_early_stop(self):
self._setup()
FLAGS.early_stopping = True
FLAGS.num_gpus = 2
FLAGS.eval_batch_size = 160000
self._run_and_report_benchmark()
def benchmark_2_gpus_ctl_early_stop(self):
"""NCF with custom training loop. Works only in TF 2.0."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True
FLAGS.num_gpus = 2
FLAGS.eval_batch_size = 160000
self._run_and_report_benchmark()
#############################################
# Tests below with mlperf in the test name are of two types:
# 1) 1 GPU tests are based on MLPerf 0.5 and the TensorFlow pulled submission.
# 2) 8 GPU tests are based on MLPerf 0.5 and use NVIDIA's hyper parameters.
#
# The purpose of both is to get a number to compare to existing results. To do
# this the number of epochs is held constant rather than a race to a given
# accuracy. The accuracy validation is done by the "early_stop" tests.
#############################################
def benchmark_1_gpu_mlperf_like(self):
"""1 GPU using keras fit/compile."""
self._setup()
FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_no_dist_strat_mlperf_like(self):
"""1 GPU using compile/fit without dist_strat."""
self._setup()
FLAGS.train_epochs = 7
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_no_dist_strat_run_eagerly_mlperf_like(self):
self._setup()
FLAGS.train_epochs = 7
FLAGS.distribution_strategy = 'off'
FLAGS.run_eagerly = True
self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_1_gpu_mlperf_like(self):
"""1 GPU using compile/fit with XLA."""
self._setup()
FLAGS.train_epochs = 7
FLAGS.enable_xla = True
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_mlperf_like(self):
"""1 GPU using CTL."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_fp16_mlperf_like(self):
"""1 GPU using CTL and FP16."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_fp16_mlperf_like(self):
"""1 GPU using FP16."""
self._setup()
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_fp16_graph_rewrite_mlperf_like(self):
"""1 GPU using CTL and FP16 graph rewrite."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_fp16_graph_rewrite_mlperf_like(self):
"""1 GPU using FP16 graph rewrite."""
self._setup()
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self):
"""1 GPU using CTL with eager and distribution strategy."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.run_eagerly = True
FLAGS.train_epochs = 7
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_ctl_mlperf_like(self):
"""1 GPU using CTL with XLA."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.enable_xla = True
FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_1_gpu_fp16_mlperf_like(self):
"""1 GPU using with XLA and FP16."""
self._setup()
FLAGS.enable_xla = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_1_gpu_ctl_fp16_mlperf_like(self):
"""1 GPU using CTL with XLA and FP16."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.enable_xla = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_mlperf_like(self):
"""8 GPU using keras fit/compile."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 160000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_ctl_mlperf_like(self):
"""8 GPU using CTL."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 160000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_tf_data_ctl_mlperf_like(self):
"""8 GPU using CTL."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 1048000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
FLAGS.train_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "training_cycle_*/*")
FLAGS.eval_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "eval_data/*")
FLAGS.input_meta_data_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "meta_data.json")
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_tf_data_fp16_mlperf_like(self):
"""8 GPU FP16"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 1048000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
FLAGS.train_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "training_cycle_*/*")
FLAGS.eval_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "eval_data/*")
FLAGS.input_meta_data_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "meta_data.json")
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_tf_data_ctl_fp16_mlperf_like(self):
"""8 GPU FP16 using CTL"""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 1048000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
FLAGS.train_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "training_cycle_*/*")
FLAGS.eval_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "eval_data/*")
FLAGS.input_meta_data_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME, "meta_data.json")
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_tf_data_ctl_fp16_graph_rewrite_mlperf_like(self):
"""8 GPU FP16 graph rewrite using CTL."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 1048000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.loss_scale = 8192
FLAGS.train_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME,
'training_cycle_*/*')
FLAGS.eval_dataset_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME,
'eval_data/*')
FLAGS.input_meta_data_path = os.path.join(NCF_TF_DATA_1M_BATCH_DIR_NAME,
'meta_data.json')
self._run_and_report_benchmark_mlperf_like()
class NCFKerasSynth(NCFKerasBenchmarkBase):
"""Benchmark NCF model using synthetic data."""
def __init__(self,
output_dir=None,
default_flags=None,
**kwargs):
default_flags = {}
default_flags['dataset'] = 'ml-20m'
default_flags['num_gpus'] = 1
default_flags['train_epochs'] = 8
default_flags['batch_size'] = 99000
default_flags['eval_batch_size'] = 160000
default_flags['learning_rate'] = 0.00382059
default_flags['beta1'] = 0.783529
default_flags['beta2'] = 0.909003
default_flags['epsilon'] = 1.45439e-07
default_flags['layers'] = [256, 256, 128, 64]
default_flags['num_factors'] = 64
default_flags['hr_threshold'] = 0.635
default_flags['use_synthetic_data'] = True
super(NCFKerasSynth, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
**kwargs)
def benchmark_1_gpu(self):
self._setup()
self._run_and_report_benchmark()
def benchmark_2_gpus(self):
self._setup()
FLAGS.num_gpus = 2
self._run_and_report_benchmark()
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,91 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for creating PerfZero benchmarks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
FLAGS = flags.FLAGS
class PerfZeroBenchmark(tf.test.Benchmark):
"""Common methods used in PerfZero Benchmarks.
Handles the resetting of flags between tests, loading of default_flags,
overriding of defaults. PerfZero (OSS) runs each test in a separate
process reducing some need to reset the flags.
"""
local_flags = None
def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
tpu=None):
"""Initialize class.
Args:
output_dir: Base directory to store all output for the test.
default_flags: Set of flags to pass to model.
flag_methods: Set of flag methods to run during setup.
tpu: (optional) TPU name to use in a TPU benchmark.
"""
if os.getenv('BENCHMARK_OUTPUT_DIR'):
self.output_dir = os.getenv('BENCHMARK_OUTPUT_DIR')
elif output_dir:
self.output_dir = output_dir
else:
self.output_dir = '/tmp'
self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {}
if os.getenv('BENCHMARK_TPU'):
resolved_tpu = os.getenv('BENCHMARK_TPU')
elif tpu:
resolved_tpu = tpu
else:
resolved_tpu = None
if resolved_tpu:
# TPU models are expected to accept a --tpu=name flag. PerfZero creates
# the TPU at runtime and passes the TPU's name to this flag.
self.default_flags['tpu'] = resolved_tpu
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
logging.set_verbosity(logging.INFO)
if PerfZeroBenchmark.local_flags is None:
for flag_method in self.flag_methods:
flag_method()
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
# Overrides flag values with defaults for the class of tests.
for k, v in self.default_flags.items():
setattr(FLAGS, k, v)
saved_flag_values = flagsaver.save_flag_values()
PerfZeroBenchmark.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(PerfZeroBenchmark.local_flags)
@@ -0,0 +1,412 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes CTL benchmarks and accuracy tests."""
# pylint: disable=line-too-long,g-bad-import-order
from __future__ import print_function
import os
import time
from absl import flags
import tensorflow as tf
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_ctl_imagenet_main
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
from official.benchmark import benchmark_wrappers
from official.utils.flags import core as flags_core
MIN_TOP_1_ACCURACY = 0.76
MAX_TOP_1_ACCURACY = 0.77
FLAGS = flags.FLAGS
class CtlBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
self.output_dir = output_dir
self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {}
super(CtlBenchmark, self).__init__(
output_dir=self.output_dir,
default_flags=self.default_flags,
flag_methods=self.flag_methods)
def _report_benchmark(self,
stats,
wall_time_sec,
top_1_max=None,
top_1_min=None,
total_batch_size=None,
log_steps=None,
warmup=1,
start_time_sec=None):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from keras models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
top_1_max: highest passing level for top_1 accuracy.
top_1_min: lowest passing level for top_1 accuracy.
total_batch_size: Global batch-size.
log_steps: How often the log was created for stats['step_timestamp_log'].
warmup: number of entries in stats['step_timestamp_log'] to ignore.
start_time_sec: the start time of the program in seconds since epoch.
"""
metrics = []
if 'eval_acc' in stats:
metrics.append({
'name': 'accuracy_top_1',
'value': stats['eval_acc'],
'min_value': top_1_min,
'max_value': top_1_max
})
metrics.append({'name': 'eval_loss', 'value': stats['eval_loss']})
metrics.append({
'name': 'top_1_train_accuracy',
'value': stats['train_acc']
})
metrics.append({'name': 'train_loss', 'value': stats['train_loss']})
if (warmup and 'step_timestamp_log' in stats and
len(stats['step_timestamp_log']) > warmup + 1):
# first entry in the time_log is start of step 0. The rest of the
# entries are the end of each step recorded
time_log = stats['step_timestamp_log']
steps_elapsed = time_log[-1].batch_index - time_log[warmup].batch_index
time_elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
examples_per_sec = total_batch_size * (steps_elapsed / time_elapsed)
metrics.append({'name': 'exp_per_second', 'value': examples_per_sec})
if 'avg_exp_per_second' in stats:
metrics.append({
'name': 'avg_exp_per_second',
'value': stats['avg_exp_per_second']
})
if start_time_sec and 'step_timestamp_log' in stats:
time_log = stats['step_timestamp_log']
# time_log[0] is recorded at the beginning of the first step.
startup_time = time_log[0].timestamp - start_time_sec
metrics.append({'name': 'startup_time', 'value': startup_time})
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark(
iters=-1,
wall_time=wall_time_sec,
metrics=metrics,
extras={'flags': flags_str})
class Resnet50CtlAccuracy(CtlBenchmark):
"""Benchmark accuracy tests for ResNet50 in CTL."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""A benchmark class.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more named
arguments before updating the constructor.
"""
flag_methods = [common.define_keras_flags]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
super(Resnet50CtlAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
def benchmark_8_gpu(self):
"""Test Keras model with eager, dist_strat and 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
"""Test Keras model with eager, 8 GPUs with tf.keras mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 256 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
FLAGS.dtype = 'fp16'
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with 8 GPUs and mixed precision via graph rewrite."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 256 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_ctl_imagenet_main.run(flags.FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet50CtlAccuracy, self)._report_benchmark(
stats,
wall_time_sec,
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY,
total_batch_size=FLAGS.batch_size,
log_steps=100,
start_time_sec=start_time_sec)
def _get_model_dir(self, folder_name):
return os.path.join(self.output_dir, folder_name)
class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [common.define_keras_flags]
super(Resnet50CtlBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_ctl_imagenet_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
# Number of logged step time entries that are excluded in performance
# report. We keep results from last 100 batches in this case.
warmup = (FLAGS.train_steps - 100) // FLAGS.log_steps
super(Resnet50CtlBenchmarkBase, self)._report_benchmark(
stats,
wall_time_sec,
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
warmup=warmup,
start_time_sec=start_time_sec)
def benchmark_1_gpu_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'off'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
"""Test Keras model with 1 GPU."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu_fp16(self):
"""Test Keras model with 1 GPU with tf.keras mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
FLAGS.batch_size = 256
FLAGS.dtype = 'fp16'
self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Test Keras model with 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
FLAGS.batch_size = 256
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_amp(self):
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
FLAGS.batch_size = 256
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_1_gpu_eager(self):
"""Test Keras model with 1 GPU in pure eager mode."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
FLAGS.batch_size = 120
FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark()
def benchmark_1_gpu_fp16_eager(self):
"""Test Keras model with 1 GPU with fp16 and pure eager mode."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_eager')
FLAGS.batch_size = 240
FLAGS.dtype = 'fp16'
FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
"""Test Keras model with 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
"""Test Keras model with 8 GPUs with tf.keras mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
self._run_and_report_benchmark()
def benchmark_8_gpu_eager(self):
"""Test Keras model with 8 GPUs, eager, fp32."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_8_gpu_eager_fp16(self):
"""Test Keras model with 8 GPUs, eager, fp16."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager_fp16')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
"""Test Keras model with 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self):
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def fill_report_object(self, stats):
super(Resnet50CtlBenchmarkBase, self).fill_report_object(
stats, total_batch_size=FLAGS.batch_size, log_steps=FLAGS.log_steps)
class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
"""Resnet50 synthetic benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20
def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags)
class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
"""Resnet50 real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20
def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,294 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes RetinaNet benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order
import copy
import json
import os
import time
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.utils.flags import core as flags_core
from official.benchmark import benchmark_wrappers
from official.vision.detection import main as detection
TMP_DIR = os.getenv('TMPDIR')
FLAGS = flags.FLAGS
# pylint: disable=line-too-long
COCO_TRAIN_DATA = 'gs://tf-perfzero-data/coco/train*'
COCO_EVAL_DATA = 'gs://tf-perfzero-data/coco/val*'
COCO_EVAL_JSON = 'gs://tf-perfzero-data/coco/instances_val2017.json'
RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoint-2018-02-07'
# pylint: enable=line-too-long
class DetectionBenchmarkBase(tf.test.Benchmark):
"""Base class to hold methods common to test classes."""
local_flags = None
def __init__(self, output_dir=None):
self.num_gpus = 8
if not output_dir:
output_dir = '/tmp'
self.output_dir = output_dir
self.timer_callback = None
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
self.timer_callback = benchmark_utils.BenchmarkTimerCallback()
if DetectionBenchmarkBase.local_flags is None:
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
DetectionBenchmarkBase.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(DetectionBenchmarkBase.local_flags)
def _report_benchmark(self,
stats,
wall_time_sec,
min_ap,
max_ap,
train_batch_size=None):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from Detection models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_ap: Minimum detection AP constraint to verify correctness of the
model.
max_ap: Maximum detection AP accuracy constraint to verify correctness of
the model.
train_batch_size: Train batch size. It is needed for computing
exp_per_second.
"""
metrics = [{
'name': 'total_loss',
'value': stats['total_loss'],
}]
if self.timer_callback:
metrics.append({
'name': 'exp_per_second',
'value': self.timer_callback.get_examples_per_sec(train_batch_size)
})
else:
metrics.append({
'name': 'exp_per_second',
'value': 0.0,
})
if 'eval_metrics' in stats:
metrics.append({
'name': 'AP',
'value': stats['AP'],
'min_value': min_ap,
'max_value': max_ap,
})
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark(
iters=stats['total_steps'],
wall_time=wall_time_sec,
metrics=metrics,
extras={'flags': flags_str})
class RetinanetBenchmarkBase(DetectionBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, **kwargs):
self.train_data_path = COCO_TRAIN_DATA
self.eval_data_path = COCO_EVAL_DATA
self.eval_json_path = COCO_EVAL_JSON
self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH
super(RetinanetBenchmarkBase, self).__init__(output_dir=output_dir)
def _run_detection_main(self):
"""Starts detection job."""
if self.timer_callback:
return detection.run(callbacks=[self.timer_callback])
else:
return detection.run()
class RetinanetAccuracy(RetinanetBenchmarkBase):
"""Accuracy test for RetinaNet model.
Tests RetinaNet detection task model accuracy. The naming
convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, output_dir=TMP_DIR, **kwargs):
super(RetinanetAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, min_ap=0.325, max_ap=0.35):
"""Starts RetinaNet accuracy benchmark test."""
start_time_sec = time.time()
FLAGS.mode = 'train'
summary, _ = self._run_detection_main()
wall_time_sec = time.time() - start_time_sec
FLAGS.mode = 'eval'
eval_metrics = self._run_detection_main()
summary.update(eval_metrics)
summary['train_batch_size'] = self.params_override['train']['batch_size']
summary['total_steps'] = self.params_override['train']['total_steps']
super(RetinanetAccuracy, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_ap=min_ap,
max_ap=max_ap,
train_batch_size=self.params_override['train']['batch_size'])
def _setup(self):
super(RetinanetAccuracy, self)._setup()
FLAGS.strategy_type = 'mirrored'
FLAGS.model = 'retinanet'
self.params_override = {
'train': {
'batch_size': 64,
'iterations_per_loop': 100,
'total_steps': 22500,
'train_file_pattern': self.train_data_path,
'checkpoint': {
'path': self.resnet_checkpoint_path,
'prefix': 'resnet50/'
},
},
'eval': {
'batch_size': 8,
'eval_samples': 5000,
'val_json_file': self.eval_json_path,
'eval_file_pattern': self.eval_data_path,
},
}
@flagsaver.flagsaver
def benchmark_8_gpu_coco(self):
"""Run RetinaNet model accuracy test with 8 GPUs."""
self._setup()
params = copy.deepcopy(self.params_override)
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_coco')
# Sets timer_callback to None as we do not use it now.
self.timer_callback = None
self._run_and_report_benchmark()
class RetinanetBenchmarkReal(RetinanetAccuracy):
"""Short benchmark performance tests for RetinaNet model.
Tests RetinaNet performance in different GPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
"""
def __init__(self, output_dir=TMP_DIR, **kwargs):
super(RetinanetBenchmarkReal, self).__init__(output_dir=output_dir)
@flagsaver.flagsaver
def benchmark_8_gpu_coco(self):
"""Run RetinaNet model accuracy test with 8 GPUs."""
self.num_gpus = 8
self._setup()
params = copy.deepcopy(self.params_override)
params['train']['total_steps'] = 1875 # One epoch.
# The iterations_per_loop must be one, otherwise the number of examples per
# second would be wrong. Currently only support calling callback per batch
# when each loop only runs on one batch, i.e. host loop for one step. The
# performance of this situation might be lower than the case of
# iterations_per_loop > 1.
# Related bug: b/135933080
params['train']['iterations_per_loop'] = 1
params['eval']['eval_samples'] = 8
FLAGS.num_gpus = self.num_gpus
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco')
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback')
else:
self._run_and_report_benchmark()
@flagsaver.flagsaver
def benchmark_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU."""
self.num_gpus = 1
self._setup()
params = copy.deepcopy(self.params_override)
params['train']['batch_size'] = 8
params['train']['total_steps'] = 200
params['train']['iterations_per_loop'] = 1
params['eval']['eval_samples'] = 8
FLAGS.num_gpus = self.num_gpus
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_1_gpu_coco')
FLAGS.strategy_type = 'one_device'
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback')
else:
self._run_and_report_benchmark()
@flagsaver.flagsaver
def benchmark_xla_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled."""
self.num_gpus = 1
self._setup()
params = copy.deepcopy(self.params_override)
params['train']['batch_size'] = 8
params['train']['total_steps'] = 200
params['train']['iterations_per_loop'] = 1
params['eval']['eval_samples'] = 8
FLAGS.num_gpus = self.num_gpus
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_1_gpu_coco')
FLAGS.strategy_type = 'one_device'
FLAGS.enable_xla = True
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback')
else:
self._run_and_report_benchmark()
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,359 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Shakespeare (LSTM) benchmark and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.benchmark.models.shakespeare import shakespeare_main
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
from official.benchmark import benchmark_wrappers
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt'
TMP_DIR = os.getenv('TMPDIR')
FLAGS = flags.FLAGS
class ShakespeareBenchmarkBase(PerfZeroBenchmark):
"""Base class for Shakespeare (LSTM) benchmark and accuracy tests."""
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None):
super(ShakespeareBenchmarkBase, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=[shakespeare_main.define_flags])
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
top_1_train_min=0.91,
top_1_train_max=0.94,
warmup=1,
log_steps=100):
"""Report benchmark results by writing to local protobuf file.
Average epoch time is calculated by skipping the first epoch. This average
ignores time spent between epoch and is recorded by begin and end epoch. To
skip accuracy check set `top_1_train_min=None`.
Args:
top_1_train_min: lowest passing value.
top_1_train_max: highest passing value.
warmup: number of entries in `timestamp_log` to ignore.
log_steps: How often the log was created for `timestamp_log`.
"""
total_batch_size = FLAGS.batch_size
metrics = []
start_time_sec = time.time()
stats = shakespeare_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
if top_1_train_min:
metrics.append({'name': 'accuracy_top_1_train',
'value': stats['history']['RecallAt1'][-1],
'min_value': top_1_train_min,
'max_value': top_1_train_max})
# Look for the time history callback which was used during keras.fit
for callback in stats['callbacks']:
if isinstance(callback, keras_utils.TimeHistory):
epoch_timings = callback.epoch_runtime_log
if len(epoch_timings) > 1:
average_time = sum(epoch_timings[1:]) / len(epoch_timings[1:])
metrics.append({'name': 'avg_epoch_time',
'value': average_time})
# First entry in timestamp_log is the start of step 1. The rest of the
# entries are the end of each step recorded.
time_log = callback.timestamp_log
elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
num_examples = (
total_batch_size * log_steps * (len(time_log) - warmup - 1))
if elapsed > 0:
examples_per_sec = num_examples / elapsed
metrics.append({'name': 'exp_per_second',
'value': examples_per_sec})
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark(iters=-1, wall_time=wall_time_sec,
metrics=metrics,
extras={'flags': flags_str})
class ShakespeareAccuracy(ShakespeareBenchmarkBase):
"""Shakespeare accuracy tests.
This is not an ideal test. The best we can use for the accuracy check is to
validate top_1 of the training set. At batch size 64 the top_1 training
stabilizes to ~0.92 around 40-45 epochs.
"""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""Shakespeare accuracy tests.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
self.train_data = os.path.join(root_data_dir, SHAKESPEARE_TRAIN_DATA)
super(ShakespeareAccuracy, self).__init__(
output_dir=output_dir, root_data_dir=root_data_dir)
def benchmark_cpu(self):
"""Benchmark cpu."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
self._run_and_report_benchmark()
def benchmark_cpu_no_ds_run_eagerly(self):
"""Benchmark cpu without distribution strategies and run eagerly."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
FLAGS.run_eagerly = True
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
"""Benchmark 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
self._run_and_report_benchmark()
def benchmark_1_gpu_no_ds(self):
"""Benchmark 1 gpu without distribution strategies."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_1_gpu_no_ds_run_eagerly(self):
"""Benchmark 1 gpu without distribution strategies and run eagerly."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
FLAGS.run_eagerly = True
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_xla_1_gpu(self):
"""Benchmark 1 gpu w/xla."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
"""Benchmark 8 gpu.
This is test is for accuracy not scaling. The batch-size is not scaled to
the number of gpus.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
self._run_and_report_benchmark()
class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
"""Benchmark accuracy tests."""
def __init__(self, output_dir=None, root_data_dir=TMP_DIR, **kwargs):
"""Benchmark tests w/Keras.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
self.train_data = os.path.join(root_data_dir, SHAKESPEARE_TRAIN_DATA)
def_flags = {}
def_flags['training_data'] = self.train_data
def_flags['model_dir'] = ''
def_flags['train_epochs'] = 4
def_flags['log_steps'] = 50
super(ShakespeareKerasBenchmarkReal, self).__init__(
output_dir=output_dir,
root_data_dir=root_data_dir,
default_flags=def_flags)
def benchmark_cpu(self):
"""Benchmark cpu."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.batch_size = 64
self._run_and_report_benchmark()
def benchmark_cpu_no_ds_run_eagerly(self):
"""Benchmark cpu without distribution strategy and run eagerly."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.batch_size = 64
FLAGS.distribution_strategy = 'off'
FLAGS.run_eagerly = True
self._run_and_report_benchmark()
def benchmark_cpu_no_ds(self):
"""Benchmark cpu without distribution strategy."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.batch_size = 64
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_cpu_no_ds_force_v2(self):
"""Benchmark cpu no ds, and force v2."""
self._setup()
FLAGS.num_gpus = 0
FLAGS.batch_size = 64
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
"""Benchmark 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 64
self._run_and_report_benchmark()
def benchmark_1_gpu_no_cudnn(self):
"""Benchmark 1 gpu with CuDNN disabled."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 64
FLAGS.cudnn = False
FLAGS.enable_eager = keras_utils.is_v2_0()
self._run_and_report_benchmark()
def benchmark_1_gpu_no_ds(self):
"""Benchmark 1 gpu without distribution strategies."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 64
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_1_gpu_no_ds_run_eagerly(self):
"""Benchmark 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 64
FLAGS.run_eagerly = True
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_xla_1_gpu(self):
"""Benchmark 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 64
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_no_cudnn(self):
"""Benchmark 1 gpu w/xla and CuDNN disabled."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 64
FLAGS.cudnn = False
FLAGS.enable_eager = keras_utils.is_v2_0()
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
"""Benchmark 8 gpu."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = 64 * 8
FLAGS.log_steps = 10
self._run_and_report_benchmark()
def benchmark_8_gpu_no_cudnn(self):
"""Benchmark 8 gpu with CuDNN disabled."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = 64 * 8
FLAGS.log_steps = 10
FLAGS.cudnn = False
FLAGS.enable_eager = keras_utils.is_v2_0()
self._run_and_report_benchmark()
def benchmark_xla_8_gpu(self):
"""Benchmark 8 gpu w/xla."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = 64 * 8
FLAGS.log_steps = 10
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_no_cudnn(self):
"""Benchmark 8 gpu w/xla and CuDNN disabled."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = 64 * 8
FLAGS.log_steps = 10
FLAGS.cudnn = False
FLAGS.enable_eager = keras_utils.is_v2_0()
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def _run_and_report_benchmark(self):
"""Run and report benchmark."""
super(ShakespeareKerasBenchmarkReal, self)._run_and_report_benchmark(
top_1_train_min=None, log_steps=FLAGS.log_steps)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,69 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a memory usage benchmark for a Tensorflow Hub model.
Loads a SavedModel and records memory usage.
"""
import functools
import time
from absl import flags
import tensorflow as tf
import tensorflow_hub as hub
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
FLAGS = flags.FLAGS
class TfHubMemoryUsageBenchmark(PerfZeroBenchmark):
"""A benchmark measuring memory usage for a given TF Hub SavedModel."""
def __init__(self,
hub_model_handle_list=None,
output_dir=None,
default_flags=None,
root_data_dir=None,
**kwargs):
super(TfHubMemoryUsageBenchmark, self).__init__(
output_dir=output_dir, default_flags=default_flags, **kwargs)
if hub_model_handle_list:
for hub_model_handle in hub_model_handle_list.split(';'):
# Converts a model handle of the form
# https://tfhub.dev/google/nnlm-en-dim128/1 to valid python method name
# like google_nnlm_en_dim128_1.
hub_model_method_name = hub_model_handle.replace(
'https://tfhub.dev',
'').replace('/', '_').replace('-', '_').strip('_')
setattr(
self, 'benchmark_' + hub_model_method_name,
functools.partial(self.benchmark_memory_usage, hub_model_handle))
def benchmark_memory_usage(
self, hub_model_handle='https://tfhub.dev/google/nnlm-en-dim128/1'):
start_time_sec = time.time()
self.load_model(hub_model_handle)
wall_time_sec = time.time() - start_time_sec
metrics = []
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)
def load_model(self, hub_model_handle):
"""Loads a TF Hub module."""
hub.load(hub_model_handle)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,681 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Transformer w/Keras benchmark and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import flags
import tensorflow as tf
from official.benchmark import benchmark_wrappers
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS
TMP_DIR = os.getenv('TMPDIR')
class TransformerBenchmark(PerfZeroBenchmark):
"""Methods common to executing transformer w/keras tests.
Code under test for the Transformer Keras models report the same data and
require the same FLAG setup.
"""
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None):
root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
self.vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
self.bleu_source = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.en')
self.bleu_ref = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de')
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
bleu_max=None,
bleu_min=None,
log_steps=None,
total_batch_size=None,
warmup=1):
"""Report benchmark results by writing to local protobuf file.
Args:
bleu_max: highest passing level for bleu score.
bleu_min: lowest passing level for bleu score.
log_steps: How often the log was created for stats['step_timestamp_log'].
total_batch_size: Global batch-size.
warmup: number of entries in stats['step_timestamp_log'] to ignore.
"""
start_time_sec = time.time()
task = transformer_main.TransformerTask(FLAGS)
stats = task.train()
wall_time_sec = time.time() - start_time_sec
metrics = []
if 'bleu_uncased' in stats:
if 'bleu_uncased_history' in stats:
bleu_uncased_best = max(stats['bleu_uncased_history'],
key=lambda x: x[1])
metrics.append({'name': 'bleu_uncased',
'value': bleu_uncased_best[1],
'min_value': bleu_min,
'max_value': bleu_max})
metrics.append({'name': 'bleu_best_score_iteration',
'value': bleu_uncased_best[0]})
metrics.append({'name': 'bleu_uncased_last',
'value': stats['bleu_uncased']})
else:
metrics.append({'name': 'bleu_uncased',
'value': stats['bleu_uncased'],
'min_value': bleu_min,
'max_value': bleu_max})
if (warmup and 'step_timestamp_log' in stats and
len(stats['step_timestamp_log']) > warmup):
# first entry in the time_log is start of step 1. The rest of the
# entries are the end of each step recorded
time_log = stats['step_timestamp_log']
elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
num_examples = (
total_batch_size * log_steps * (len(time_log) - warmup - 1))
examples_per_sec = num_examples / elapsed
metrics.append({'name': 'exp_per_second',
'value': examples_per_sec})
if 'avg_exp_per_second' in stats:
metrics.append({'name': 'avg_exp_per_second',
'value': stats['avg_exp_per_second']})
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics,
extras={'flags': flags_str})
class TransformerBaseKerasAccuracy(TransformerBenchmark):
"""Benchmark accuracy tests for Transformer Base model w/ Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""Benchmark accuracy tests for Transformer Base model w/ Keras.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
flag_methods = [misc.define_transformer_flags]
super(TransformerBaseKerasAccuracy, self).__init__(
output_dir=output_dir, root_data_dir=root_data_dir,
flag_methods=flag_methods)
def benchmark_1_gpu(self):
"""Benchmark 1 gpu.
The paper uses 8 GPUs and a much larger effective batch size, this is will
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 2048
FLAGS.train_steps = 1000
FLAGS.steps_between_evals = 500
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
# These bleu scores are based on test runs after at this limited
# number of steps and batch size after verifying SOTA at 8xV100s.
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=25.3,
bleu_max=26)
def benchmark_1_gpu_static_batch(self):
"""Benchmark 1 gpu with static_batch.
The paper uses 8 GPUs and a much larger effective batch size, this is will
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 5000
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
# These bleu scores are based on test runs after at this limited
# number of steps and batch size after verifying SOTA at 8xV100s.
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=25.3,
bleu_max=26)
def benchmark_8_gpu(self):
"""Benchmark 8 gpu.
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=27,
bleu_max=28)
def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu.
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.steps_between_evals = 5000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=27,
bleu_max=28)
class TransformerBigKerasAccuracy(TransformerBenchmark):
"""Benchmark accuracy tests for Transformer Big model w/ Keras."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
"""Benchmark accuracy tests for Transformer Big model w/ Keras.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
flag_methods = [misc.define_transformer_flags]
super(TransformerBigKerasAccuracy, self).__init__(
output_dir=output_dir, root_data_dir=root_data_dir,
flag_methods=flag_methods)
def benchmark_8_gpu(self):
"""Benchmark 8 gpu.
Over 6 runs with eval every 20K steps the average highest value was 28.195
(bleu uncased). 28.424 was the highest and 27.96 the lowest. The values are
the highest value seen during a run and occurred at a median of iteration 9.
Iterations are not epochs, an iteration is a number of steps between evals.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=27.9,
bleu_max=29.2)
def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.train_steps = 20000 * 12
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29.2)
def benchmark_8_gpu_fp16(self):
"""Benchmark 8 gpu with dynamic batch and fp16.
Over 6 runs with eval every 20K steps the average highest value was 28.247
(bleu uncased). 28.424 was the highest and 28.09 the lowest. The values are
the highest value seen during a run and occurred at a median of iteration
11. While this could be interpreted as worse than FP32, if looking at the
first iteration at which 28 is passed FP16 performs equal and possibly
better. Although not part of the initial test runs, the highest value
recorded with the arguments below was 28.9 at iteration 12. Iterations are
not epochs, an iteration is a number of steps between evals.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29.2)
def benchmark_8_gpu_fp16_amp(self):
"""Benchmark 8 gpu with dynamic batch and fp16 with automatic mixed precision.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_amp')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.train_steps = 400000
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch_fp16')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29.2)
def benchmark_xla_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch, XLA, and FP16.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.train_steps = 400000
FLAGS.steps_between_evals = 20000
FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_static_batch_fp16')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29.2)
class TransformerKerasBenchmark(TransformerBenchmark):
"""Benchmarks for Transformer (Base and Big) using Keras."""
def __init__(self, output_dir=None, default_flags=None,
root_data_dir=None, batch_per_gpu=4096):
"""Initialize.
Args:
output_dir: Based directory for saving artifacts, e.g. checkpoints.
default_flags: default flags to use for all tests.
root_data_dir: root directory for data, e.g. training.
batch_per_gpu: batch size to use per gpu.
"""
flag_methods = [misc.define_transformer_flags]
self.batch_per_gpu = batch_per_gpu
super(TransformerKerasBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
root_data_dir=root_data_dir,
flag_methods=flag_methods)
def benchmark_1_gpu_no_dist_strat(self):
"""Benchmark 1 gpu without distribution strategy."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'off'
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu_no_dist_strat_static_batch(self):
"""Benchmark 1 gpu without distribution strategy with static batch."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'off'
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_ds_sb')
FLAGS.static_batch = True
FLAGS.max_length = 64
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu(self):
"""Benchmark 1 gpu."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu_fp16(self):
"""Benchmark 1 gpu FP16."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
FLAGS.dtype = 'fp16'
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_1_gpu(self):
"""Benchmark 1 gpu w/xla."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
FLAGS.enable_xla = True
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_1_gpu_fp16(self):
"""Benchmark 1 gpu w/xla and FP16."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
FLAGS.enable_xla = True
FLAGS.dtype = 'fp16'
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu_static_batch(self):
"""Benchmark 1 gpu with static batch."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
FLAGS.static_batch = True
FLAGS.max_length = 64
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_1_gpu_static_batch(self):
"""Benchmark 1 gpu with static batch w/xla."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_static_batch')
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.enable_xla = True
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu_static_batch_fp16(self):
"""Benchmark 1 gpu with static batch FP16."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_static_batch_fp16')
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.dtype = 'fp16'
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_1_gpu_static_batch_fp16(self):
"""Benchmark 1 gpu with static batch w/xla and FP16."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_1_gpu_static_batch_fp16')
FLAGS.static_batch = True
FLAGS.max_length = 64
FLAGS.enable_xla = True
FLAGS.dtype = 'fp16'
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_8_gpu(self):
"""Benchmark 8 gpu."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_8_gpu_fp16(self):
"""Benchmark 8 gpu FP16."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_8_gpu(self):
"""Benchmark 8 gpu w/xla."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_xla = True
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_8_gpu_fp16(self):
"""Benchmark 8 gpu w/xla and FP16."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_xla = True
FLAGS.dtype = 'fp16'
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu with static batch."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
FLAGS.static_batch = True
FLAGS.max_length = 64
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch FP16."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir(
'benchmark_8_gpu_static_batch_fp16')
FLAGS.static_batch = True
FLAGS.max_length = 64
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_8_gpu_static_batch(self):
"""Benchmark 8 gpu with static batch w/xla."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_xla = True
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_static_batch')
FLAGS.static_batch = True
FLAGS.max_length = 64
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch w/xla and FP16."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_xla = True
FLAGS.dtype = 'fp16'
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_static_batch_fp16')
FLAGS.static_batch = True
FLAGS.max_length = 64
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests."""
def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
def_flags = {}
def_flags['param_set'] = 'base'
def_flags['train_steps'] = 50
def_flags['log_steps'] = 10
super(TransformerBaseKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags,
root_data_dir=root_data_dir, batch_per_gpu=4096)
class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests."""
def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs):
def_flags = {}
def_flags['param_set'] = 'big'
def_flags['train_steps'] = 50
def_flags['log_steps'] = 10
super(TransformerBigKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags,
root_data_dir=root_data_dir, batch_per_gpu=3072)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,216 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes XLNet benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp.xlnet import run_classifier
from official.nlp.xlnet import run_squad
from official.benchmark import benchmark_wrappers
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.train.tf_record'
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.dev.eval.tf_record'
SQUAD_DATA_PATH = 'gs://tf-perfzero-data/xlnet/squadv2_cased/'
# pylint: enable=line-too-long
FLAGS = flags.FLAGS
class XLNetBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None):
super(XLNetBenchmarkBase, self).__init__(output_dir)
self.num_epochs = None
self.num_steps_per_epoch = None
@flagsaver.flagsaver
def _run_xlnet_classifier(self):
"""Starts XLNet classification task."""
run_classifier.main(unused_argv=None)
@flagsaver.flagsaver
def _run_xlnet_squad(self):
"""Starts XLNet classification task."""
run_squad.main(unused_argv=None)
class XLNetClassifyAccuracy(XLNetBenchmarkBase):
"""Short accuracy test for XLNet classifier model.
Tests XLNet classification task model accuracy. The naming
convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, output_dir=None, **kwargs):
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
super(XLNetClassifyAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.95,
max_accuracy=0.97):
"""Starts XLNet accuracy benchmark test."""
start_time_sec = time.time()
self._run_xlnet_classifier()
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
super(XLNetClassifyAccuracy, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def _setup(self):
super(XLNetClassifyAccuracy, self)._setup()
FLAGS.test_data_size = 25024
FLAGS.train_batch_size = 16
FLAGS.seq_len = 512
FLAGS.mem_len = 0
FLAGS.n_layer = 24
FLAGS.d_model = 1024
FLAGS.d_embed = 1024
FLAGS.n_head = 16
FLAGS.d_head = 64
FLAGS.d_inner = 4096
FLAGS.untie_r = True
FLAGS.n_class = 2
FLAGS.ff_activation = 'gelu'
FLAGS.strategy_type = 'mirror'
FLAGS.learning_rate = 2e-5
FLAGS.train_steps = 4000
FLAGS.warmup_steps = 500
FLAGS.iterations = 200
FLAGS.bi_data = False
FLAGS.init_checkpoint = self.pretrained_checkpoint_path
FLAGS.train_tfrecord_path = self.train_data_path
FLAGS.test_tfrecord_path = self.eval_data_path
def benchmark_8_gpu_imdb(self):
"""Run XLNet model accuracy test with 8 GPUs."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_imdb')
# Sets timer_callback to None as we do not use it now.
self.timer_callback = None
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
class XLNetSquadAccuracy(XLNetBenchmarkBase):
"""Short accuracy test for XLNet squad model.
Tests XLNet squad task model accuracy. The naming
convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, output_dir=None, **kwargs):
self.train_data_path = SQUAD_DATA_PATH
self.predict_file = os.path.join(SQUAD_DATA_PATH, "dev-v2.0.json")
self.test_data_path = os.path.join(SQUAD_DATA_PATH, "12048.eval.tf_record")
self.spiece_model_file = os.path.join(SQUAD_DATA_PATH, "spiece.cased.model")
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
super(XLNetSquadAccuracy, self).__init__(output_dir=output_dir)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=87.0,
max_accuracy=89.0):
"""Starts XLNet accuracy benchmark test."""
start_time_sec = time.time()
self._run_xlnet_squad()
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
super(XLNetSquadAccuracy, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def _setup(self):
super(XLNetSquadAccuracy, self)._setup()
FLAGS.train_batch_size = 16
FLAGS.seq_len = 512
FLAGS.mem_len = 0
FLAGS.n_layer = 24
FLAGS.d_model = 1024
FLAGS.d_embed = 1024
FLAGS.n_head = 16
FLAGS.d_head = 64
FLAGS.d_inner = 4096
FLAGS.untie_r = True
FLAGS.ff_activation = 'gelu'
FLAGS.strategy_type = 'mirror'
FLAGS.learning_rate = 3e-5
FLAGS.train_steps = 8000
FLAGS.warmup_steps = 1000
FLAGS.iterations = 1000
FLAGS.bi_data = False
FLAGS.init_checkpoint = self.pretrained_checkpoint_path
FLAGS.train_tfrecord_path = self.train_data_path
FLAGS.test_tfrecord_path = self.test_data_path
FLAGS.spiece_model_file = self.spiece_model_file
FLAGS.predict_file = self.predict_file
FLAGS.adam_epsilon=1e-6
FLAGS.lr_layer_decay_rate=0.75
def benchmark_8_gpu_squadv2(self):
"""Run XLNet model squad v2 accuracy test with 8 GPUs."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squadv2')
FLAGS.predict_dir = FLAGS.model_dir
# Sets timer_callback to None as we do not use it now.
self.timer_callback = None
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,383 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "How-to Guide: Using a PIP package for fine-tuning a BERT model",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "5T_-iFRIqliG",
"colab_type": "text"
},
"source": [
"## How-to Guide: Using a PIP package for fine-tuning a BERT model\n",
"\n",
"Author: [Chen Chen](https://github.com/chenGitHuber)\n",
"\n",
"In this example, we will work through fine-tuning a BERT model using the tensorflow-models PIP package."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mY1vX5VAq4SS",
"colab_type": "text"
},
"source": [
"## License\n",
"\n",
"Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XV3k63bt0ihl",
"colab_type": "text"
},
"source": [
"## Learning objectives\n",
"\n",
"In this Colab notebook, you will learn how to fine-tune a BERT model using the TensorFlow Model Garden PIP package."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MHA-RWherfG4",
"colab_type": "text"
},
"source": [
"## Enable the GPU acceleration\n",
"Please enable GPU for better performance.\n",
"* Navigate to Edit 🡒 Notebook settings\n",
"* Select GPU from the \"Hardware Accelerator\" drop-down list\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l2B4N5Djrs2l",
"colab_type": "text"
},
"source": [
"## Install the Model Garden PIP package\n",
"\n",
"Install the Model Garden PIP package (tf-models-nightly) and other necessary PIP packages."
]
},
{
"cell_type": "code",
"metadata": {
"id": "VMYZDly6rx97",
"colab_type": "code",
"outputId": "146956ab-4568-4de6-c78e-cf25f115a5a8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"pip install tf-models-nightly"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting tf-models-nightly\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/bd/7c/1390d4e05d4d370e91d32dd9700d3a462dbc560c7f4e95a6477592b17def/tf_models_nightly-2.2.0.dev20200326-py2.py3-none-any.whl (710kB)\n",
"\u001b[K |████████████████████████████████| 716kB 2.8MB/s \n",
"\u001b[?25hRequirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.4.1)\n",
"Collecting opencv-python-headless\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/0b/23/5f10b30a48b218a4884bc84188c14381ac71288b210f6f8079a54f7a05e8/opencv_python_headless-4.2.0.32-cp36-cp36m-manylinux1_x86_64.whl (21.6MB)\n",
"\u001b[K |████████████████████████████████| 21.6MB 1.3MB/s \n",
"\u001b[?25hRequirement already satisfied: tensorflow-hub>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.7.0)\n",
"Requirement already satisfied: typing in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.6.6)\n",
"Collecting tensorflow-model-optimization>=0.2.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8f/c4/4c3d011e432bd9c19f0323f7da7d3f783402615e4c3b5a98416c7da9cb05/tensorflow_model_optimization-0.2.1-py2.py3-none-any.whl (93kB)\n",
"\u001b[K |████████████████████████████████| 102kB 10.2MB/s \n",
"\u001b[?25hRequirement already satisfied: numpy>=1.15.4 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.18.2)\n",
"Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.25.3)\n",
"Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.3.0)\n",
"Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.21.0)\n",
"Collecting mlperf-compliance==0.0.10\n",
" Downloading https://files.pythonhosted.org/packages/f4/08/f2febd8cbd5c9371f7dab311e90400d83238447ba7609b3bf0145b4cb2a2/mlperf_compliance-0.0.10-py3-none-any.whl\n",
"Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.5.6)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.12.0)\n",
"Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.7.12)\n",
"Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.8.3)\n",
"Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (5.4.8)\n",
"Requirement already satisfied: tensorflow-datasets in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (2.1.0)\n",
"Collecting sentencepiece\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
"\u001b[K |████████████████████████████████| 1.0MB 58.3MB/s \n",
"\u001b[?25hCollecting tf-nightly\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/39/1c/4408b4c4b0d8008a7de62162e35089d59d19cc7543cfd1b23a70121f3086/tf_nightly-2.2.0.dev20200325-cp36-cp36m-manylinux2010_x86_64.whl (516.1MB)\n",
"\u001b[K |████████████████████████████████| 516.1MB 21kB/s \n",
"\u001b[?25hCollecting py-cpuinfo>=3.3.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/42/60/63f28a5401da733043abe7053e7d9591491b4784c4f87c339bf51215aa0a/py-cpuinfo-5.0.0.tar.gz (82kB)\n",
"\u001b[K |████████████████████████████████| 92kB 13.7MB/s \n",
"\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (7.0.0)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.13)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.2.1)\n",
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.7)\n",
"Requirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.29.15)\n",
"Requirement already satisfied: oauth2client>=4.1.2 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (4.1.3)\n",
"Requirement already satisfied: protobuf>=3.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.6.0->tf-models-nightly) (3.10.0)\n",
"Collecting enum34~=1.1\n",
" Downloading https://files.pythonhosted.org/packages/63/f6/ccb1c83687756aeabbf3ca0f213508fcfb03883ff200d201b3a4c60cedcc/enum34-1.1.10-py3-none-any.whl\n",
"Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly) (2.8.1)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly) (2018.9)\n",
"Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly) (0.4.1)\n",
"Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly) (1.0.3)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (1.24.3)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (2019.11.28)\n",
"Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (4.0.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (4.38.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (2.21.0)\n",
"Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (0.0.3)\n",
"Requirement already satisfied: httplib2<1dev,>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (0.17.0)\n",
"Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (3.0.1)\n",
"Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (1.7.2)\n",
"Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly) (2.7.1)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.16.0)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.9.0)\n",
"Requirement already satisfied: wrapt in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (1.12.1)\n",
"Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (2.3)\n",
"Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (19.3.0)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (1.1.0)\n",
"Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.21.1)\n",
"Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.3.1.1)\n",
"Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.3.3)\n",
"Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.34.2)\n",
"Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.27.2)\n",
"Collecting tb-nightly<2.3.0a0,>=2.2.0a0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/52/b6/aa559ea9edbc6129a64ff752dbf6567bcd62fba34566defca00fdff4345e/tb_nightly-2.2.0a20200324-py3-none-any.whl (2.8MB)\n",
"\u001b[K |████████████████████████████████| 2.8MB 51.9MB/s \n",
"\u001b[?25hRequirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (2.10.0)\n",
"Collecting tf-estimator-nightly\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/bf/eb/0d6c06d1181cd9b52d7c82535073887d68d224f3dbeeb00adefd04762a9c/tf_estimator_nightly-2.3.0.dev2020032501-py2.py3-none-any.whl (455kB)\n",
"\u001b[K |████████████████████████████████| 460kB 53.5MB/s \n",
"\u001b[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.2.0)\n",
"Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.6.3)\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (3.2.0)\n",
"Requirement already satisfied: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.1.0)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (2.4.6)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (1.1.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (0.10.0)\n",
"Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (0.2.8)\n",
"Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (4.0)\n",
"Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (0.4.8)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.4.0->tensorflow-hub>=0.6.0->tf-models-nightly) (46.0.0)\n",
"Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly) (1.16.0)\n",
"Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly) (1.3)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle>=1.3.9->tf-models-nightly) (2.8)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle>=1.3.9->tf-models-nightly) (3.0.4)\n",
"Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly) (3.1.1)\n",
"Requirement already satisfied: googleapis-common-protos in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets->tf-models-nightly) (1.51.0)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (3.2.1)\n",
"Collecting tensorboard-plugin-wit>=1.6.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/41/ec/3da49289b93963bd8b32d29ed108f1809436ff3d9cd4e29c90bac4a7292f/tensorboard_plugin_wit-1.6.0.post2-py3-none-any.whl (775kB)\n",
"\u001b[K |████████████████████████████████| 778kB 43.8MB/s \n",
"\u001b[?25hRequirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (0.4.1)\n",
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (1.0.0)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (1.3.0)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (3.1.0)\n",
"Building wheels for collected packages: py-cpuinfo\n",
" Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for py-cpuinfo: filename=py_cpuinfo-5.0.0-cp36-none-any.whl size=18684 sha256=093853bc49757be8f9facd8d38eab95b3cc2de13514dafa221ad7898725491ff\n",
" Stored in directory: /root/.cache/pip/wheels/01/7e/a9/b982d0fea22b7e4ae5619de949570cde5ad55420cec16e86a5\n",
"Successfully built py-cpuinfo\n",
"Installing collected packages: opencv-python-headless, enum34, tensorflow-model-optimization, mlperf-compliance, sentencepiece, tensorboard-plugin-wit, tb-nightly, tf-estimator-nightly, tf-nightly, py-cpuinfo, tf-models-nightly\n",
"Successfully installed enum34-1.1.10 mlperf-compliance-0.0.10 opencv-python-headless-4.2.0.32 py-cpuinfo-5.0.0 sentencepiece-0.1.85 tb-nightly-2.2.0a20200324 tensorboard-plugin-wit-1.6.0.post2 tensorflow-model-optimization-0.2.1 tf-estimator-nightly-2.3.0.dev2020032501 tf-models-nightly-2.2.0.dev20200326 tf-nightly-2.2.0.dev20200325\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.colab-display-data+json": {
"pip_warning": {
"packages": [
"enum"
]
}
}
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w4oyRCTji-aa",
"colab_type": "text"
},
"source": [
"## BERT Fine-tuning\n",
"\n",
"The following code import necessary modules for fine-tuning a BERT model on a classification task.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rV8MIX7g078-",
"colab_type": "code",
"colab": {}
},
"source": [
"%tensorflow_version 2.x\n",
"import tensorflow as tf\n",
"\n",
"import json\n",
"import math\n",
"\n",
"from official.utils.misc import distribution_utils\n",
"from official.nlp import optimization\n",
"from official.nlp.bert import bert_models\n",
"from official.nlp.bert import configs as bert_configs\n",
"from official.nlp.bert import run_classifier\n",
"from official.modeling import activations\n",
"from official.nlp.modeling import networks\n",
"from official.nlp.modeling.models import bert_classifier"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DIuS8nYD08n3",
"colab_type": "text"
},
"source": [
"This section of code performs the following tasks:\n",
"* Load data for fine-tuning\n",
"* Fine-tune a BERT model\n",
"* Save the fine-tuned model to a TensorFlow SavedModel file\n",
"\n",
"Please check [create_finetuning_data.py](https://github.com/tensorflow/models/blob/master/official/nlp/data/create_finetuning_data.py) if you want to know how the train/eval data are created."
]
},
{
"cell_type": "code",
"metadata": {
"id": "PAby1RTCi_1e",
"colab_type": "code",
"outputId": "e663e830-cc9b-4b5d-99db-4504fd66d5f3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 258
}
},
"source": [
"\n",
"train_data_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_train.tf_record\"\n",
"eval_data_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_eval.tf_record\"\n",
"input_meta_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_meta_data\"\n",
"\n",
"bert_config_file = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json\"\n",
"ckpt_path = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
"\n",
"with tf.io.gfile.GFile(input_meta_path, 'rb') as reader:\n",
" input_meta_data = json.loads(reader.read().decode('utf-8'))\n",
"\n",
"max_seq_length = input_meta_data['max_seq_length']\n",
"num_classes = input_meta_data['num_labels']\n",
"batch_size = 32\n",
"eval_batch_size = 32\n",
"train_input_fn = run_classifier.get_dataset_fn(train_data_path, max_seq_length, batch_size, is_training=True)\n",
"eval_input_fn = run_classifier.get_dataset_fn(eval_data_path, max_seq_length, eval_batch_size, is_training=False)\n",
"\n",
"strategy = distribution_utils.get_distribution_strategy(\n",
" distribution_strategy='one_device', num_gpus=1)\n",
"\n",
"with strategy.scope():\n",
" training_dataset = train_input_fn()\n",
" evaluation_dataset = eval_input_fn()\n",
" bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)\n",
" classifier_model, encoder = bert_models.classifier_model(\n",
" bert_config, num_classes, max_seq_length)\n",
"\n",
" checkpoint = tf.train.Checkpoint(model=encoder)\n",
" checkpoint.restore(ckpt_path).assert_consumed()\n",
"\n",
" epochs = 3\n",
" train_data_size = input_meta_data['train_data_size']\n",
" eval_data_size = input_meta_data['eval_data_size']\n",
" steps_per_epoch = int(train_data_size / batch_size)\n",
" warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)\n",
" optimizer = optimization.create_optimizer(\n",
" 2e-5, num_train_steps=steps_per_epoch * epochs, num_warmup_steps=warmup_steps)\n",
"\n",
" def metric_fn():\n",
" return tf.keras.metrics.SparseCategoricalAccuracy(\n",
" 'test_accuracy', dtype=tf.float32)\n",
"\n",
" classifier_model.compile(optimizer=optimizer,\n",
" loss=run_classifier.get_loss_fn(num_classes=2),\n",
" metrics=[metric_fn()])\n",
" classifier_model.fit(\n",
" x=training_dataset,\n",
" validation_data=evaluation_dataset,\n",
" steps_per_epoch=steps_per_epoch,\n",
" epochs=epochs,\n",
" validation_steps=int(eval_data_size / eval_batch_size))\n",
"\n",
" classifier_model.save('/tmp/saved_model', include_optimizer=False, save_format='tf')"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:BertClassifier inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to \"bert_classifier\" was not an Input tensor, it was generated by layer input_mask.\n",
"Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n",
"The tensor that caused the issue was: input_mask:0\n",
"WARNING:tensorflow:BertClassifier inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to \"bert_classifier\" was not an Input tensor, it was generated by layer input_type_ids.\n",
"Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n",
"The tensor that caused the issue was: input_type_ids:0\n",
"Epoch 1/3\n",
"114/114 [==============================] - 96s 840ms/step - loss: 0.5932 - test_accuracy: 0.6960 - val_loss: 0.5083 - val_test_accuracy: 0.7604\n",
"Epoch 2/3\n",
"114/114 [==============================] - 100s 878ms/step - loss: 0.4225 - test_accuracy: 0.8183 - val_loss: 0.4020 - val_test_accuracy: 0.8438\n",
"Epoch 3/3\n",
"114/114 [==============================] - 100s 880ms/step - loss: 0.2482 - test_accuracy: 0.9134 - val_loss: 0.4065 - val_test_accuracy: 0.8151\n",
"INFO:tensorflow:Assets written to: /tmp/saved_model/assets\n"
],
"name": "stdout"
}
]
}
]
}
@@ -0,0 +1,19 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Activations package definition."""
from official.modeling.activations.gelu import gelu
from official.modeling.activations.swish import hard_swish
from official.modeling.activations.swish import identity
from official.modeling.activations.swish import simple_swish
@@ -0,0 +1,40 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gaussian error linear unit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
@@ -0,0 +1,38 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the Gaussian error linear unit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
@keras_parameterized.run_all_keras_modes
class GeluTest(keras_parameterized.TestCase):
def test_gelu(self):
expected_data = [[0.14967535, 0., -0.10032465],
[-0.15880796, -0.04540223, 2.9963627]]
gelu_data = activations.gelu([[.25, 0, -.25], [-1, -2, 3]])
self.assertAllClose(expected_data, gelu_data)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,75 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Customized Swish activation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
def simple_swish(features):
"""Computes the Swish activation function.
The tf.nn.swish operation uses a custom gradient to reduce memory usage.
Since saving custom gradients in SavedModel is currently not supported, and
one would not be able to use an exported TF-Hub module for fine-tuning, we
provide this wrapper that can allow to select whether to use the native
TensorFlow swish operation, or whether to use a customized operation that
has uses default TensorFlow gradient computation.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return features * tf.nn.sigmoid(features)
@tf.keras.utils.register_keras_serializable(package='Text')
def hard_swish(features):
"""Computes a hard version of the swish function.
This operation can be used to reduce computational cost and improve
quantization for edge devices.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return features * tf.nn.relu6(features + tf.constant(3.)) * (1. / 6.)
@tf.keras.utils.register_keras_serializable(package='Text')
def identity(features):
"""Computes the identity function.
Useful for helping in quantization.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return tf.identity(features)
@@ -0,0 +1,49 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the customized Swish activation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
@keras_parameterized.run_all_keras_modes
class CustomizedSwishTest(keras_parameterized.TestCase):
def _hard_swish_np(self, x):
x = np.float32(x)
return x * np.clip(x + 3, 0, 6) / 6
def test_simple_swish(self):
features = [[.25, 0, -.25], [-1, -2, 3]]
customized_swish_data = activations.simple_swish(features)
swish_data = tf.nn.swish(features)
self.assertAllClose(customized_swish_data, swish_data)
def test_hard_swish(self):
features = [[.25, 0, -.25], [-1, -2, 3]]
customized_swish_data = activations.hard_swish(features)
swish_data = self._hard_swish_np(features)
self.assertAllClose(customized_swish_data, swish_data)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,323 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base configurations to standardize experiments."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy
import functools
from typing import Any, List, Mapping, Optional, Type
import dataclasses
import tensorflow as tf
import yaml
from official.modeling.hyperparams import params_dict
@dataclasses.dataclass
class Config(params_dict.ParamsDict):
"""The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* It converts dict to Config even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
"""
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES = (str, int, float, bool, type(None))
# It's safe to add set, frozenset and other collections here.
SEQUENCE_TYPES = (list, tuple)
default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
restrictions: dataclasses.InitVar[Optional[List[str]]] = None
@classmethod
def _isvalidsequence(cls, v):
"""Check if the input values are valid sequences.
Args:
v: Input sequence.
Returns:
True if the sequence is valid. Valid sequence includes the sequence
type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
is dict or ParamsDict.
"""
if not isinstance(v, cls.SEQUENCE_TYPES):
return False
return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
all(isinstance(e, dict) for e in v) or
all(isinstance(e, params_dict.ParamsDict) for e in v))
@classmethod
def _import_config(cls, v, subconfig_type):
"""Returns v with dicts converted to Configs, recursively."""
if not issubclass(subconfig_type, params_dict.ParamsDict):
raise TypeError(
'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
subconfig_type))
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
# Only support one layer of sequence.
if not cls._isvalidsequence(v):
raise TypeError(
'Invalid sequence: only supports single level {!r} of {!r} or '
'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
cls.IMMUTABLE_TYPES, v))
import_fn = functools.partial(
cls._import_config, subconfig_type=subconfig_type)
return type(v)(map(import_fn, v))
elif isinstance(v, params_dict.ParamsDict):
# Deepcopy here is a temporary solution for preserving type in nested
# Config object.
return copy.deepcopy(v)
elif isinstance(v, dict):
return subconfig_type(v)
else:
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _export_config(cls, v):
"""Returns v with Configs converted to dicts, recursively."""
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
return type(v)(map(cls._export_config, v))
elif isinstance(v, params_dict.ParamsDict):
return v.as_dict()
elif isinstance(v, dict):
raise TypeError('dict value not supported in converting.')
else:
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _get_subconfig_type(cls, k) -> Type[params_dict.ParamsDict]:
"""Get element type by the field name.
Args:
k: the key/name of the field.
Returns:
Config as default. If a type annotation is found for `k`,
1) returns the type of the annotation if it is subtype of ParamsDict;
2) returns the element type if the annotation of `k` is List[SubType]
or Tuple[SubType].
"""
subconfig_type = Config
if k in cls.__annotations__:
# Directly Config subtype.
type_annotation = cls.__annotations__[k]
if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k]
else:
# Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None))
if (isinstance(field_type, type) and
issubclass(field_type, cls.SEQUENCE_TYPES)):
element_type = getattr(type_annotation, '__args__', [type(None)])[0]
subconfig_type = (
element_type if issubclass(element_type, params_dict.ParamsDict)
else subconfig_type)
return subconfig_type
def __post_init__(self, default_params, restrictions, *args, **kwargs):
super().__init__(default_params=default_params,
restrictions=restrictions,
*args,
**kwargs)
def _set(self, k, v):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
k: key to set.
v: value.
Raises:
RuntimeError
"""
subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict):
if k not in self.__dict__ or not self.__dict__[k]:
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v)
else:
self.__dict__[k].override(v)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
def __setattr__(self, k, v):
if k not in self.RESERVED_ATTR:
if getattr(self, '_locked', False):
raise ValueError('The Config has been locked. ' 'No change is allowed.')
self._set(k, v)
def _override(self, override_dict, is_strict=True):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
override_dict: dictionary to write to .
is_strict: If True, not allows to add new keys.
Raises:
KeyError: overriding reserved keys or keys not exist (is_strict=True).
"""
for k, v in sorted(override_dict.items()):
if k in self.RESERVED_ATTR:
raise KeyError('The key {!r} is internally reserved. '
'Can not be overridden.'.format(k))
if k not in self.__dict__:
if is_strict:
raise KeyError('The key {!r} does not exist in {!r}. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(
k, type(self)))
else:
self._set(k, v)
else:
if isinstance(v, dict) and self.__dict__[k]:
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self._set(k, v)
def as_dict(self):
"""Returns a dict representation of params_dict.ParamsDict.
For the nested params_dict.ParamsDict, a nested dict will be returned.
"""
return {
k: self._export_config(v)
for k, v in self.__dict__.items()
if k not in self.RESERVED_ATTR
}
def replace(self, **kwargs):
"""Like `override`, but returns a copy with the current config unchanged."""
params = self.__class__(self)
params.override(kwargs, is_strict=True)
return params
@classmethod
def from_yaml(cls, file_path: str):
# Note: This only works if the Config has all default values.
with tf.io.gfile.GFile(file_path, 'r') as f:
loaded = yaml.load(f)
config = cls()
config.override(loaded)
return config
@classmethod
def from_json(cls, file_path: str):
"""Wrapper for `from_yaml`."""
return cls.from_yaml(file_path)
@classmethod
def from_args(cls, *args, **kwargs):
"""Builds a config from the given list of arguments."""
attributes = list(cls.__annotations__.keys())
default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs)
return cls(default_params)
@dataclasses.dataclass
class RuntimeConfig(Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
"""
distribution_strategy: str = 'mirrored'
enable_xla: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
loss_scale: Optional[str] = None
run_eagerly: bool = False
@dataclasses.dataclass
class TensorboardConfig(Config):
"""Configuration for Tensorboard.
Attributes:
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as
images in Tensorboard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True
enable_time_history: bool = True
@@ -0,0 +1,299 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pprint
from typing import List, Tuple
from absl.testing import parameterized
import dataclasses
import tensorflow as tf
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class DumpConfig1(base_config.Config):
a: int = 1
b: str = 'text'
@dataclasses.dataclass
class DumpConfig2(base_config.Config):
c: int = 2
d: str = 'text'
e: DumpConfig1 = DumpConfig1()
@dataclasses.dataclass
class DumpConfig3(DumpConfig2):
f: int = 2
g: str = 'text'
h: List[DumpConfig1] = dataclasses.field(
default_factory=lambda: [DumpConfig1(), DumpConfig1()])
g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
def assertHasSameTypes(self, c, d, msg=''):
"""Checks if a Config has the same structure as a given dict.
Args:
c: the Config object to be check.
d: the reference dict object.
msg: The error message to show when type mismatched.
"""
# Make sure d is not a Config. Assume d is either
# dictionary or primitive type and c is the Config or primitive types.
self.assertNotIsInstance(d, base_config.Config)
if isinstance(d, base_config.Config.IMMUTABLE_TYPES):
self.assertEqual(pprint.pformat(c), pprint.pformat(d), msg=msg)
elif isinstance(d, base_config.Config.SEQUENCE_TYPES):
self.assertEqual(type(c), type(d), msg=msg)
for i, v in enumerate(d):
self.assertHasSameTypes(c[i], v, msg='{}[{!r}]'.format(msg, i))
elif isinstance(d, dict):
self.assertIsInstance(c, base_config.Config, msg=msg)
for k, v in sorted(d.items()):
self.assertHasSameTypes(getattr(c, k), v, msg='{}[{!r}]'.format(msg, k))
else:
raise TypeError('Unknown type: %r' % type(d))
def assertImportExport(self, v):
config = base_config.Config({'key': v})
back = config.as_dict()['key']
self.assertEqual(pprint.pformat(back), pprint.pformat(v))
self.assertHasSameTypes(config.key, v, msg='=%s v' % pprint.pformat(v))
def test_invalid_keys(self):
params = base_config.Config()
with self.assertRaises(AttributeError):
_ = params.a
def test_nested_config_types(self):
config = DumpConfig3()
self.assertIsInstance(config.e, DumpConfig1)
self.assertIsInstance(config.h[0], DumpConfig1)
self.assertIsInstance(config.h[1], DumpConfig1)
self.assertIsInstance(config.g[0], DumpConfig1)
config.override({'e': {'a': 2, 'b': 'new text'}})
self.assertIsInstance(config.e, DumpConfig1)
self.assertEqual(config.e.a, 2)
self.assertEqual(config.e.b, 'new text')
config.override({'h': [{'a': 3, 'b': 'new text 2'}]})
self.assertIsInstance(config.h[0], DumpConfig1)
self.assertLen(config.h, 1)
self.assertEqual(config.h[0].a, 3)
self.assertEqual(config.h[0].b, 'new text 2')
config.override({'g': [{'a': 4, 'b': 'new text 3'}]})
self.assertIsInstance(config.g[0], DumpConfig1)
self.assertLen(config.g, 1)
self.assertEqual(config.g[0].a, 4)
self.assertEqual(config.g[0].b, 'new text 3')
@parameterized.parameters(
('_locked', "The key '_locked' is internally reserved."),
('_restrictions', "The key '_restrictions' is internally reserved."),
('aa', "The key 'aa' does not exist."),
)
def test_key_error(self, key, msg):
params = base_config.Config()
with self.assertRaisesRegex(KeyError, msg):
params.override({key: True})
@parameterized.parameters(
('str data',),
(123,),
(1.23,),
(None,),
(['str', 1, 2.3, None],),
(('str', 1, 2.3, None),),
)
def test_import_export_immutable_types(self, v):
self.assertImportExport(v)
out = base_config.Config({'key': v})
self.assertEqual(pprint.pformat(v), pprint.pformat(out.key))
def test_override_is_strict_true(self):
params = base_config.Config({
'a': 'aa',
'b': 2,
'c': {
'c1': 'cc',
'c2': 20
}
})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc')
with self.assertRaises(KeyError):
params.override({'d': 'ddd'}, is_strict=True)
with self.assertRaises(KeyError):
params.override({'c': {'c3': 30}}, is_strict=True)
config = base_config.Config({'key': [{'a': 42}]})
config.override({'key': [{'b': 43}]})
self.assertEqual(config.key[0].b, 43)
with self.assertRaisesRegex(AttributeError, 'The key `a` does not exist'):
_ = config.key[0].a
@parameterized.parameters(
(lambda x: x, 'Unknown type'),
(object(), 'Unknown type'),
(set(), 'Unknown type'),
(frozenset(), 'Unknown type'),
)
def test_import_unsupport_types(self, v, msg):
with self.assertRaisesRegex(TypeError, msg):
_ = base_config.Config({'key': v})
@parameterized.parameters(
({
'a': [{
'b': 2,
}, {
'c': 3,
}]
},),
({
'c': [{
'f': 1.1,
}, {
'h': [1, 2],
}]
},),
(({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20,
}
},),),
)
def test_import_export_nested_structure(self, d):
self.assertImportExport(d)
@parameterized.parameters(
([{
'a': 42,
'b': 'hello',
'c': 1.2
}],),
(({
'a': 42,
'b': 'hello',
'c': 1.2
},),),
)
def test_import_export_nested_sequences(self, v):
self.assertImportExport(v)
@parameterized.parameters(
([([{}],)],),
([['str', 1, 2.3, None]],),
((('str', 1, 2.3, None),),),
([
('str', 1, 2.3, None),
],),
([
('str', 1, 2.3, None),
],),
([[{
'a': 42,
'b': 'hello',
'c': 1.2
}]],),
([[[{
'a': 42,
'b': 'hello',
'c': 1.2
}]]],),
((({
'a': 42,
'b': 'hello',
'c': 1.2
},),),),
(((({
'a': 42,
'b': 'hello',
'c': 1.2
},),),),),
([({
'a': 42,
'b': 'hello',
'c': 1.2
},)],),
(([{
'a': 42,
'b': 'hello',
'c': 1.2
}],),),
)
def test_import_export_unsupport_sequence(self, v):
with self.assertRaisesRegex(TypeError,
'Invalid sequence: only supports single level'):
_ = base_config.Config({'key': v})
def test_construct_subtype(self):
pass
def test_import_config(self):
params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
self.assertLen(params.a, 2)
self.assertEqual(params.a[0].b, 2)
self.assertEqual(type(params.a[0]), base_config.Config)
self.assertEqual(pprint.pformat(params.a[0].b), '2')
self.assertEqual(type(params.a[1]), base_config.Config)
self.assertEqual(type(params.a[1].c), base_config.Config)
self.assertEqual(pprint.pformat(params.a[1].c.d), '3')
def test_override(self):
params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
params.override({'a': [{'b': 4}, {'c': {'d': 5}}]}, is_strict=False)
self.assertEqual(type(params.a), list)
self.assertEqual(type(params.a[0]), base_config.Config)
self.assertEqual(pprint.pformat(params.a[0].b), '4')
self.assertEqual(type(params.a[1]), base_config.Config)
self.assertEqual(type(params.a[1].c), base_config.Config)
self.assertEqual(pprint.pformat(params.a[1].c.d), '5')
@parameterized.parameters(
([{}],),
(({},),),
)
def test_config_vs_params_dict(self, v):
d = {'key': v}
self.assertEqual(type(base_config.Config(d).key[0]), base_config.Config)
self.assertEqual(type(base_config.params_dict.ParamsDict(d).key[0]), dict)
def test_ppformat(self):
self.assertEqual(
pprint.pformat([
's', 1, 1.0, True, None, {}, [], (), {
(2,): (3, [4], {
6: 7,
}),
8: 9,
}
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,410 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A parameter dictionary class which supports the nest structure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import re
import six
import tensorflow as tf
import yaml
# regex pattern that matches on key-value pairs in a comma-separated
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
|
\"(.*?)\" # double quote
|
[^,\[]* # single value
|
\[[^\]]*\])) # list of values
($|,\s*)""", re.VERBOSE)
class ParamsDict(object):
"""A hyperparameter container class."""
RESERVED_ATTR = ['_locked', '_restrictions']
def __init__(self, default_params=None, restrictions=None):
"""Instantiate a ParamsDict.
Instantiate a ParamsDict given a set of default parameters and a list of
restrictions. Upon initialization, it validates itself by checking all the
defined restrictions, and raise error if it finds inconsistency.
Args:
default_params: a Python dict or another ParamsDict object including the
default parameters to initialize.
restrictions: a list of strings, which define a list of restrictions to
ensure the consistency of different parameters internally. Each
restriction string is defined as a binary relation with a set of
operators, including {'==', '!=', '<', '<=', '>', '>='}.
"""
self._locked = False
self._restrictions = []
if restrictions:
self._restrictions = restrictions
if default_params is None:
default_params = {}
self.override(default_params, is_strict=False)
self.validate()
def _set(self, k, v):
if isinstance(v, dict):
self.__dict__[k] = ParamsDict(v)
else:
self.__dict__[k] = copy.deepcopy(v)
def __setattr__(self, k, v):
"""Sets the value of the existing key.
Note that this does not allow directly defining a new key. Use the
`override` method with `is_strict=False` instead.
Args:
k: the key string.
v: the value to be used to set the key `k`.
Raises:
KeyError: if k is not defined in the ParamsDict.
"""
if k not in ParamsDict.RESERVED_ATTR:
if k not in self.__dict__.keys():
raise KeyError('The key `%{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = True.'.format(k))
if self._locked:
raise ValueError('The ParamsDict has been locked. '
'No change is allowed.')
self._set(k, v)
def __getattr__(self, k):
"""Gets the value of the existing key.
Args:
k: the key string.
Returns:
the value of the key.
Raises:
AttributeError: if k is not defined in the ParamsDict.
"""
if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k))
return self.__dict__[k]
def __contains__(self, key):
"""Implements the membership test operator."""
return key in self.__dict__
def get(self, key, value=None):
"""Accesses through built-in dictionary get method."""
return self.__dict__.get(key, value)
def override(self, override_params, is_strict=True):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to
be overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict.
If False, keys in `override_params` can be different from what is
currently defined in the ParamsDict. In this case, the ParamsDict will
be extended to include the new keys.
"""
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
if isinstance(override_params, ParamsDict):
override_params = override_params.as_dict()
self._override(override_params, is_strict) # pylint: disable=protected-access
def _override(self, override_dict, is_strict=True):
"""The implementation of `override`."""
for k, v in six.iteritems(override_dict):
if k in ParamsDict.RESERVED_ATTR:
raise KeyError('The key `%{}` is internally reserved. '
'Can not be overridden.')
if k not in self.__dict__.keys():
if is_strict:
raise KeyError('The key `{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(k))
else:
self._set(k, v)
else:
if isinstance(v, dict):
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, ParamsDict):
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self.__dict__[k] = copy.deepcopy(v)
def lock(self):
"""Makes the ParamsDict immutable."""
self._locked = True
def as_dict(self):
"""Returns a dict representation of ParamsDict.
For the nested ParamsDict, a nested dict will be returned.
"""
params_dict = {}
for k, v in six.iteritems(self.__dict__):
if k not in ParamsDict.RESERVED_ATTR:
if isinstance(v, ParamsDict):
params_dict[k] = v.as_dict()
else:
params_dict[k] = copy.deepcopy(v)
return params_dict
def validate(self):
"""Validate the parameters consistency based on the restrictions.
This method validates the internal consistency using the pre-defined list of
restrictions. A restriction is defined as a string which specfiies a binary
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
'>='}. Note that the meaning of these operators are consistent with the
underlying Python immplementation. Users should make sure the define
restrictions on their type make sense.
For example, for a ParamsDict like the following
```
a:
a1: 1
a2: 2
b:
bb:
bb1: 10
bb2: 20
ccc:
a1: 1
a3: 3
```
one can define two restrictions like this
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
- a.a1 = 1 == b.ccc.a1 = 2
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
KeyError: if any of the following happens
(1) any of parameters in any of restrictions is not defined in
ParamsDict,
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def _get_kv(dotted_string, params_dict):
tokenized_params = dotted_string.split('.')
v = params_dict
for t in tokenized_params:
v = v[t]
return tokenized_params[-1], v
def _get_kvs(tokens, params_dict):
if len(tokens) != 2:
raise ValueError('Only support binary relation in restriction.')
stripped_tokens = [t.strip() for t in tokens]
left_k, left_v = _get_kv(stripped_tokens[0], params_dict)
right_k, right_v = _get_kv(stripped_tokens[1], params_dict)
return left_k, left_v, right_k, right_v
params_dict = self.as_dict()
for restriction in self._restrictions:
if '==' in restriction:
tokens = restriction.split('==')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '!=' in restriction:
tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '<=' in restriction:
tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
elif '>=' in restriction:
tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
else:
raise ValueError('Unsupported relation in restriction.')
def read_yaml_to_params_dict(file_path):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f)
return ParamsDict(params_dict)
def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence(
u'tag:yaml.org,2002:seq', data, flow_style=True)
yaml.add_representer(list, _my_list_rep)
yaml.dump(params.as_dict(), f, default_flow_style=False)
def nested_csv_str_to_json_str(csv_str):
"""Converts a nested (using '.') comma-separated k=v string to a JSON string.
Converts a comma-separated string of key/value pairs that supports
nesting of keys to a JSON string. Nesting is implemented using
'.' between levels for a given key.
Spacing between commas and = is supported (e.g. there is no difference between
"a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
Note that this will only support values supported by CSV, meaning
values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
supported. Strings are supported as well, e.g. "a='hello'".
An example conversion would be:
"a=1, b=2, c.a=2, c.b=3, d.a.a=5"
to
"{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
Args:
csv_str: the comma separated string.
Returns:
the converted JSON string.
Raises:
ValueError: If csv_str is not in a comma separated string or
if the string is formatted incorrectly.
"""
if not csv_str:
return ''
formatted_entries = []
nested_map = collections.defaultdict(list)
pos = 0
while pos < len(csv_str):
m = _PARAM_RE.match(csv_str, pos)
if not m:
raise ValueError('Malformed hyperparameter value while parsing '
'CSV string: %s' % csv_str[pos:])
pos = m.end()
# Parse the values.
m_dict = m.groupdict()
name = m_dict['name']
v = m_dict['val']
# If a GCS path (e.g. gs://...) is provided, wrap this in quotes
# as yaml.load would otherwise throw an exception
if re.match(r'(?=[^\"\'])(?=[gs://])', v):
v = '\'{}\''.format(v)
name_nested = name.split('.')
if len(name_nested) > 1:
grouping = name_nested[0]
value = '.'.join(name_nested[1:]) + '=' + v
nested_map[grouping].append(value)
else:
formatted_entries.append('%s : %s' % (name, v))
for grouping, value in nested_map.items():
value = ','.join(value)
value = nested_csv_str_to_json_str(value)
formatted_entries.append('%s : %s' % (grouping, value))
return '{' + ', '.join(formatted_entries) + '}'
def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
"""Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
The logic of the function is outlined below:
1. Test that the input is a dict. If not, proceed to 2.
2. Tests that the input is a string. If not, raise unknown ValueError
2.1. Test if the string is in a CSV format. If so, parse.
If not, proceed to 2.2.
2.2. Try loading the string as a YAML/JSON. If successful, parse to
dict and use it to override. If not, proceed to 2.3.
2.3. Try using the string as a file path and load the YAML file.
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
params: the overridden ParamsDict object.
Raises:
ValueError: if failed to override the parameters.
"""
if not dict_or_string_or_yaml_file:
return params
if isinstance(dict_or_string_or_yaml_file, dict):
params.override(dict_or_string_or_yaml_file, is_strict)
elif isinstance(dict_or_string_or_yaml_file, six.string_types):
try:
dict_or_string_or_yaml_file = (
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError:
pass
params_dict = yaml.load(dict_or_string_or_yaml_file)
if isinstance(params_dict, dict):
params.override(params_dict, is_strict)
else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f), is_strict)
else:
raise ValueError('Unknown input type to parse.')
return params
@@ -0,0 +1,322 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.modeling.hyperparams.params_dict.py."""
import os
import tensorflow as tf
import yaml
from official.modeling.hyperparams import params_dict
class ParamsDictTest(tf.test.TestCase):
def test_init_from_an_empty_dict(self):
params = params_dict.ParamsDict()
with self.assertRaises(AttributeError):
_ = params.a
with self.assertRaises(KeyError):
params.a = 'aa'
def test_init_from_a_dict(self):
params = params_dict.ParamsDict({'a': 'aa', 'b': 2})
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
def test_init_from_a_param_dict(self):
params_init = params_dict.ParamsDict({'a': 'aa', 'b': 2})
params = params_dict.ParamsDict(params_init)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
def test_lock(self):
params = params_dict.ParamsDict({'a': 1, 'b': 2})
params.lock()
with self.assertRaises(ValueError):
params.a = 10
with self.assertRaises(ValueError):
params.override({'b': 20})
def test_setattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.c = 'ccc'
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, 'ccc')
def test_getattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, None)
def test_contains(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa'}, is_strict=False)
self.assertIn('a', params)
self.assertNotIn('b', params)
def test_get(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa'}, is_strict=False)
self.assertEqual(params.get('a'), 'aa')
self.assertEqual(params.get('b', 2), 2)
self.assertEqual(params.get('b'), None)
def test_override_is_strict_true(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 'cc', 'c2': 20}})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc')
with self.assertRaises(KeyError):
params.override({'d': 'ddd'}, is_strict=True)
with self.assertRaises(KeyError):
params.override({'c': {'c3': 30}}, is_strict=True)
def test_override_is_strict_false(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c3, 3000)
params.override({'d': 'ddd'}, is_strict=False)
self.assertEqual(params.d, 'ddd')
params.override({'c': {'c4': 4444}}, is_strict=False)
self.assertEqual(params.c.c4, 4444)
def test_as_dict(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params_d = params.as_dict()
self.assertEqual(params_d['a'], 'aa')
self.assertEqual(params_d['b'], 2)
self.assertEqual(params_d['c']['c1'], 10)
self.assertEqual(params_d['c']['c2'], 20)
def test_validate(self):
# Raise error due to the unknown parameter.
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 1, 'b': {'a': 11}}, ['a == c'])
# OK to check equality of two nested dicts.
params = params_dict.ParamsDict(
{'a': 1, 'b': {'a': 10}, 'c': {'a': 10}}, ['b == c'])
# Raise error due to inconsistency
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 1, 'c': {'a': 10}}, ['a == c.a'])
# Valid rule.
params = params_dict.ParamsDict(
{'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate.
params.override({'a': 11})
with self.assertRaises(KeyError):
params.validate()
class ParamsDictIOTest(tf.test.TestCase):
def write_temp_file(self, filename, text):
temp_file = os.path.join(self.get_temp_dir(), filename)
with tf.io.gfile.GFile(temp_file, 'w') as writer:
writer.write(text)
return temp_file
def test_save_params_dict_to_yaml(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
params_dict.save_params_dict_to_yaml(params, output_yaml_file)
with tf.io.gfile.GFile(output_yaml_file, 'r') as f:
params_d = yaml.load(f)
self.assertEqual(params.a, params_d['a'])
self.assertEqual(params.b, params_d['b'])
self.assertEqual(params.c.c1, params_d['c']['c1'])
self.assertEqual(params.c.c2, params_d['c']['c2'])
def test_read_yaml_to_params_dict(self):
input_yaml_file = self.write_temp_file(
'params.yaml', r"""
a: 'aa'
b: 2
c:
c1: 10
c2: 20
""")
params = params_dict.read_yaml_to_params_dict(input_yaml_file)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c.c1, 10)
self.assertEqual(params.c.c2, 20)
def test_override_params_dict_using_dict(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
override_dict = {'b': 5.2, 'c': [30, 40]}
params = params_dict.override_params_dict(
params, override_dict, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
def test_override_params_dict_using_yaml_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
override_yaml_string = "'b': 5.2\n'c': [30, 40]"
params = params_dict.override_params_dict(
params, override_yaml_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
def test_override_params_dict_using_json_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
'd': {'d1': {'d2': 'hello'}}, 'e': False})
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params = params_dict.override_params_dict(
params, override_json_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(2, params.b.b1)
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi', params.d.d1.d2)
self.assertEqual(False, params.e)
def test_override_params_dict_using_csv_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
'd': {'d1': {'d2': 'hello'}}, 'e': False})
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(2, params.b.b1)
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi, world', params.d.d1.d2)
self.assertEqual('gs://test', params.e)
def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
override_yaml_file = self.write_temp_file(
'params.yaml', r"""
b: 5.2
c: [30, 40]
""")
params = params_dict.override_params_dict(
params, override_yaml_file, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
class IOTest(tf.test.TestCase):
def test_basic_csv_str_to_json_str(self):
csv_str = 'a=1,b=2,c=3'
json_str = '{a : 1, b : 2, c : 3}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_basic_csv_str_load(self):
csv_str = 'a=1,b=2,c=3'
expected_output = {'a': 1, 'b': 2, 'c': 3}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_basic_nested_csv_str_to_json_str(self):
csv_str = 'a=1,b.b1=2'
json_str = '{a : 1, b : {b1 : 2}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_basic_nested_csv_str_load(self):
csv_str = 'a=1,b.b1=2,c.c1=3'
expected_output = {'a': 1, 'b': {'b1': 2}, 'c': {'c1': 3}}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_complex_nested_csv_str_to_json_str(self):
csv_str = 'a.aa.aaa.aaaaa.a=1'
json_str = '{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_complex_nested_csv_str_load(self):
csv_str = 'a.aa.aaa.aaaaa.a=1,a.a=2'
expected_output = {'a': {'aa': {'aaa': {'aaaaa': {'a': 1}}}, 'a': 2}}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_csv_str_load_supported_datatypes(self):
csv_str = 'a=1,b=2.,c=[1,2,3],d=\'hello, there\',e=\"Hi.\"'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertEqual(converted_dict['a'], 1)
self.assertEqual(converted_dict['b'], 2.)
self.assertEqual(converted_dict['c'], [1, 2, 3])
self.assertEqual(converted_dict['d'], 'hello, there')
self.assertEqual(converted_dict['e'], 'Hi.')
def test_csv_str_load_unsupported_datatypes(self):
csv_str = 'a=[[1,2,3],[4,5,6]]'
self.assertRaises(ValueError,
params_dict.nested_csv_str_to_json_str,
csv_str)
def test_csv_str_to_json_str_spacing(self):
csv_str1 = 'a=1,b=2,c=3'
csv_str2 = 'a = 1, b = 2, c = 3'
json_str = '{a : 1, b : 2, c : 3}'
converted_csv_str1 = params_dict.nested_csv_str_to_json_str(csv_str1)
converted_csv_str2 = params_dict.nested_csv_str_to_json_str(csv_str2)
self.assertEqual(converted_csv_str1, converted_csv_str2)
self.assertEqual(converted_csv_str1, json_str)
self.assertEqual(converted_csv_str2, json_str)
def test_gcs_added_quotes(self):
csv_str = 'a=gs://abc, b=gs://def'
expected_output = '{a : \'gs://abc\', b : \'gs://def\'}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, expected_output)
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,56 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to training performance."""
import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale="dynamic"):
"""Configures optimizer object with performance options."""
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None):
"""Sets mix precision policy."""
if dtype == tf.float16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.float32:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
raise ValueError("Unexpected dtype: %s" % dtype)
@@ -0,0 +1,175 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common TF utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
from tensorflow.python.util import deprecation
from official.modeling import activations
@deprecation.deprecated(
None,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
)
def pack_inputs(inputs):
"""Pack a list of `inputs` tensors to a tuple.
Args:
inputs: a list of tensors.
Returns:
a tuple of tensors. if any input is None, replace it with a special constant
tensor.
"""
inputs = tf.nest.flatten(inputs)
outputs = []
for x in inputs:
if x is None:
outputs.append(tf.constant(0, shape=[], dtype=tf.int32))
else:
outputs.append(x)
return tuple(outputs)
@deprecation.deprecated(
None,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
)
def unpack_inputs(inputs):
"""unpack a tuple of `inputs` tensors to a tuple.
Args:
inputs: a list of tensors.
Returns:
a tuple of tensors. if any input is a special constant tensor, replace it
with None.
"""
inputs = tf.nest.flatten(inputs)
outputs = []
for x in inputs:
if is_special_none_tensor(x):
outputs.append(None)
else:
outputs.append(x)
x = tuple(outputs)
# To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
# from triggering.
if len(x) == 1:
return x[0]
return tuple(outputs)
def is_special_none_tensor(tensor):
"""Checks if a tensor is a special None Tensor."""
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def get_activation(identifier):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
It checks string first and if it is one of customized activation not in TF,
the corresponding activation will be returned. For non-customized activation
names and callable identifiers, always fallback to tf.keras.activations.get.
Args:
identifier: String name of the activation function or callable.
Returns:
A Python function corresponding to the activation function.
"""
if isinstance(identifier, six.string_types):
name_to_fn = {
"gelu": activations.gelu,
"simple_swish": activations.simple_swish,
"hard_swish": activations.hard_swish,
"identity": activations.identity,
}
identifier = str(identifier).lower()
if identifier in name_to_fn:
return tf.keras.activations.get(name_to_fn[identifier])
return tf.keras.activations.get(identifier)
def get_shape_list(tensor, expected_rank=None, name=None):
"""Returns a list of the shape of tensor, preferring static dimensions.
Args:
tensor: A tf.Tensor object to find the shape of.
expected_rank: (optional) int. The expected rank of `tensor`. If this is
specified and the `tensor` has a different rank, and exception will be
thrown.
name: Optional name of the tensor for the error message.
Returns:
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars.
"""
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
shape[index] = dyn_shape[index]
return shape
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
raise ValueError(
"For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
"equal to the expected tensor rank `%s`" %
(name, actual_rank, str(tensor.shape), str(expected_rank)))
@@ -0,0 +1,759 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom training loop for running TensorFlow 2.0 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils import hyperparams_flags
FLAGS = flags.FLAGS
strategy_flags_dict = hyperparams_flags.strategy_flags_dict
hparam_flags_dict = hyperparams_flags.hparam_flags_dict
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to model_dir with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
def _steps_to_run(current_step, total_steps, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
return min(total_steps - current_step, steps_per_loop)
def _no_metric():
return None
class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics.
Attributes:
writer: The tf.SummaryWriter.
"""
def __init__(self, model_dir: Text, name: Text):
"""Inits SummaryWriter with paths.
Arguments:
model_dir: the model folder path.
name: the summary subfolder name.
"""
self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
"""Write metrics to summary with the given writer.
Args:
metrics: a dictionary of metrics values. Prefer dictionary.
step: integer. The training step.
"""
if not isinstance(metrics, dict):
# Support scalar metric without name.
logging.warning('Warning: summary writer prefer metrics as dictionary.')
metrics = {'metric': metrics}
with self.writer.as_default():
for k, v in metrics.items():
tf.summary.scalar(k, v, step=step)
self.writer.flush()
class DistributedExecutor(object):
"""Interface to train and eval models with tf.distribute.Strategy.
Arguments:
strategy: an instance of tf.distribute.Strategy.
params: Model configuration needed to run distribution strategy.
model_fn: Keras model function. Signature:
(params: ParamsDict) -> tf.keras.models.Model.
loss_fn: loss function. Signature:
(y_true: Tensor, y_pred: Tensor) -> Tensor
metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric.
is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False.
"""
def __init__(self,
strategy,
params,
model_fn,
loss_fn,
is_multi_host=False):
self._params = params
self._model_fn = model_fn
self._loss_fn = loss_fn
self._strategy = strategy
self._checkpoint_name = 'ctl_step_{step}.ckpt'
self._is_multi_host = is_multi_host
self.train_summary_writer = None
self.eval_summary_writer = None
self.global_train_step = None
@property
def checkpoint_name(self):
"""Returns default checkpoint name."""
return self._checkpoint_name
@checkpoint_name.setter
def checkpoint_name(self, name):
"""Sets default summary writer for the current thread."""
self._checkpoint_name = name
def loss_fn(self):
return self._loss_fn()
def model_fn(self, params):
return self._model_fn(params)
def _save_config(self, model_dir):
"""Save parameters to config files if model_dir is defined."""
logging.info('Save config to model_dir %s.', model_dir)
if model_dir:
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.makedirs(model_dir)
self._params.lock()
params_dict.save_params_dict_to_yaml(self._params,
model_dir + '/params.yaml')
else:
logging.warning('model_dir is empty, so skip the save config.')
def _get_input_iterator(
self, input_fn: Callable[..., tf.data.Dataset],
strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
"""Returns distributed dataset iterator.
Args:
input_fn: (params: dict) -> tf.data.Dataset.
strategy: an instance of tf.distribute.Strategy.
Returns:
An iterator that yields input tensors.
"""
if input_fn is None:
return None
# When training with multiple TPU workers, datasets needs to be cloned
# across workers. Since Dataset instance cannot be cloned in eager mode,
# we instead pass callable that returns a dataset.
if self._is_multi_host:
return iter(
strategy.experimental_distribute_datasets_from_function(input_fn))
else:
input_data = input_fn()
return iter(strategy.experimental_distribute_dataset(input_data))
def _create_replicated_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
prediction_loss = loss_fn(labels, outputs)
loss = tf.reduce_mean(prediction_loss)
loss = loss / strategy.num_replicas_in_sync
if isinstance(metric, tf.keras.metrics.Metric):
metric.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
return _replicated_step
def _create_train_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
"""Creates a distributed training step.
Args:
strategy: an instance of tf.distribute.Strategy.
model: (Tensor, bool) -> Tensor. model function.
loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
optimizer: tf.keras.optimizers.Optimizer.
iterator: an iterator that yields input tensors.
metric: tf.keras.metrics.Metric subclass.
Returns:
The training step callable.
"""
_replicated_step = self._create_replicated_step(strategy, model, loss_fn,
optimizer, metric)
@tf.function
def train_step(iterator, num_steps):
"""Performs a distributed training step.
Args:
iterator: an iterator that yields input tensors.
Returns:
The loss tensor.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),))
for _ in tf.range(num_steps - 1):
per_replica_losses = strategy.run(
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
losses = tf.nest.map_structure(
lambda x: strategy.reduce(tf.distribute.ReduceOp.MEAN, x, axis=None),
per_replica_losses)
return losses
return train_step
def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step."""
@tf.function
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
if not metric:
logging.info('Skip test_step because metric is None (%s)', metric)
return None, None
if not isinstance(metric, tf.keras.metrics.Metric):
raise ValueError(
'Metric must be an instance of tf.keras.metrics.Metric '
'for running in test_step. Actual {}'.format(metric))
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs)
return labels, model_outputs
return strategy.run(_test_step_fn, args=(next(iterator),))
return test_step
def train(self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Callable[[params_dict.ParamsDict],
tf.data.Dataset] = None,
model_dir: Text = None,
total_steps: int = 1,
iterations_per_loop: int = 1,
train_metric_fn: Callable[[], Any] = None,
eval_metric_fn: Callable[[], Any] = None,
summary_writer_fn: Callable[[Text, Text],
SummaryWriter] = SummaryWriter,
init_checkpoint: Callable[[tf.keras.Model], Any] = None,
custom_callbacks: List[tf.keras.callbacks.Callback] = None,
save_config: bool = True):
"""Runs distributed training.
Args:
train_input_fn: (params: dict) -> tf.data.Dataset training data input
function.
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
model_dir: the folder path for model checkpoints.
total_steps: total training steps.
iterations_per_loop: train steps per loop. After each loop, this job will
update metrics like loss and save checkpoint.
train_metric_fn: metric_fn for evaluation in train_step.
eval_metric_fn: metric_fn for evaluation in test_step.
summary_writer_fn: function to create summary writer.
init_checkpoint: function to load checkpoint.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
save_config: bool. Whether to save params to model_dir.
Returns:
The training loss and eval metrics.
"""
assert train_input_fn is not None
if train_metric_fn and not callable(train_metric_fn):
raise ValueError('if `train_metric_fn` is specified, '
'train_metric_fn must be a callable.')
if eval_metric_fn and not callable(eval_metric_fn):
raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.')
train_metric_fn = train_metric_fn or _no_metric
eval_metric_fn = eval_metric_fn or _no_metric
if custom_callbacks and iterations_per_loop != 1:
logging.warning(
'It is sematically wrong to run callbacks when '
'iterations_per_loop is not one (%s)', iterations_per_loop)
custom_callbacks = custom_callbacks or []
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
if callback:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
if callback:
callback.on_batch_end(batch)
if save_config:
self._save_config(model_dir)
if FLAGS.save_checkpoint_freq:
save_freq = FLAGS.save_checkpoint_freq
else:
save_freq = iterations_per_loop
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
# input pipeline ops in worker task.
train_iterator = self._get_input_iterator(train_input_fn, strategy)
train_loss = None
eval_metric_result = None
with strategy.scope():
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model = self.model_fn(params.as_dict())
if not hasattr(model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
optimizer = model.optimizer
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
initial_step = 0
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
initial_step = optimizer.iterations.numpy()
logging.info('Loading from checkpoint file completed. Init step %d',
initial_step)
elif init_checkpoint:
logging.info('Restoring from init checkpoint function')
init_checkpoint(model)
logging.info('Loading from init checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = self.checkpoint_name
eval_metric = eval_metric_fn()
train_metric = train_metric_fn()
train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
self.train_summary_writer = train_summary_writer.writer
test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
self.eval_summary_writer = test_summary_writer.writer
# Use training summary writer in TimeHistory if it's in use
for cb in custom_callbacks:
if isinstance(cb, keras_utils.TimeHistory):
cb.summary_writer = self.train_summary_writer
# Continue training loop.
train_step = self._create_train_step(
strategy=strategy,
model=model,
loss_fn=self.loss_fn(),
optimizer=optimizer,
metric=train_metric)
test_step = None
if eval_input_fn and eval_metric:
self.global_train_step = model.optimizer.iterations
test_step = self._create_test_step(strategy, model, metric=eval_metric)
# Step-0 operations
_save_checkpoint(
checkpoint, model_dir, checkpoint_name.format(step=current_step))
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(
test_step, current_step, eval_metric, eval_iterator)
logging.info(
'Step: %s evalation metric = %s.', current_step, eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
eval_metric.reset_states()
logging.info('Training started')
last_save_checkpoint_step = current_step
while current_step < total_steps:
num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
_run_callbacks_on_batch_begin(current_step)
train_loss = train_step(train_iterator,
tf.convert_to_tensor(num_steps, dtype=tf.int32))
current_step += num_steps
train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
train_loss)
_run_callbacks_on_batch_end(current_step - 1)
if not isinstance(train_loss, dict):
train_loss = {'total_loss': train_loss}
if np.isnan(train_loss['total_loss']):
raise ValueError('total loss is NaN.')
if train_metric:
train_metric_result = train_metric.result()
if isinstance(train_metric, tf.keras.metrics.Metric):
train_metric_result = tf.nest.map_structure(
lambda x: x.numpy().astype(float), train_metric_result)
if not isinstance(train_metric_result, dict):
train_metric_result = {'metric': train_metric_result}
train_metric_result.update(train_loss)
else:
train_metric_result = train_loss
if callable(optimizer.lr):
train_metric_result.update(
{'learning_rate': optimizer.lr(current_step).numpy()})
else:
train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
current_step, total_steps, train_loss,
train_metric_result)
train_summary_writer(
metrics=train_metric_result, step=optimizer.iterations)
# Saves model checkpoints and run validation steps at every
# iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last
# step of training.
if save_freq > 0 and current_step < total_steps and (
current_step - last_save_checkpoint_step) >= save_freq:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
# Re-initialize evaluation metric, except the last step.
if eval_metric and current_step < total_steps:
eval_metric.reset_states()
if train_metric and current_step < total_steps:
train_metric.reset_states()
# Reaches the end of training and saves the last checkpoint.
if last_save_checkpoint_step < total_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_step:
logging.info('Running final evaluation after training is complete.')
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Final evaluation metric = %s.', eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
self.train_summary_writer.close()
self.eval_summary_writer.close()
return train_loss, eval_metric_result
def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator):
"""Runs validation steps and aggregate metrics."""
if not test_iterator or not metric:
logging.warning(
'Both test_iterator (%s) and metrics (%s) must not be None.',
test_iterator, metric)
return None
logging.info('Running evaluation after step: %s.', current_training_step)
while True:
try:
test_step(test_iterator)
except (StopIteration, tf.errors.OutOfRangeError):
break
metric_result = metric.result()
if isinstance(metric, tf.keras.metrics.Metric):
metric_result = metric_result.numpy().astype(float)
logging.info('Step: [%d] Validation metric = %f', current_training_step,
metric_result)
return metric_result
def evaluate_from_model_dir(
self,
model_dir: Text,
eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_metric_fn: Callable[[], Any],
total_steps: int = -1,
eval_timeout: int = None,
min_eval_interval: int = 180,
summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
"""Runs distributed evaluation on model folder.
Args:
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
model_dir: the folder for storing model checkpoints.
total_steps: total training steps. If the current step reaches the
total_steps, the evaluation loop will stop.
eval_timeout: The maximum number of seconds to wait between checkpoints.
If left as None, then the process will wait indefinitely. Used by
tf.train.checkpoints_iterator.
min_eval_interval: The minimum number of seconds between yielding
checkpoints. Used by tf.train.checkpoints_iterator.
summary_writer_fn: function to create summary writer.
Returns:
Eval metrics dictionary of the last checkpoint.
"""
if not model_dir:
raise ValueError('model_dir must be set.')
def terminate_eval():
tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
eval_timeout)
return True
summary_writer = summary_writer_fn(model_dir, 'eval')
self.eval_summary_writer = summary_writer.writer
# Read checkpoints from the given model directory
# until `eval_timeout` seconds elapses.
for checkpoint_path in tf.train.checkpoints_iterator(
model_dir,
min_interval_secs=min_eval_interval,
timeout=eval_timeout,
timeout_fn=terminate_eval):
eval_metric_result, current_step = self.evaluate_checkpoint(
checkpoint_path=checkpoint_path,
eval_input_fn=eval_input_fn,
eval_metric_fn=eval_metric_fn,
summary_writer=summary_writer)
if total_steps > 0 and current_step >= total_steps:
logging.info('Evaluation finished after training step %d', current_step)
break
return eval_metric_result
def evaluate_checkpoint(self,
checkpoint_path: Text,
eval_input_fn: Callable[[params_dict.ParamsDict],
tf.data.Dataset],
eval_metric_fn: Callable[[], Any],
summary_writer: SummaryWriter = None):
"""Runs distributed evaluation on the one checkpoint.
Args:
eval_input_fn: (Optional) same type as train_input_fn. If not None, will
trigger evaluting metric on eval data. If None, will not run eval step.
eval_metric_fn: metric_fn for evaluation in test_step.
checkpoint_path: the checkpoint to evaluate.
summary_writer_fn: function to create summary writer.
Returns:
Eval metrics dictionary of the last checkpoint.
"""
if not callable(eval_metric_fn):
raise ValueError('if `eval_metric_fn` is specified, '
'eval_metric_fn must be a callable.')
params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
# input pipeline ops in worker task.
with strategy.scope():
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model = self.model_fn(params.as_dict())
checkpoint = tf.train.Checkpoint(model=model)
eval_metric = eval_metric_fn()
assert eval_metric, 'eval_metric does not exist'
test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Starting to evaluate.')
if not checkpoint_path:
raise ValueError('checkpoint path is empty')
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
current_step = reader.get_tensor(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', checkpoint_path)
checkpoint.restore(checkpoint_path)
self.global_train_step = model.optimizer.iterations
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
summary_writer(metrics=eval_metric_result, step=current_step)
eval_metric.reset_states()
return eval_metric_result, current_step
def predict(self):
return NotImplementedError('Unimplmented function.')
class ExecutorBuilder(object):
"""Builder of DistributedExecutor.
Example 1: Builds an executor with supported Strategy.
builder = ExecutorBuilder(
strategy_type='tpu',
strategy_config={'tpu': '/bns/xxx'})
dist_executor = builder.build_executor(
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
Example 2: Builds an executor with customized Strategy.
builder = ExecutorBuilder()
builder.strategy = <some customized Strategy>
dist_executor = builder.build_executor(
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
Example 3: Builds a customized executor with customized Strategy.
class MyDistributedExecutor(DistributedExecutor):
# implementation ...
builder = ExecutorBuilder()
builder.strategy = <some customized Strategy>
dist_executor = builder.build_executor(
class_ctor=MyDistributedExecutor,
params=params,
model_fn=my_model_fn,
loss_fn=my_loss_fn,
metric_fn=my_metric_fn)
Args:
strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'. If
None. User is responsible to set the strategy before calling
build_executor(...).
strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure.
"""
def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster(
strategy_config.worker_hosts, strategy_config.task_index)
self._strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
num_packs=strategy_config.num_packs,
tpu_address=strategy_config.tpu)
@property
def strategy(self):
"""Returns default checkpoint name."""
return self._strategy
@strategy.setter
def strategy(self, new_strategy):
"""Sets default summary writer for the current thread."""
self._strategy = new_strategy
def build_executor(self,
class_ctor=DistributedExecutor,
params=None,
model_fn=None,
loss_fn=None,
**kwargs):
"""Creates an executor according to strategy type.
See doc string of the DistributedExecutor.__init__ for more information of
the
input arguments.
Args:
class_ctor: A constructor of executor (default: DistributedExecutor).
params: ParamsDict, all the model parameters and runtime parameters.
model_fn: Keras model function.
loss_fn: loss function.
**kwargs: other arguments to the executor constructor.
Returns:
An instance of DistributedExecutor or its subclass.
"""
if self._strategy is None:
raise ValueError('`strategy` should not be None. You need to specify '
'`strategy_type` in the builder contructor or directly '
'set the `strategy` property of the builder.')
return class_ctor(
strategy=self._strategy,
params=params,
model_fn=model_fn,
loss_fn=loss_fn,
**kwargs)
@@ -0,0 +1,19 @@
# TensorFlow Natural Language Processing Modelling Toolkit
tensorflow/models/official/nlp provides a [modeling library](modeling) for constructing
NLP model achitectures, as well as TF2 reference implementations for
state-of-the-art models.
The repository contains the following models, with implementations, pre-trained
model weights, usage scripts and conversion utilities:
* [Albert](albert)
* [Bert](bert)
* [NHNet](nhnet)
* [XLNet](xlnet)
* [Transformer for translation](transformer)
Addtional features:
* Distributed trainable on both multi-GPU and TPU
* e2e training for custom models, including both pretraining and finetuning.
@@ -0,0 +1,332 @@
# ALBERT (ALBERT: A Lite BERT for Self-supervised Learning of Language Representations)
The academic paper which describes ALBERT in detail and provides full results on
a number of tasks can be found here: https://arxiv.org/abs/1909.11942.
This repository contains TensorFlow 2.x implementation for ALBERT.
## Contents
* [Contents](#contents)
* [Pre-trained Models](#pre-trained-models)
* [Restoring from Checkpoints](#restoring-from-checkpoints)
* [Set Up](#set-up)
* [Process Datasets](#process-datasets)
* [Fine-tuning with BERT](#fine-tuning-with-bert)
* [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
* [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
* [SQuAD 1.1](#squad-1.1)
## Pre-trained Models
We released both checkpoints and tf.hub modules as the pretrained models for
fine-tuning. They are TF 2.x compatible and are converted from the ALBERT v2
checkpoints released in TF 1.x official ALBERT repository
[google-research/albert](https://github.com/google-research/albert)
in order to keep consistent with ALBERT paper.
Our current released checkpoints are exactly the same as TF 1.x official ALBERT
repository.
### Access to Pretrained Checkpoints
Pretrained checkpoints can be found in the following links:
**Note: We implemented ALBERT using Keras functional-style networks in [nlp/modeling](../modeling).
ALBERT V2 models compatible with TF 2.x checkpoints are:**
* **[`ALBERT V2 Base`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base.tar.gz)**:
12-layer, 768-hidden, 12-heads, 12M parameters
* **[`ALBERT V2 Large`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_large.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 18M parameters
* **[`ALBERT V2 XLarge`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_xlarge.tar.gz)**:
24-layer, 2048-hidden, 32-heads, 60M parameters
* **[`ALBERT V2 XXLarge`](https://storage.googleapis.com/cloud-tpu-checkpoints/albert/checkpoints/albert_v2_xxlarge.tar.gz)**:
12-layer, 4096-hidden, 64-heads, 235M parameters
We recommend to host checkpoints on Google Cloud storage buckets when you use
Cloud GPU/TPU.
### Restoring from Checkpoints
`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
weights from provided pre-trained checkpoints, you can use the following code:
```python
init_checkpoint='the pretrained model checkpoint path.'
model=tf.keras.Model() # Bert pre-trained model as feature extractor.
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
```
Checkpoints featuring native serialized Keras models
(i.e. model.load()/load_weights()) will be available soon.
### Access to Pretrained hub modules.
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
following links:
* **[`ALBERT V2 Base`](https://tfhub.dev/tensorflow/albert_en_base/1)**:
12-layer, 768-hidden, 12-heads, 12M parameters
* **[`ALBERT V2 Large`](https://tfhub.dev/tensorflow/albert_en_large/1)**:
24-layer, 1024-hidden, 16-heads, 18M parameters
* **[`ALBERT V2 XLarge`](https://tfhub.dev/tensorflow/albert_en_xlarge/1)**:
24-layer, 2048-hidden, 32-heads, 60M parameters
* **[`ALBERT V2 XXLarge`](https://tfhub.dev/tensorflow/albert_en_xxlarge/1)**:
12-layer, 4096-hidden, 64-heads, 235M parameters
## Set Up
```shell
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```
Install `tf-nightly` to get latest updates:
```shell
pip install tf-nightly-gpu
```
With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
```shell
ctpu up -name <instance name> --tf-version=”nightly”
```
Second, you need to install TF 2 `tf-nightly` on your VM:
```shell
pip install tf-nightly
```
Warning: More details TPU-specific set-up instructions and tutorial should come
along with official TF 2.x release for TPU. Note that this repo is not
officially supported by Google Cloud TPU team yet until TF 2.1 released.
## Process Datasets
### Pre-training
Pre-train ALBERT using TF2.x will come soon.
For now, please use [ALBERT research repo](https://github.com/google-research/ALBERT)
to pretrain the model and convert the checkpoint to TF2.x compatible ones using
[tf2_albert_encoder_checkpoint_converter.py](tf2_albert_encoder_checkpoint_converter.py).
### Fine-tuning
To prepare the fine-tuning data for final model training, use the
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
Note that different from BERT models that use word piece tokenzer,
ALBERT models employ sentence piece tokenizer. So the FLAG tokenizer_impl has
to be set to 'sentence_piece'.
Resulting datasets in `tf_record` format and training meta data should be later
passed to training or evaluation scripts. The task-specific arguments are
described in following sections:
* GLUE
Users can download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.
```shell
export GLUE_DIR=~/glue
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
export TASK_NAME=MNLI
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
--sp_model_file=${ALBERT_DIR}/30k-clean.model \
--train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
--eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
--fine_tuning_task_type=classification --max_seq_length=128 \
--classification_task_name=${TASK_NAME} \
--tokenizer_impl=sentence_piece
```
* SQUAD
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
detailed information about the SQuAD datasets and evaluation.
The necessary files can be found here:
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
```shell
export SQUAD_DIR=~/squad
export SQUAD_VERSION=v1.1
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
--sp_model_file=${ALBERT_DIR}/30k-clean.model \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
--fine_tuning_task_type=squad --max_seq_length=384 \
--tokenizer_impl=sentence_piece
```
## Fine-tuning with ALBERT
### Cloud GPUs and TPUs
* Cloud Storage
The unzipped pre-trained model files can also be found in the Google Cloud
Storage folder `gs://cloud-tpu-checkpoints/albert/checkpoints`. For example:
```shell
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
export MODEL_DIR=gs://some_bucket/my_output_dir
```
Currently, users are able to access to `tf-nightly` TPUs and the following TPU
script should run with `tf-nightly`.
* GPU -> TPU
Just add the following flags to `run_classifier.py` or `run_squad.py`:
```shell
--distribution_strategy=tpu
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
### Sentence and Sentence-pair Classification Tasks
This example code fine-tunes `albert_v2_base` on the Microsoft Research
Paraphrase Corpus (MRPC) corpus, which only contains 3,600 examples and can
fine-tune in a few minutes on most GPUs.
We use the `albert_v2_base` as an example throughout the
workflow.
```shell
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${ALBERT_DIR}/albert_config.json \
--init_checkpoint=${ALBERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--eval_batch_size=4 \
--steps_per_loop=1 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Alternatively, instead of specifying `init_checkpoint`, you can specify
`hub_module_url` to employ a pretraind BERT hub module, e.g.,
` --hub_module_url=https://tfhub.dev/tensorflow/albert_en_base/1`.
To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
information and use remote storage for model checkpoints.
```shell
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=$ALBERT_DIR/albert_config.json \
--init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
--train_batch_size=32 \
--eval_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
### SQuAD 1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
benchmark dataset. See more in [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
We use the `albert_v2_base` as an example throughout the
workflow.
```shell
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
export SQUAD_DIR=gs://some_bucket/datasets
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--sp_model_file=${ALBERT_DIR}/30k-clean.model \
--bert_config_file=$ALBERT_DIR/albert_config.json \
--init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
--train_batch_size=4 \
--predict_batch_size=4 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Similarily, you can replace `init_checkpoint` FLAGS with `hub_module_url` to
specify a hub module path.
To use TPU, you need switch distribution strategy type to `tpu` with TPU
information.
```shell
export ALBERT_DIR=gs://cloud-tpu-checkpoints/albert/checkpoints/albert_v2_base
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_DIR=gs://some_bucket/datasets
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--sp_model_file=${ALBERT_DIR}/30k-clean.model \
--bert_config_file=$ALBERT_DIR/albert_config.json \
--init_checkpoint=$ALBERT_DIR/bert_model.ckpt \
--train_batch_size=32 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
The dev set predictions will be saved into a file called predictions.json in the
model_dir:
```shell
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
```
@@ -0,0 +1,61 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The ALBERT configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from official.nlp.bert import configs
class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`."""
def __init__(self,
embedding_size,
num_hidden_groups=1,
inner_group_num=1,
**kwargs):
"""Constructs AlbertConfig.
Args:
embedding_size: Size of the factorized word embeddings.
num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared. Note that this value and also the following
'inner_group_num' has to be 1 for now, because all released ALBERT
models set them to 1. We may support arbitary valid values in future.
inner_group_num: Number of inner repetition of attention and ffn.
**kwargs: The remaining arguments are the same as above 'BertConfig'.
"""
super(AlbertConfig, self).__init__(**kwargs)
self.embedding_size = embedding_size
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertTransformerEncoder
# if needed.
if inner_group_num != 1 or num_hidden_groups != 1:
raise ValueError("We only support 'inner_group_num' and "
"'num_hidden_groups' as 1.")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
config = AlbertConfig(embedding_size=None, vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@@ -0,0 +1,88 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export the ALBERT core model as a TF-Hub SavedModel."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
from typing import Text
from official.nlp.albert import configs
from official.nlp.bert import bert_models
FLAGS = flags.FLAGS
flags.DEFINE_string("albert_config_file", None,
"Albert configuration file to define core albert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string(
"sp_model_file", None,
"The sentence piece model file that the ALBERT model was trained on.")
def create_albert_model(
albert_config: configs.AlbertConfig) -> tf.keras.Model:
"""Creates an ALBERT keras core model from ALBERT configuration.
Args:
albert_config: An `AlbertConfig` to create the core model.
Returns:
A keras model.
"""
# Adds input layers just as placeholders.
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids")
transformer_encoder = bert_models.get_transformer_encoder(
albert_config, sequence_length=None)
sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_albert_tfhub(albert_config: configs.AlbertConfig,
model_checkpoint_path: Text, hub_destination: Text,
sp_model_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model, encoder = create_albert_model(albert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_):
albert_config = configs.AlbertConfig.from_json_file(
FLAGS.albert_config_file)
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.sp_model_file)
if __name__ == "__main__":
app.run(main)
@@ -0,0 +1,89 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests official.nlp.albert.export_albert_tfhub."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.nlp.albert import configs
from official.nlp.albert import export_albert_tfhub
class ExportAlbertTfhubTest(tf.test.TestCase):
def test_export_albert_tfhub(self):
# Exports a savedmodel for TF-Hub
albert_config = configs.AlbertConfig(
vocab_size=100,
embedding_size=8,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model, encoder = export_albert_tfhub.create_albert_model(albert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
sp_model_file = os.path.join(self.get_temp_dir(), "sp_tokenizer.model")
with tf.io.gfile.GFile(sp_model_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_albert_tfhub.export_albert_tfhub(
albert_config,
model_checkpoint_path,
hub_destination,
sp_model_file=sp_model_file)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
if hasattr(hub_layer, "resolved_object"):
with tf.io.gfile.GFile(
hub_layer.resolved_object.sp_model_file.asset_path.numpy()) as f:
self.assertEqual("dummy content", f.read())
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
if __name__ == "__main__":
tf.test.main()
@@ -0,0 +1,69 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from absl import app
from absl import flags
import tensorflow as tf
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
def main(_):
# Users should always run this script under TF 2.x
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = run_classifier_bert.get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
eval_input_fn = run_classifier_bert.get_dataset_fn(
FLAGS.eval_data_path,
max_seq_length,
FLAGS.eval_batch_size,
is_training=False)
albert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
train_input_fn, eval_input_fn)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('input_meta_data_path')
flags.mark_flag_as_required('model_dir')
app.run(main)
@@ -0,0 +1,139 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
flags.DEFINE_string(
'sp_model_file', None,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.')
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False):
"""Runs bert squad training."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for the squad dataset."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=FLAGS.sp_model_file)
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
bert_config, squad_lib_sp)
def eval_squad(strategy, input_meta_data):
"""Evaluate on the squad dataset."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=FLAGS.sp_model_file)
eval_metrics = run_squad_helper.eval_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_sp)
return eval_metrics
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
# Users should always run this script under TF 2.x
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if 'train' in FLAGS.mode:
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode:
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
summary_writer = tf.summary.create_file_writer(summary_dir)
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
# Also write eval_metrics to json file.
squad_lib_sp.write_to_json_files(
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
time.sleep(60)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
@@ -0,0 +1,132 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A converter from a tf1 ALBERT encoder checkpoint to a tf2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a AlbertTransformerEncoder object.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.modeling import activations
from official.nlp.albert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import networks
FLAGS = flags.FLAGS
flags.DEFINE_string("albert_config_file", None,
"Albert configuration file to define core bert layers.")
flags.DEFINE_string(
"checkpoint_to_convert", None,
"Initial checkpoint from a pretrained BERT model core (that is, only the "
"BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.")
ALBERT_NAME_REPLACEMENTS = (
("bert/encoder/", ""),
("bert/", ""),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/position_embeddings", "position_embedding/embeddings"),
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
("embeddings/LayerNorm", "embeddings/layer_norm"),
("embedding_hidden_mapping_in", "embedding_projection"),
("group_0/inner_group_0/", ""),
("attention_1/self", "self_attention"),
("attention_1/output/dense", "self_attention_output"),
("LayerNorm/", "self_attention_layer_norm/"),
("ffn_1/intermediate/dense", "intermediate"),
("ffn_1/intermediate/output/dense", "output"),
("LayerNorm_1/", "output_layer_norm/"),
("pooler/dense", "pooler_transform"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights",
"predictions/transform/logits/kernel"),
)
def _create_albert_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A keras model.
"""
albert_encoder = networks.AlbertTransformerEncoder(
vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size,
embedding_width=cfg.embedding_size,
num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size,
activation=activations.gelu,
dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob,
sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
return albert_encoder
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint,
num_heads=bert_config.num_attention_heads,
name_replacements=ALBERT_NAME_REPLACEMENTS,
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"])
# Create a V2 checkpoint from the temporary checkpoint.
model = _create_albert_model(bert_config)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path)
# Clean up the temporary checkpoint, if it exists.
try:
tf.io.gfile.rmtree(temporary_checkpoint_dir)
except tf.errors.OpError:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def main(_):
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
convert_checkpoint(albert_config, output_path, v1_checkpoint)
if __name__ == "__main__":
app.run(main)
@@ -0,0 +1,350 @@
# BERT (Bidirectional Encoder Representations from Transformers)
The academic paper which describes BERT in detail and provides full results on a
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
This repository contains TensorFlow 2.x implementation for BERT.
## Contents
* [Contents](#contents)
* [Pre-trained Models](#pre-trained-models)
* [Restoring from Checkpoints](#restoring-from-checkpoints)
* [Set Up](#set-up)
* [Process Datasets](#process-datasets)
* [Fine-tuning with BERT](#fine-tuning-with-bert)
* [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
* [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
* [SQuAD 1.1](#squad-1.1)
## Pre-trained Models
We released both checkpoints and tf.hub modules as the pretrained models for
fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
released in TF 1.x official BERT repository
[google-research/bert](https://github.com/google-research/bert)
in order to keep consistent with BERT paper.
### Access to Pretrained Checkpoints
Pretrained checkpoints can be found in the following links:
**Note: We have switched BERT implementation
to use Keras functional-style networks in [nlp/modeling](../modeling).
The new checkpoints are:**
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
We recommend to host checkpoints on Google Cloud storage buckets when you use
Cloud GPU/TPU.
### Restoring from Checkpoints
`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
weights from provided pre-trained checkpoints, you can use the following code:
```python
init_checkpoint='the pretrained model checkpoint path.'
model=tf.keras.Model() # Bert pre-trained model as feature extractor.
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
```
Checkpoints featuring native serialized Keras models
(i.e. model.load()/load_weights()) will be available soon.
### Access to Pretrained hub modules.
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
following links:
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/1)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/1)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/1)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/1)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/1)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/1)**:
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
110M parameters
## Set Up
```shell
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```
Install `tf-nightly` to get latest updates:
```shell
pip install tf-nightly-gpu
```
With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
```shell
ctpu up -name <instance name> --tf-version=”nightly”
```
Second, you need to install TF 2 `tf-nightly` on your VM:
```shell
pip install tf-nightly
```
Warning: More details TPU-specific set-up instructions and tutorial should come
along with official TF 2.x release for TPU. Note that this repo is not
officially supported by Google Cloud TPU team yet until TF 2.1 released.
## Process Datasets
### Pre-training
There is no change to generate pre-training data. Please use the script
[`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
which is essentially branched from [BERT research repo](https://github.com/google-research/bert)
to get processed pre-training data and it adapts to TF2 symbols and python3
compatibility.
### Fine-tuning
To prepare the fine-tuning data for final model training, use the
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
Resulting datasets in `tf_record` format and training meta data should be later
passed to training or evaluation scripts. The task-specific arguments are
described in following sections:
* GLUE
Users can download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.
```shell
export GLUE_DIR=~/glue
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TASK_NAME=MNLI
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
--vocab_file=${BERT_DIR}/vocab.txt \
--train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
--eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
--fine_tuning_task_type=classification --max_seq_length=128 \
--classification_task_name=${TASK_NAME}
```
* SQUAD
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
detailed information about the SQuAD datasets and evaluation.
The necessary files can be found here:
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
```shell
export SQUAD_DIR=~/squad
export SQUAD_VERSION=v1.1
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
--fine_tuning_task_type=squad --max_seq_length=384
```
## Fine-tuning with BERT
### Cloud GPUs and TPUs
* Cloud Storage
The unzipped pre-trained model files can also be found in the Google Cloud
Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export MODEL_DIR=gs://some_bucket/my_output_dir
```
Currently, users are able to access to `tf-nightly` TPUs and the following TPU
script should run with `tf-nightly`.
* GPU -> TPU
Just add the following flags to `run_classifier.py` or `run_squad.py`:
```shell
--distribution_strategy=tpu
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
### Sentence and Sentence-pair Classification Tasks
This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
few minutes on most GPUs.
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
workflow.
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
(uncased_L-12_H-768_A-12).
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--eval_batch_size=4 \
--steps_per_loop=1 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Alternatively, instead of specifying `init_checkpoint`, you can specify
`hub_module_url` to employ a pretraind BERT hub module, e.g.,
` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
information and use remote storage for model checkpoints.
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--eval_batch_size=32 \
--steps_per_loop=1000 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
training steps inside a `tf.function` can significantly increase TPU utilization
and callbacks will not be called inside the loop.
### SQuAD 1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
benchmark dataset. See more in [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
workflow.
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
(uncased_L-12_H-768_A-12).
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export SQUAD_DIR=gs://some_bucket/datasets
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--predict_batch_size=4 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Similarily, you can replace `init_checkpoint` FLAG with `hub_module_url` to
specify a hub module path.
To use TPU, you need switch distribution strategy type to `tpu` with TPU
information.
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_DIR=gs://some_bucket/datasets
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
The dev set predictions will be saved into a file called predictions.json in the
model_dir:
```shell
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
```
@@ -0,0 +1,110 @@
# BERT FineTuning with Cloud TPU: Sentence and Sentence-Pair Classification Tasks (TF 2.1)
This tutorial shows you how to train the Bidirectional Encoder Representations from Transformers (BERT) model on Cloud TPU.
## Set up Cloud Storage and Compute Engine VM
1. [Open a cloud shell window](https://console.cloud.google.com/?cloudshell=true&_ga=2.11844148.-1612541229.1552429951)
2. Create a variable for the project's name:
```
export PROJECT_NAME=your-project_name
```
3. Configure `gcloud` command-line tool to use the project where you want to create Cloud TPU.
```
gcloud config set project ${PROJECT_NAME}
```
4. Create a Cloud Storage bucket using the following command:
```
gsutil mb -p ${PROJECT_NAME} -c standard -l europe-west4 -b on gs://your-bucket-name
```
This Cloud Storage bucket stores the data you use to train your model and the training results.
5. Launch a Compute Engine VM and Cloud TPU using the ctpu up command.
```
ctpu up --tpu-size=v3-8 \
--machine-type=n1-standard-8 \
--zone=europe-west4-a \
--tf-version=2.1 [optional flags: --project, --name]
```
6. The configuration you specified appears. Enter y to approve or n to cancel.
7. When the ctpu up command has finished executing, verify that your shell prompt has changed from username@project to username@tpuname. This change shows that you are now logged into your Compute Engine VM.
```
gcloud compute ssh vm-name --zone=europe-west4-a
(vm)$ export TPU_NAME=vm-name
```
As you continue these instructions, run each command that begins with `(vm)$` in your VM session window.
## Prepare the Dataset
1. From your Compute Engine virtual machine (VM), install requirements.txt.
```
(vm)$ cd /usr/share/models
(vm)$ sudo pip3 install -r official/requirements.txt
```
2. Optional: download download_glue_data.py
This tutorial uses the General Language Understanding Evaluation (GLUE) benchmark to evaluate and analyze the performance of the model. The GLUE data is provided for this tutorial at gs://cloud-tpu-checkpoints/bert/classification.
## Define parameter values
Next, define several parameter values that are required when you train and evaluate your model:
```
(vm)$ export PYTHONPATH="$PYTHONPATH:/usr/share/tpu/models"
(vm)$ export STORAGE_BUCKET=gs://your-bucket-name
(vm)$ export BERT_BASE_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
(vm)$ export MODEL_DIR=${STORAGE_BUCKET}/bert-output
(vm)$ export GLUE_DIR=gs://cloud-tpu-checkpoints/bert/classification
(vm)$ export TASK=mnli
```
## Train the model
From your Compute Engine VM, run the following command.
```
(vm)$ python3 official/nlp/bert/run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--train_batch_size=32 \
--eval_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=${TPU_NAME}
```
## Verify your results
The training takes approximately 1 hour on a v3-8 TPU. When script completes, you should see results similar to the following:
```
Training Summary:
{'train_loss': 0.28142181038856506,
'last_train_metrics': 0.9467429518699646,
'eval_metrics': 0.8599063158035278,
'total_training_steps': 36813}
```
## Clean up
To avoid incurring charges to your GCP account for the resources used in this topic:
1. Disconnect from the Compute Engine VM:
```
(vm)$ exit
```
2. In your Cloud Shell, run ctpu delete with the --zone flag you used when you set up the Cloud TPU to delete your Compute Engine VM and your Cloud TPU:
```
$ ctpu delete --zone=your-zone
```
3. Run ctpu status specifying your zone to make sure you have no instances allocated to avoid unnecessary charges for TPU usage. The deletion might take several minutes. A response like the one below indicates there are no more allocated instances:
```
$ ctpu status --zone=your-zone
```
4. Run gsutil as shown, replacing your-bucket with the name of the Cloud Storage bucket you created for this tutorial:
```
$ gsutil rm -r gs://your-bucket
```
@@ -0,0 +1,349 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gin
import tensorflow as tf
import tensorflow_hub as hub
from official.modeling import tf_utils
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import configs
from official.nlp.modeling import losses
from official.nlp.modeling import models
from official.nlp.modeling import networks
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
"""Returns layer that computes custom loss and metrics for pretraining."""
def __init__(self, vocab_size, **kwargs):
super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
self._vocab_size = vocab_size
self.config = {
'vocab_size': vocab_size,
}
def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
lm_example_loss, sentence_output, sentence_labels,
next_sentence_loss):
"""Adds metrics."""
masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
lm_labels, lm_output)
numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
denominator = tf.reduce_sum(lm_label_weights) + 1e-5
masked_lm_accuracy = numerator / denominator
self.add_metric(
masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
if sentence_labels is not None:
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
sentence_labels, sentence_output)
self.add_metric(
next_sentence_accuracy,
name='next_sentence_accuracy',
aggregation='mean')
if next_sentence_loss is not None:
self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self,
lm_output,
sentence_output,
lm_label_ids,
lm_label_weights,
sentence_labels=None):
"""Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
if sentence_labels is not None:
sentence_output = tf.cast(sentence_output, tf.float32)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output)
loss = mask_label_loss + sentence_loss
else:
sentence_loss = None
loss = mask_label_loss
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss)
self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
mask_label_loss, sentence_output, sentence_labels,
sentence_loss)
return final_loss
@gin.configurable
def get_transformer_encoder(bert_config,
sequence_length,
transformer_encoder_cls=None):
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data.
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation.
Returns:
A networks.TransformerEncoder object.
"""
if transformer_encoder_cls is not None:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size,
seq_length=sequence_length,
max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
dropout_rate=bert_config.hidden_dropout_prob,
)
hidden_cfg = dict(
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers,
pooled_output_dim=bert_config.hidden_size,
)
# Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs)
kwargs = dict(
vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=sequence_length,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
if isinstance(bert_config, albert_configs.AlbertConfig):
kwargs['embedding_width'] = bert_config.embedding_size
return networks.AlbertTransformerEncoder(**kwargs)
else:
assert isinstance(bert_config, configs.BertConfig)
return networks.TransformerEncoder(**kwargs)
def pretrain_model(bert_config,
seq_length,
max_predictions_per_seq,
initializer=None,
use_next_sentence_label=True):
"""Returns model to be used for pre-training.
Args:
bert_config: Configuration that defines the core BERT model.
seq_length: Maximum sequence length of the training data.
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns:
Pretraining model as well as core BERT submodel from which to save
weights after pretraining.
"""
input_word_ids = tf.keras.layers.Input(
shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
input_mask = tf.keras.layers.Input(
shape=(seq_length,), name='input_mask', dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(
shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
masked_lm_positions = tf.keras.layers.Input(
shape=(max_predictions_per_seq,),
name='masked_lm_positions',
dtype=tf.int32)
masked_lm_ids = tf.keras.layers.Input(
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
masked_lm_weights = tf.keras.layers.Input(
shape=(max_predictions_per_seq,),
name='masked_lm_weights',
dtype=tf.int32)
if use_next_sentence_label:
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
else:
next_sentence_labels = None
transformer_encoder = get_transformer_encoder(bert_config, seq_length)
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
pretrainer_model = models.BertPretrainer(
network=transformer_encoder,
num_classes=2, # The next sentence prediction label has two classes.
num_token_predictions=max_predictions_per_seq,
initializer=initializer,
output='predictions')
lm_output, sentence_output = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions])
pretrain_loss_layer = BertPretrainLossAndMetricLayer(
vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
masked_lm_weights, next_sentence_labels)
inputs = {
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
'masked_lm_positions': masked_lm_positions,
'masked_lm_ids': masked_lm_ids,
'masked_lm_weights': masked_lm_weights,
}
if use_next_sentence_label:
inputs['next_sentence_labels'] = next_sentence_labels
keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
return keras_model, transformer_encoder
def squad_model(bert_config,
max_seq_length,
initializer=None,
hub_module_url=None,
hub_module_trainable=True):
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
bert_config: BertConfig, the config defines the core Bert model.
max_seq_length: integer, the maximum input sequence length.
initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
Returns:
A tuple of (1) keras model that outputs start logits and end logits and
(2) the core BERT transformer encoder.
"""
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return models.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
pooled_output, sequence_output = core_model(
[input_word_ids, input_mask, input_type_ids])
bert_encoder = tf.keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
},
outputs=[sequence_output, pooled_output],
name='core_model')
return models.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder
def classifier_model(bert_config,
num_labels,
max_seq_length,
final_layer_initializer=None,
hub_module_url=None,
hub_module_trainable=True):
"""BERT classifier model in functional API style.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
ALBERT model.
num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
Returns:
Combined prediction model (words, mask, type) -> (one-hot labels)
BERT sub-model (words, mask, type) -> (bert_outputs)
"""
if final_layer_initializer is not None:
initializer = final_layer_initializer
else:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return models.BertClassifier(
bert_encoder,
num_classes=num_labels,
dropout_rate=bert_config.hidden_dropout_prob,
initializer=initializer), bert_encoder
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output)
output = tf.keras.layers.Dense(
num_labels, kernel_initializer=initializer, name='output')(
output)
return tf.keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids
},
outputs=output), bert_model
@@ -0,0 +1,118 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defining common flags used across all BERT models/applications."""
from absl import flags
import tensorflow as tf
from official.utils.flags import core as flags_core
def define_gin_flags():
"""Define common gin configurable flags."""
flags.DEFINE_multi_string('gin_file', None,
'List of paths to the config files.')
flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.')
def define_common_bert_flags():
"""Define common flags for BERT tasks."""
flags_core.define_base(
data_dir=False,
model_dir=True,
clean=False,
train_epochs=False,
epochs_between_evals=False,
stop_threshold=False,
batch_size=False,
num_gpu=True,
hooks=False,
export_dir=False,
distribution_strategy=True,
run_eagerly=True)
flags_core.define_distribution()
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string(
'model_export_path', None,
'Path to the directory, where trainined model will be '
'exported.')
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer(
'steps_per_loop', 1,
'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called '
'inside.')
flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.')
flags.DEFINE_float('end_lr', 0.0,
'The end learning rate for learning rate decay.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags.DEFINE_boolean(
'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.')
flags.DEFINE_boolean(
'use_keras_compile_fit', False,
'If True, uses Keras compile/fit() API for training logic. Otherwise '
'use custom training loop.')
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags_core.define_log_steps()
# Adds flags for mixed precision and multi-worker training.
flags_core.define_performance(
num_parallel_calls=False,
inter_op=False,
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=True,
dynamic_loss_scale=True,
loss_scale=True,
all_reduce_alg=True,
num_packs=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True,
fp16_implementation=True,
)
def dtype():
return flags_core.get_tf_dtype(flags.FLAGS)
def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def use_graph_rewrite():
return flags.FLAGS.fp16_implementation == 'graph_rewrite'
def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
@@ -0,0 +1,105 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The main BERT model and related functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import six
import tensorflow as tf
class BertConfig(object):
"""Configuration for `BertModel`."""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
backward_compatible=True):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
backward_compatible: Boolean, whether the variables shape are compatible
with checkpoints converted from TF 1.x BERT.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.backward_compatible = backward_compatible
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.io.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
@@ -0,0 +1,86 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export the BERT core model as a TF-Hub SavedModel."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
from typing import Text
from official.nlp.bert import bert_models
from official.nlp.bert import configs
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
"""Creates a BERT keras core model from BERT configuration.
Args:
bert_config: A `BertConfig` to create the core model.
Returns:
A keras model.
"""
# Adds input layers just as placeholders.
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids")
transformer_encoder = bert_models.get_transformer_encoder(
bert_config, sequence_length=None)
sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_bert_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text, hub_destination: Text,
vocab_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable(
"uncased" in vocab_file, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file)
if __name__ == "__main__":
app.run(main)
@@ -0,0 +1,87 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests official.nlp.bert.export_tfhub."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
class ExportTfhubTest(tf.test.TestCase):
def test_export_tfhub(self):
# Exports a savedmodel for TF-Hub
bert_config = configs.BertConfig(
vocab_size=100,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, vocab_file)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
if hasattr(hub_layer, "resolved_object"):
# Checks meta attributes.
self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
with tf.io.gfile.GFile(
hub_layer.resolved_object.vocab_file.asset_path.numpy()) as f:
self.assertEqual("dummy content", f.read())
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
if __name__ == "__main__":
tf.test.main()
@@ -0,0 +1,231 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT model input pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def single_file_dataset(input_file, name_to_features):
"""Creates a single-file dataset to be passed for BERT custom training."""
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
d = d.map(lambda record: decode_record(record, name_to_features))
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
# same input file is sent to all workers.
if isinstance(input_file, str) or len(input_file) == 1:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
d = d.with_options(options)
return d
def create_pretrain_dataset(input_patterns,
seq_length,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=None,
use_next_sentence_label=True):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
'input_ids':
tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask':
tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids':
tf.io.FixedLenFeature([seq_length], tf.int64),
'masked_lm_positions':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_ids':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
'masked_lm_weights':
tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
}
if use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
for input_pattern in input_patterns:
if not tf.io.gfile.glob(input_pattern):
raise ValueError('%s does not match any files.' % input_pattern)
dataset = tf.data.Dataset.list_files(input_patterns, shuffle=is_training)
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
if is_training:
dataset = dataset.repeat()
# We set shuffle buffer to exactly match total number of
# training files to ensure that training data is well shuffled.
input_files = []
for input_pattern in input_patterns:
input_files.extend(tf.io.gfile.glob(input_pattern))
dataset = dataset.shuffle(len(input_files))
# In parallel, create tf record dataset for each train files.
# cycle_length = 8 means that up to 8 files will be read and deserialized in
# parallel. You may want to increase this number if you have a large number of
# CPU cores.
dataset = dataset.interleave(
tf.data.TFRecordDataset, cycle_length=8,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
decode_fn = lambda record: decode_record(record, name_to_features)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def _select_data_from_record(record):
"""Filter out features to use for pretraining."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'],
'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'],
}
if use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels']
y = record['masked_lm_weights']
return (x, y)
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(1024)
return dataset
def create_classifier_dataset(file_path,
seq_length,
batch_size,
is_training=True,
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'is_real_example': tf.io.FixedLenFeature([], tf.int64),
}
dataset = single_file_dataset(file_path, name_to_features)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
def _select_data_from_record(record):
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
y = record['label_ids']
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(1024)
return dataset
def create_squad_dataset(file_path,
seq_length,
batch_size,
is_training=True,
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
}
if is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
else:
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
dataset = single_file_dataset(file_path, name_to_features)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
def _select_data_from_record(record):
"""Dispatches record to features and labels."""
x, y = {}, {}
for name, tensor in record.items():
if name in ('start_positions', 'end_positions'):
y[name] = tensor
elif name == 'input_ids':
x['input_word_ids'] = tensor
elif name == 'segment_ids':
x['input_type_ids'] = tensor
else:
x[name] = tensor
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(1024)
return dataset
@@ -0,0 +1,101 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to save models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
from absl import logging
import tensorflow as tf
import typing
def export_bert_model(model_export_path: typing.Text,
model: tf.keras.Model,
checkpoint_dir: typing.Optional[typing.Text] = None,
restore_model_using_load_weights: bool = False) -> None:
"""Export BERT model for serving which does not include the optimizer.
Arguments:
model_export_path: Path to which exported model will be saved.
model: Keras model object to export.
checkpoint_dir: Path from which model weights will be loaded, if
specified.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API.
There are 2 different ways to save checkpoints. One is using
tf.train.Checkpoint and another is using Keras model.save_weights().
Custom training loop implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint callback internally uses model.save_weights()
API. Since these two API's cannot be used toghether, model loading logic
must be take into account how model checkpoint was saved.
Raises:
ValueError when either model_export_path or model is not specified.
"""
if not model_export_path:
raise ValueError('model_export_path must be specified.')
if not isinstance(model, tf.keras.Model):
raise ValueError('model must be a tf.keras.Model object.')
if checkpoint_dir:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if restore_model_using_load_weights:
model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
assert tf.io.gfile.exists(model_weight_path)
model.load_weights(model_weight_path)
# tf.train.Checkpoint API was used via custom training loop logic.
else:
checkpoint = tf.train.Checkpoint(model=model)
# Restores the model from latest checkpoint.
latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
model.save(model_export_path, include_optimizer=False, save_format='tf')
class BertModelCheckpoint(tf.keras.callbacks.Callback):
"""Keras callback that saves model at the end of every epoch."""
def __init__(self, checkpoint_dir, checkpoint):
"""Initializes BertModelCheckpoint.
Arguments:
checkpoint_dir: Directory of the to be saved checkpoint file.
checkpoint: tf.train.Checkpoint object.
"""
super(BertModelCheckpoint, self).__init__()
self.checkpoint_file_name = os.path.join(
checkpoint_dir, 'bert_training_checkpoint_step_{global_step}.ckpt')
assert isinstance(checkpoint, tf.train.Checkpoint)
self.checkpoint = checkpoint
def on_epoch_end(self, epoch, logs=None):
global_step = tf.keras.backend.get_value(self.model.optimizer.iterations)
formatted_file_name = self.checkpoint_file_name.format(
global_step=global_step)
saved_path = self.checkpoint.save(formatted_file_name)
logging.info('Saving model TF checkpoint to : %s', saved_path)
@@ -0,0 +1,491 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A light weight utilities to train NLP models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import tempfile
from absl import logging
import tensorflow as tf
from official.staging.training import grad_utils
from official.utils.misc import distribution_utils
_SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10
def _should_export_checkpoint(strategy):
return (not strategy) or strategy.extended.should_checkpoint
def _should_export_summary(strategy):
return (not strategy) or strategy.extended.should_save_summary
def _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix."""
if _should_export_checkpoint(strategy):
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
else:
# In multi worker training we need every worker to save checkpoint, because
# variables can trigger synchronization on read and synchronization needs
# all workers to participate. To avoid workers overriding each other we save
# to a temporary directory on non-chief workers.
tmp_dir = tempfile.mkdtemp()
checkpoint.save(os.path.join(tmp_dir, 'ckpt'))
tf.io.gfile.rmtree(tmp_dir)
return
def _get_input_iterator(input_fn, strategy):
"""Returns distributed dataset iterator."""
# When training with TPU pods, datasets needs to be cloned across
# workers. Since Dataset instance cannot be cloned in eager mode, we instead
# pass callable that returns a dataset.
if not callable(input_fn):
raise ValueError('`input_fn` should be a closure that returns a dataset.')
iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_fn))
return iterator
def _float_metric_value(metric):
"""Gets the value of a float-value keras metric."""
return metric.result().numpy().astype(float)
def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
if steps_per_loop == 1:
return steps_per_loop
remainder_in_epoch = current_step % steps_per_epoch
if remainder_in_epoch != 0:
return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
else:
return steps_per_loop
def write_txt_summary(training_summary, summary_dir):
"""Writes a summary text file to record stats."""
summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4))
def run_customized_training_loop(
# pylint: disable=invalid-name
_sentinel=None,
# pylint: enable=invalid-name
strategy=None,
model_fn=None,
loss_fn=None,
scale_loss=True,
model_dir=None,
train_input_fn=None,
steps_per_epoch=None,
steps_per_loop=1,
epochs=1,
eval_input_fn=None,
eval_steps=None,
metric_fn=None,
init_checkpoint=None,
custom_callbacks=None,
run_eagerly=False,
sub_model_export_name=None,
explicit_allreduce=False,
pre_allreduce_callbacks=None,
post_allreduce_callbacks=None):
"""Run BERT pretrain model training using low-level API.
Arguments:
_sentinel: Used to prevent positional parameters. Internal, do not use.
strategy: Distribution strategy on which to run low level training loop.
model_fn: Function that returns a tuple (model, sub_model). Caller of this
function should add optimizer to the `model` via calling
`model.compile()` API or manually setting `model.optimizer` attribute.
Second element of the returned tuple(sub_model) is an optional sub model
to be used for initial checkpoint -- if provided.
loss_fn: Function with signature func(labels, logits) and returns a loss
tensor.
scale_loss: Whether to divide the raw loss by number of replicas before
gradients calculation.
model_dir: Model directory used during training for restoring/saving model
weights.
train_input_fn: Function that returns a tf.data.Dataset used for training.
steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every
steps_per_loop.
epochs: Number of epochs to train.
eval_input_fn: Function that returns evaluation dataset. If none,
evaluation is skipped.
eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
is not none.
metric_fn: A metrics function that returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by
`model_fn` into checkpoint files. The name of intermediate checkpoint
file is {sub_model_export_name}_step_{step}.ckpt and the last
checkpint's name is {sub_model_export_name}.ckpt;
if None, `sub_model` will not be exported as checkpoint.
explicit_allreduce: Whether to explicitly perform gradient allreduce,
instead of relying on implicit allreduce in optimizer.apply_gradients().
default is False. For now, if training using FP16 mixed precision,
explicit allreduce will aggregate gradients in FP16 format. For TPU and
GPU training using FP32, explicit allreduce will aggregate gradients in
FP32 format.
pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new
gradients and model variables paris. The callback functions will be
invoked in the list order and before gradients are allreduced.
With mixed precision training, the pre_allreduce_allbacks will be
applied on scaled_gradients. Default is no callbacks.
Only used when explicit_allreduce=True.
post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks. Only used
when explicit_allreduce=True.
Returns:
Trained model.
Raises:
ValueError: (1) When model returned by `model_fn` does not have optimizer
attribute or when required parameters are set to none. (2) eval args are
not specified correctly. (3) metric_fn must be a callable if specified.
(4) sub_model_checkpoint_name is specified, but `sub_model` returned
by `model_fn` is None.
"""
if _sentinel is not None:
raise ValueError('only call `run_customized_training_loop()` '
'with named arguments.')
required_arguments = [
strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'`steps_per_loop` and `steps_per_epoch` are required '
'parameters.')
if steps_per_loop > steps_per_epoch:
logging.error(
'steps_per_loop: %d is specified to be greater than '
' steps_per_epoch: %d, we will use steps_per_epoch as'
' steps_per_loop.', steps_per_loop, steps_per_epoch)
steps_per_loop = steps_per_epoch
assert tf.executing_eagerly()
if run_eagerly:
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
raise ValueError(
'TPUStrategy should not run eagerly as it heavily relies on graph'
' optimization for the distributed system.')
if eval_input_fn and (eval_steps is None or metric_fn is None):
raise ValueError(
'`eval_step` and `metric_fn` are required when `eval_input_fn ` '
'is not none.')
if metric_fn and not callable(metric_fn):
raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.')
total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy)
with distribution_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model, sub_model = model_fn()
if not hasattr(model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
if sub_model_export_name and sub_model is None:
raise ValueError('sub_model_export_name is specified as %s, but '
'sub_model is None.' % sub_model_export_name)
optimizer = model.optimizer
if init_checkpoint:
logging.info(
'Checkpoint file %s found and restoring from '
'initial checkpoint for core model.', init_checkpoint)
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
eval_metrics = [metric_fn()] if metric_fn else []
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
train_metrics = [
metric.__class__.from_config(metric.get_config())
for metric in eval_metrics
]
# Create summary writers
if _should_export_summary(strategy):
summary_dir = os.path.join(model_dir, 'summaries')
else:
# In multi worker training we need every worker to write summary, because
# variables can trigger synchronization on read and synchronization needs
# all workers to participate.
summary_dir = tempfile.mkdtemp()
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'train'))
else:
train_summary_writer = None
# Collects training variables.
training_vars = model.trainable_variables
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
model_outputs = model(inputs, training=True)
loss = loss_fn(labels, model_outputs)
# Raw loss is used for reporting in metrics/logs.
raw_loss = loss
if scale_loss:
# Scales down the loss for gradients to be invariant from replicas.
loss = loss / strategy.num_replicas_in_sync
if explicit_allreduce:
grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
training_vars,
pre_allreduce_callbacks,
post_allreduce_callbacks)
else:
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
with tape:
scaled_loss = optimizer.get_scaled_loss(loss)
scaled_grads = tape.gradient(scaled_loss, training_vars)
grads = optimizer.get_unscaled_gradients(scaled_grads)
else:
grads = tape.gradient(loss, training_vars)
optimizer.apply_gradients(zip(grads, training_vars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(raw_loss)
for metric in train_metrics:
metric.update_state(labels, model_outputs)
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
for _ in tf.range(steps):
strategy.run(_replicated_step, args=(next(iterator),))
def train_single_step(iterator):
"""Performs a distributed training step.
Args:
iterator: the distributed iterator of training datasets.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
strategy.run(_replicated_step, args=(next(iterator),))
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
for metric in eval_metrics:
metric.update_state(labels, model_outputs)
strategy.run(_test_step_fn, args=(next(iterator),))
if not run_eagerly:
train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step)
def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric)
logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value)
tf.summary.scalar(
metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch, logs):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch, logs)
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
sub_model_checkpoint = tf.train.Checkpoint(
model=sub_model) if sub_model_export_name else None
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
train_loss_metric.reset_states()
for metric in train_metrics + model.metrics:
metric.reset_states()
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if tf.config.list_physical_devices('GPU'):
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
for _ in range(steps):
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric)
current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
train_summary_writer.flush()
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_training_steps:
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if sub_model_export_name:
_save_checkpoint(
strategy, sub_model_checkpoint, model_dir,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics:
metric.reset_states()
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if sub_model_export_name:
_save_checkpoint(strategy, sub_model_checkpoint, model_dir,
'%s.ckpt' % sub_model_export_name)
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
training_summary = {
'total_training_steps': total_training_steps,
'train_loss': _float_metric_value(train_loss_metric),
}
if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
write_txt_summary(training_summary, summary_dir)
if not _should_export_summary(strategy):
tf.io.gfile.rmtree(summary_dir)
return model
@@ -0,0 +1,236 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.modeling.training.model_training_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
from absl.testing import parameterized
from absl.testing.absltest import mock
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.bert import model_training_utils
def eager_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode='eager',
)
def eager_gpu_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode='eager',
)
def create_fake_data_input_fn(batch_size, features_shape, num_classes):
"""Creates a dummy input function with the given feature and label shapes.
Args:
batch_size: integer.
features_shape: list[int]. Feature shape for an individual example.
num_classes: integer. Number of labels.
Returns:
An input function that is usable in the executor.
"""
def _dataset_fn(input_context=None):
"""An input function for generating fake data."""
local_batch_size = input_context.get_per_replica_batch_size(batch_size)
features = np.random.rand(64, *features_shape)
labels = np.random.randint(2, size=[64, num_classes])
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
def _assign_dtype(features, labels):
features = tf.cast(features, tf.float32)
labels = tf.cast(labels, tf.float32)
return features, labels
# Shuffle, repeat, and batch the examples.
dataset = dataset.map(_assign_dtype)
dataset = dataset.shuffle(64).repeat()
dataset = dataset.batch(local_batch_size, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=64)
return dataset
return _dataset_fn
def create_model_fn(input_shape, num_classes, use_float16=False):
def _model_fn():
"""A one-layer softmax model suitable for testing."""
input_layer = tf.keras.layers.Input(shape=input_shape)
x = tf.keras.layers.Dense(num_classes, activation='relu')(input_layer)
output_layer = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
sub_model = tf.keras.models.Model(input_layer, x, name='sub_model')
model = tf.keras.models.Model(input_layer, output_layer, name='model')
model.add_metric(
tf.reduce_mean(input_layer), name='mean_input', aggregation='mean')
model.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
if use_float16:
model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
model.optimizer, loss_scale='dynamic'))
return model, sub_model
return _model_fn
def metric_fn():
"""Gets a tf.keras metric object."""
return tf.keras.metrics.CategoricalAccuracy(name='accuracy', dtype=tf.float32)
def summaries_with_matching_keyword(keyword, summary_dir):
"""Yields summary protos matching given keyword from event file."""
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, 'events*'))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
logging.error(event)
yield event.summary
def check_eventfile_for_keyword(keyword, summary_dir):
"""Checks event files for the keyword."""
return any(summaries_with_matching_keyword(keyword, summary_dir))
class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ModelTrainingUtilsTest, self).setUp()
self._model_fn = create_model_fn(input_shape=[128], num_classes=3)
def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=self._model_fn,
loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir,
steps_per_epoch=20,
steps_per_loop=steps_per_loop,
epochs=2,
train_input_fn=input_fn,
eval_input_fn=input_fn,
eval_steps=10,
init_checkpoint=None,
metric_fn=metric_fn,
custom_callbacks=None,
run_eagerly=run_eagerly)
@combinations.generate(eager_strategy_combinations())
def test_train_eager_single_step(self, distribution):
model_dir = self.get_temp_dir()
if isinstance(distribution, tf.distribute.experimental.TPUStrategy):
with self.assertRaises(ValueError):
self.run_training(
distribution, model_dir, steps_per_loop=1, run_eagerly=True)
else:
self.run_training(
distribution, model_dir, steps_per_loop=1, run_eagerly=True)
@combinations.generate(eager_gpu_strategy_combinations())
def test_train_eager_mixed_precision(self, distribution):
model_dir = self.get_temp_dir()
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
self._model_fn = create_model_fn(
input_shape=[128], num_classes=3, use_float16=True)
self.run_training(
distribution, model_dir, steps_per_loop=1, run_eagerly=True)
@combinations.generate(eager_strategy_combinations())
def test_train_check_artifacts(self, distribution):
model_dir = self.get_temp_dir()
self.run_training(
distribution, model_dir, steps_per_loop=10, run_eagerly=False)
# Two checkpoints should be saved after two epochs.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*')))
self.assertNotEmpty(
tf.io.gfile.glob(
os.path.join(model_dir, 'summaries/training_summary*')))
# Loss and accuracy values should be written into summaries.
self.assertTrue(
check_eventfile_for_keyword('loss',
os.path.join(model_dir, 'summaries/train')))
self.assertTrue(
check_eventfile_for_keyword('accuracy',
os.path.join(model_dir, 'summaries/train')))
self.assertTrue(
check_eventfile_for_keyword('mean_input',
os.path.join(model_dir, 'summaries/train')))
self.assertTrue(
check_eventfile_for_keyword('accuracy',
os.path.join(model_dir, 'summaries/eval')))
self.assertTrue(
check_eventfile_for_keyword('mean_input',
os.path.join(model_dir, 'summaries/eval')))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
))
def test_train_check_artifacts_non_chief(self, distribution):
# We shouldn't export artifacts on non-chief workers. Since there's no easy
# way to test with real MultiWorkerMirroredStrategy, we patch the strategy
# to make it as if it's MultiWorkerMirroredStrategy on non-chief workers.
extended = distribution.extended
with mock.patch.object(extended.__class__, 'should_checkpoint',
new_callable=mock.PropertyMock, return_value=False), \
mock.patch.object(extended.__class__, 'should_save_summary',
new_callable=mock.PropertyMock, return_value=False):
model_dir = self.get_temp_dir()
self.run_training(
distribution, model_dir, steps_per_loop=10, run_eagerly=False)
self.assertEmpty(tf.io.gfile.listdir(model_dir))
if __name__ == '__main__':
tf.test.main()
@@ -0,0 +1,432 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT classification finetuning runner in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import math
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
'One of {"train_and_eval", "export_only"}. `train_and_eval`: '
'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', None,
'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
'Path to evaluation data for BERT classifier.')
# Model training specific flags.
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def get_loss_fn(num_classes):
"""Gets the classification loss function."""
def classification_loss_fn(labels, logits):
"""Classification loss."""
labels = tf.squeeze(labels)
log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot(
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
return tf.reduce_mean(per_example_loss)
return classification_loss_fn
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_classifier_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def run_bert_classifier(strategy,
bert_config,
input_meta_data,
model_dir,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
warmup_steps,
initial_lr,
init_checkpoint,
train_input_fn,
eval_input_fn,
custom_callbacks=None,
run_eagerly=False,
use_keras_compile_fit=False):
"""Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels']
def _get_classifier_model():
"""Gets a classifier model."""
classifier_model, core_model = (
bert_models.classifier_model(
bert_config,
num_classes,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.end_lr, FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model
loss_fn = get_loss_fn(num_classes)
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
def metric_fn():
return tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
if use_keras_compile_fit:
# Start training using Keras compile/fit API.
logging.info('Training using TF 2.0 Keras compile/fit API with '
'distribution strategy.')
return run_keras_compile_fit(
model_dir,
strategy,
_get_classifier_model,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
custom_callbacks=custom_callbacks)
# Use user-defined loop to start training.
logging.info('Training using customized training loop TF 2.0 with '
'distribution strategy.')
return model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_classifier_model,
loss_fn=loss_fn,
model_dir=model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
eval_steps=eval_steps,
init_checkpoint=init_checkpoint,
metric_fn=metric_fn,
custom_callbacks=custom_callbacks,
run_eagerly=run_eagerly)
def run_keras_compile_fit(model_dir,
strategy,
model_fn,
train_input_fn,
eval_input_fn,
loss_fn,
metric_fn,
init_checkpoint,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
custom_callbacks=None):
"""Runs BERT classifier model using Keras compile/fit API."""
with strategy.scope():
training_dataset = train_input_fn()
evaluation_dataset = eval_input_fn()
bert_model, sub_model = model_fn()
optimizer = bert_model.optimizer
if init_checkpoint:
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
bert_model.compile(
optimizer=optimizer,
loss=loss_fn,
metrics=[metric_fn()],
experimental_steps_per_execution=steps_per_loop)
summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)
if custom_callbacks is not None:
custom_callbacks += [summary_callback, checkpoint_callback]
else:
custom_callbacks = [summary_callback, checkpoint_callback]
bert_model.fit(
x=training_dataset,
validation_data=evaluation_dataset,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_steps=eval_steps,
callbacks=custom_callbacks)
return bert_model
def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
eval_steps):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
order changes on distributing dataset over TPU pods.
Args:
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
eval_steps: Number of evaluation steps.
Returns:
predictions: List of predictions.
labels: List of gold labels corresponding to predictions.
"""
@tf.function
def test_step(iterator):
"""Computes predictions on distributed devices."""
def _test_step_fn(inputs):
"""Replicated predictions."""
inputs, labels = inputs
model_outputs = trained_model(inputs, training=False)
return model_outputs, labels
outputs, labels = strategy.run(
_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
return outputs, labels
def _run_evaluation(test_iterator):
"""Runs evaluation steps."""
preds, golds = list(), list()
for _ in range(eval_steps):
logits, labels = test_step(test_iterator)
for cur_logits, cur_labels in zip(logits, labels):
preds.extend(tf.math.argmax(cur_logits, axis=1).numpy())
golds.extend(cur_labels.numpy().tolist())
return preds, golds
test_iter = iter(
strategy.experimental_distribute_datasets_from_function(eval_input_fn))
predictions, labels = _run_evaluation(test_iter)
return predictions, labels
def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights, bert_config, model_dir):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API. There are 2
different ways to save checkpoints. One is using tf.train.Checkpoint and
another is using Keras model.save_weights(). Custom training loop
implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
callback internally uses model.save_weights() API. Since these two API's
cannot be used together, model loading logic must be take into account how
model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
if not model_dir:
raise ValueError('Export path is not specified: %s' % model_dir)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'],
input_meta_data['max_seq_length'])[0]
model_saving_utils.export_bert_model(
model_export_path,
model=classifier_model,
checkpoint_dir=model_dir,
restore_model_using_load_weights=restore_model_using_load_weights)
def run_bert(strategy,
input_meta_data,
model_config,
train_input_fn=None,
eval_input_fn=None):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
export_classifier(FLAGS.model_export_path, input_meta_data,
FLAGS.use_keras_compile_fit,
model_config, FLAGS.model_dir)
return
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size']
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if not strategy:
raise ValueError('Distribution strategy has not been specified.')
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
trained_model = run_bert_classifier(
strategy,
model_config,
input_meta_data,
FLAGS.model_dir,
epochs,
steps_per_epoch,
FLAGS.steps_per_loop,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
train_input_fn,
eval_input_fn,
run_eagerly=FLAGS.run_eagerly,
use_keras_compile_fit=FLAGS.use_keras_compile_fit,
custom_callbacks=custom_callbacks)
if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
model_saving_utils.export_bert_model(
FLAGS.model_export_path,
model=trained_model,
restore_model_using_load_weights=FLAGS.use_keras_compile_fit)
return trained_model
def main(_):
# Users should always run this script under TF 2.x
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
eval_input_fn = get_dataset_fn(
FLAGS.eval_data_path,
max_seq_length,
FLAGS.eval_batch_size,
is_training=False)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_bert(strategy, input_meta_data, bert_config, train_input_fn,
eval_input_fn)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('input_meta_data_path')
flags.mark_flag_as_required('model_dir')
app.run(main)
@@ -0,0 +1,187 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run masked LM/next sentence pre-training for BERT in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import configs
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils
flags.DEFINE_string('input_files', None,
'File path to retrieve training data for pre-training.')
# Model training specific flags.
flags.DEFINE_integer(
'max_seq_length', 128,
'The maximum total input sequence length after WordPiece tokenization. '
'Sequences longer than this will be truncated, and sequences shorter '
'than this will be padded.')
flags.DEFINE_integer('max_predictions_per_seq', 20,
'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size,
use_next_sentence_label=True):
"""Returns input dataset from input file string."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
input_patterns = input_file_pattern.split(',')
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
train_dataset = input_pipeline.create_pretrain_dataset(
input_patterns,
seq_length,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=ctx,
use_next_sentence_label=use_next_sentence_label)
return train_dataset
return _dataset_fn
def get_loss_fn():
"""Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.reduce_mean(losses)
return _bert_pretrain_loss_fn
def run_customized_training(strategy,
bert_config,
max_seq_length,
max_predictions_per_seq,
model_dir,
steps_per_epoch,
steps_per_loop,
epochs,
initial_lr,
warmup_steps,
end_lr,
optimizer_type,
input_files,
train_batch_size,
use_next_sentence_label=True):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq,
train_batch_size,
use_next_sentence_label)
def _get_pretrain_model():
"""Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label)
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps,
end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(),
scale_loss=FLAGS.scale_loss,
model_dir=model_dir,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
epochs=epochs,
sub_model_export_name='pretrained/bert_model')
return trained_model
def run_bert_pretrain(strategy):
"""Runs BERT pre-training."""
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if not strategy:
raise ValueError('Distribution strategy is not specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype())
return run_customized_training(
strategy,
bert_config,
FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq,
FLAGS.model_dir,
FLAGS.num_steps_per_epoch,
FLAGS.steps_per_loop,
FLAGS.num_train_epochs,
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type,
FLAGS.input_files,
FLAGS.train_batch_size,
FLAGS.use_next_sentence_label)
def main(_):
# Users should always run this script under TF 2.x
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
if strategy:
print('***** Number of cores used : ', strategy.num_replicas_in_sync)
run_bert_pretrain(strategy)
if __name__ == '__main__':
app.run(main)
@@ -0,0 +1,149 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.')
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False,
init_checkpoint=None):
"""Run bert squad training."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly, init_checkpoint)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_squad_helper.predict_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
def eval_squad(strategy, input_meta_data):
"""Evaluate on the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_metrics = run_squad_helper.eval_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
return eval_metrics
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
# Users should always run this script under TF 2.x
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if 'train' in FLAGS.mode:
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir,
)]
else:
custom_callbacks = None
train_squad(
strategy,
input_meta_data,
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
)
if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode:
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
summary_writer = tf.summary.create_file_writer(summary_dir)
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
# Also write eval_metrics to json file.
squad_lib_wp.write_to_json_files(
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
time.sleep(60)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
@@ -0,0 +1,432 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.bert import model_training_utils
from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0
from official.nlp.data import squad_lib_sp
from official.utils.misc import keras_utils
def define_common_squad_flags():
"""Defines common flags used by SQuAD tasks."""
flags.DEFINE_enum(
'mode', 'train_and_eval',
['train_and_eval', 'train_and_predict',
'train', 'eval', 'predict', 'export_only'],
'One of {"train_and_eval", "train_and_predict", '
'"train", "eval", "predict", "export_only"}. '
'`train_and_eval`: train & predict to json files & compute eval metrics. '
'`train_and_predict`: train & predict to json files. '
'`train`: only trains the model. '
'`eval`: predict answers from squad json file & compute eval metrics. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_float(
'null_score_diff_threshold', 0.0,
'If null_score - best_non_null is greater than the threshold, '
'predict null. This is only used for SQuAD v2.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be '
'printed. A number of warnings are expected for a normal SQuAD '
'evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one '
'another.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
return total_loss
def get_loss_fn():
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits)
return _loss_fn
RawResult = collections.namedtuple('RawResult',
['unique_id', 'start_logits', 'end_logits'])
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy,
input_meta_data,
bert_config,
checkpoint_path,
predict_tfrecord_path,
num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
if checkpoint_path is None:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.run(_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy,
input_meta_data,
bert_config,
custom_callbacks=None,
run_eagerly=False,
init_checkpoint=None):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs,
warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type)
squad_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
# applied to allreduced gradients.
def clip_by_global_norm_callback(grads_and_vars):
grads, variables = zip(*grads_and_vars)
(clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return zip(clipped_grads, variables)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=get_loss_fn(),
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks,
explicit_allreduce=False,
post_allreduce_callbacks=[clip_by_global_norm_callback])
def prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(
strategy, input_meta_data, bert_config,
checkpoint, eval_writer.filename, num_steps)
all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging))
return all_predictions, all_nbest_json, scores_diff_json
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib, version_2_with_negative):
"""Save output to json files."""
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
logging.info('Writing predictions to: %s', (output_prediction_file))
logging.info('Writing nbest to: %s', (output_nbest_file))
squad_lib.write_to_json_files(all_predictions, output_prediction_file)
squad_lib.write_to_json_files(all_nbest_json, output_nbest_file)
if version_2_with_negative:
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def predict_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
def eval_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
dataset_json = json.load(reader)
pred_dataset = dataset_json['data']
if input_meta_data.get('version_2_with_negative', False):
eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset,
all_predictions,
scores_diff_json)
else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
return eval_metrics
def export_squad(model_export_path, input_meta_data, bert_config):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
bert_config: Bert configuration file to define core bert layers.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
@@ -0,0 +1,108 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation of SQuAD predictions (version 1.1).
The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import string
# pylint: disable=g-bad-import-order
from absl import logging
# pylint: enable=g-bad-import-order
def _normalize_answer(s):
"""Lowers text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _f1_score(prediction, ground_truth):
"""Computes F1 score by comparing prediction to ground truth."""
prediction_tokens = _normalize_answer(prediction).split()
ground_truth_tokens = _normalize_answer(ground_truth).split()
prediction_counter = collections.Counter(prediction_tokens)
ground_truth_counter = collections.Counter(ground_truth_tokens)
common = prediction_counter & ground_truth_counter
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _exact_match_score(prediction, ground_truth):
"""Checks if predicted answer exactly matches ground truth answer."""
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Computes the max over all metric scores."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
"""Evaluates predictions for a dataset."""
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
logging.error(message)
continue
ground_truths = [entry["text"] for entry in qa["answers"]]
prediction = predictions[qa["id"]]
exact_match += _metric_max_over_ground_truths(_exact_match_score,
prediction, ground_truths)
f1 += _metric_max_over_ground_truths(_f1_score, prediction,
ground_truths)
exact_match = exact_match / total
f1 = f1 / total
return {"exact_match": exact_match, "final_f1": f1}
@@ -0,0 +1,252 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation script for SQuAD version 2.0.
The functions are copied and modified from
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import string
from absl import logging
def _make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid_to_has_ans[qa['id']] = bool(qa['answers'])
return qid_to_has_ans
def _normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _get_tokens(s):
if not s: return []
return _normalize_answer(s).split()
def _compute_exact(a_gold, a_pred):
return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
def _compute_f1(a_gold, a_pred):
"""Compute F1-score."""
gold_toks = _get_tokens(a_gold)
pred_toks = _get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if not gold_toks or not pred_toks:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _get_raw_scores(dataset, predictions):
"""Compute raw scores."""
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid = qa['id']
gold_answers = [a['text'] for a in qa['answers']
if _normalize_answer(a['text'])]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = ['']
if qid not in predictions:
logging.error('Missing prediction for %s', qid)
continue
a_pred = predictions[qid]
# Take max over all gold answers
exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def _apply_no_ans_threshold(
scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
"""Make evaluation result dictionary."""
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores.values()) / total),
('f1', 100.0 * sum(f1_scores.values()) / total),
('total', total),
])
else:
total = len(qid_list)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
('total', total),
])
def _merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
"""Make evaluation dictionary containing average recision recall."""
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i+1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
return {'ap': 100.0 * avg_prec}
def _run_precision_recall_analysis(
main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
"""Run precision recall analysis and return result dictionary."""
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = _make_precision_recall_eval(
exact_raw, na_probs, num_true_pos, qid_to_has_ans)
pr_f1 = _make_precision_recall_eval(
f1_raw, na_probs, num_true_pos, qid_to_has_ans)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = _make_precision_recall_eval(
oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
_merge_eval(main_eval, pr_exact, 'pr_exact')
_merge_eval(main_eval, pr_f1, 'pr_f1')
_merge_eval(main_eval, pr_oracle, 'pr_oracle')
def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
"""Find the best threshold for no answer probability."""
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for qid in qid_list:
if qid not in scores: continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if predictions[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def _find_all_best_thresh(
main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = _find_best_thresh(
predictions, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = _find_best_thresh(
predictions, f1_raw, na_probs, qid_to_has_ans)
main_eval['final_exact'] = best_exact
main_eval['final_exact_thresh'] = exact_thresh
main_eval['final_f1'] = best_f1
main_eval['final_f1_thresh'] = f1_thresh
def evaluate(dataset, predictions, na_probs=None):
"""Evaluate prediction results."""
new_orig_data = []
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
if qa['id'] in predictions:
new_para = {'qas': [qa]}
new_article = {'paragraphs': [new_para]}
new_orig_data.append(new_article)
dataset = new_orig_data
if na_probs is None:
na_probs = {k: 0.0 for k in predictions}
qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
out_eval = _make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = _make_eval_dict(
exact_thresh, f1_thresh, qid_list=has_ans_qids)
_merge_eval(out_eval, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
_merge_eval(out_eval, no_ans_eval, 'NoAns')
_find_all_best_thresh(
out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
_run_precision_recall_analysis(
out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
return out_eval
@@ -0,0 +1,195 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf # TF 1.x
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS = (
("bert", "bert_model"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings",
"embedding_postprocessor/type_embeddings"),
("embeddings/position_embeddings",
"embedding_postprocessor/position_embeddings"),
("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
)
BERT_V2_NAME_REPLACEMENTS = (
("bert/", ""),
("encoder", "transformer"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
("embeddings/position_embeddings", "position_embedding/embeddings"),
("embeddings/LayerNorm", "embeddings/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights",
"predictions/transform/logits/kernel"),
)
BERT_PERMUTATIONS = ()
BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
def _bert_name_replacement(var_name, name_replacements):
"""Gets the variable name replacement."""
for src_pattern, tgt_pattern in name_replacements:
if src_pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(src_pattern, tgt_pattern)
tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
return var_name
def _has_exclude_patterns(name, exclude_patterns):
"""Checks if a string contains substrings that match patterns to exclude."""
for p in exclude_patterns:
if p in name:
return True
return False
def _get_permutation(name, permutations):
"""Checks whether a variable requires transposition by pattern matching."""
for src_pattern, permutation in permutations:
if src_pattern in name:
tf.logging.info("Permuted: %s --> %s", name, permutation)
return permutation
return None
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "self_attention_output/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "self_attention_output/bias" in name:
return shape
patterns = [
"self_attention/query", "self_attention/value", "self_attention/key"
]
for pattern in patterns:
if pattern in name:
if "kernel" in name:
return tuple([shape[0], num_heads, shape[1] // num_heads])
if "bias" in name:
return tuple([num_heads, shape[0] // num_heads])
return None
def create_v2_checkpoint(model, src_checkpoint, output_path):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.save(output_path)
def convert(checkpoint_from_path,
checkpoint_to_path,
num_heads,
name_replacements,
permutations,
exclude_patterns=None):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
num_heads: The number of heads of the model.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with tf.Graph().as_default():
tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
new_variable_map = {}
conversion_map = {}
for var_name in name_shape_map:
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
continue
# Get the original tensor data.
tensor = reader.get_tensor(var_name)
# Look up the new variable name, if any.
new_var_name = _bert_name_replacement(var_name, name_replacements)
# See if we need to reshape the underlying tensor.
new_shape = None
if num_heads > 0:
new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
if new_shape:
tf.logging.info("Veriable %s has a shape change from %s to %s",
var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape)
# See if we need to permute the underlying tensor.
permutation = _get_permutation(var_name, permutations)
if permutation:
tensor = np.transpose(tensor, permutation)
# Create a new variable with the possibly-reshaped or transposed tensor.
var = tf.Variable(tensor, name=var_name)
# Save the variable into the new variable map.
new_variable_map[new_var_name] = var
# Keep a list of converter variables for sanity checking.
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
saver = tf.train.Saver(new_variable_map)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
saver.save(sess, checkpoint_to_path)
tf.logging.info("Summary:")
tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
tf.logging.info(" Converted: %s", str(conversion_map))
@@ -0,0 +1,108 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a TransformerEncoder object.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.modeling import activations
from official.nlp.bert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import networks
FLAGS = flags.FLAGS
flags.DEFINE_string("bert_config_file", None,
"Bert configuration file to define core bert layers.")
flags.DEFINE_string(
"checkpoint_to_convert", None,
"Initial checkpoint from a pretrained BERT model core (that is, only the "
"BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.")
def _create_bert_model(cfg):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A TransformerEncoder netowork.
"""
bert_encoder = networks.TransformerEncoder(
vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size,
num_layers=cfg.num_hidden_layers,
num_attention_heads=cfg.num_attention_heads,
intermediate_size=cfg.intermediate_size,
activation=activations.gelu,
dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob,
sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range))
return bert_encoder
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
tf1_checkpoint_converter_lib.convert(
checkpoint_from_path=v1_checkpoint,
checkpoint_to_path=temporary_checkpoint,
num_heads=bert_config.num_attention_heads,
name_replacements=tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS,
permutations=tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
exclude_patterns=["adam", "Adam"])
# Create a V2 checkpoint from the temporary checkpoint.
model = _create_bert_model(bert_config)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path)
# Clean up the temporary checkpoint, if it exists.
try:
tf.io.gfile.rmtree(temporary_checkpoint_dir)
except tf.errors.OpError:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def main(_):
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
convert_checkpoint(bert_config, output_path, v1_checkpoint)
if __name__ == "__main__":
app.run(main)
@@ -0,0 +1,545 @@
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tokenization classes implementation.
The file is forked from:
https://github.com/google-research/bert/blob/master/tokenization.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
import tensorflow as tf
import sentencepiece as spm
SPIECE_UNDERLINE = ""
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." %
(actual_flag, init_checkpoint, model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with tf.io.gfile.GFile(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True, split_on_punc=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, split_on_punc=split_on_punc)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True, split_on_punc=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
split_on_punc: Whether to apply split on punctuations. By default BERT
starts a new token for punctuations. This makes detokenization difficult
for tasks like seq2seq decoding.
"""
self.do_lower_case = do_lower_case
self.split_on_punc = split_on_punc
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
if self.split_on_punc:
split_tokens.extend(self._run_split_on_punc(token))
else:
split_tokens.append(token)
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def preprocess_text(inputs, remove_space=True, lower=False):
"""Preprocesses data by removing extra space and normalize data.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
Args:
inputs: The input text.
remove_space: Whether to remove the extra space.
lower: Whether to lowercase the text.
Returns:
The preprocessed text.
"""
outputs = inputs
if remove_space:
outputs = " ".join(inputs.strip().split())
if six.PY2 and isinstance(outputs, str):
try:
outputs = six.ensure_text(outputs, "utf-8")
except UnicodeDecodeError:
outputs = six.ensure_text(outputs, "latin-1")
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if lower:
outputs = outputs.lower()
return outputs
def encode_pieces(sp_model, text, sample=False):
"""Segements text into pieces.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
Args:
sp_model: A spm.SentencePieceProcessor object.
text: The input text to be segemented.
sample: Whether to randomly sample a segmentation output or return a
deterministic one.
Returns:
A list of token pieces.
"""
if six.PY2 and isinstance(text, six.text_type):
text = six.ensure_binary(text, "utf-8")
if not sample:
pieces = sp_model.EncodeAsPieces(text)
else:
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = []
for piece in pieces:
piece = printable_text(piece)
if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
else:
cur_pieces[0] = cur_pieces[0][1:]
cur_pieces.append(piece[-1])
new_pieces.extend(cur_pieces)
else:
new_pieces.append(piece)
return new_pieces
def encode_ids(sp_model, text, sample=False):
"""Segments text and return token ids.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/master/albert/tokenization.py
Args:
sp_model: A spm.SentencePieceProcessor object.
text: The input text to be segemented.
sample: Whether to randomly sample a segmentation output or return a
deterministic one.
Returns:
A list of token ids.
"""
pieces = encode_pieces(sp_model, text, sample=sample)
ids = [sp_model.PieceToId(piece) for piece in pieces]
return ids
class FullSentencePieceTokenizer(object):
"""Runs end-to-end sentence piece tokenization.
The interface of this class is intended to keep the same as above
`FullTokenizer` class for easier usage.
"""
def __init__(self, sp_model_file):
"""Inits FullSentencePieceTokenizer.
Args:
sp_model_file: The path to the sentence piece model file.
"""
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(sp_model_file)
self.vocab = {
self.sp_model.IdToPiece(i): i
for i in six.moves.range(self.sp_model.GetPieceSize())
}
def tokenize(self, text):
"""Tokenizes text into pieces."""
return encode_pieces(self.sp_model, text)
def convert_tokens_to_ids(self, tokens):
"""Converts a list of tokens to a list of ids."""
return [self.sp_model.PieceToId(printable_text(token)) for token in tokens]
def convert_ids_to_tokens(self, ids):
"""Converts a list of ids ot a list of tokens."""
return [self.sp_model.IdToPiece(id_) for id_ in ids]
@@ -0,0 +1,160 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import six
import tensorflow as tf
from official.nlp.bert import tokenization
class TokenizationTest(tf.test.TestCase):
"""Tokenization test.
The implementation is forked from
https://github.com/google-research/bert/blob/master/tokenization_test.py."
"""
def test_full_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
]
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
if six.PY2:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
else:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
]).encode("utf-8"))
vocab_file = vocab_writer.name
tokenizer = tokenization.FullTokenizer(vocab_file)
os.unlink(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertAllEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self):
tokenizer = tokenization.BasicTokenizer()
self.assertAllEqual(
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello", "!", "how", "are", "you", "?"])
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_basic_tokenizer_no_split_on_punc(self):
tokenizer = tokenization.BasicTokenizer(
do_lower_case=True, split_on_punc=False)
self.assertAllEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello!how", "are", "you?"])
def test_wordpiece_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", "##!", "!"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
self.assertAllEqual(tokenizer.tokenize(""), [])
self.assertAllEqual(
tokenizer.tokenize("unwanted running"),
["un", "##want", "##ed", "runn", "##ing"])
self.assertAllEqual(
tokenizer.tokenize("unwanted running !"),
["un", "##want", "##ed", "runn", "##ing", "!"])
self.assertAllEqual(
tokenizer.tokenize("unwanted running!"),
["un", "##want", "##ed", "runn", "##ing", "##!"])
self.assertAllEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_convert_tokens_to_ids(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
self.assertAllEqual(
tokenization.convert_tokens_to_ids(
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
def test_is_whitespace(self):
self.assertTrue(tokenization._is_whitespace(u" "))
self.assertTrue(tokenization._is_whitespace(u"\t"))
self.assertTrue(tokenization._is_whitespace(u"\r"))
self.assertTrue(tokenization._is_whitespace(u"\n"))
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
self.assertFalse(tokenization._is_whitespace(u"A"))
self.assertFalse(tokenization._is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(tokenization._is_control(u"\u0005"))
self.assertFalse(tokenization._is_control(u"A"))
self.assertFalse(tokenization._is_control(u" "))
self.assertFalse(tokenization._is_control(u"\t"))
self.assertFalse(tokenization._is_control(u"\r"))
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
def test_is_punctuation(self):
self.assertTrue(tokenization._is_punctuation(u"-"))
self.assertTrue(tokenization._is_punctuation(u"$"))
self.assertTrue(tokenization._is_punctuation(u"`"))
self.assertTrue(tokenization._is_punctuation(u"."))
self.assertFalse(tokenization._is_punctuation(u"A"))
self.assertFalse(tokenization._is_punctuation(u" "))
if __name__ == "__main__":
tf.test.main()
@@ -0,0 +1,676 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for classification task."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import os
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
from official.nlp.bert import tokenization
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_mask,
segment_ids,
label_id,
is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
self.process_text_fn = process_text_fn
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@staticmethod
def get_processor_name():
"""Gets the string identifier of the processor."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.io.gfile.GFile(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
self.language = "zh"
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
language = self.process_text_fn(line[0])
if language != self.process_text_fn(self.language):
continue
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XNLI"
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test":
label = "0"
else:
label = self.process_text_fn(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = self.process_text_fn(line[1])
label = "0"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class SstProcessor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "SST-2"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "0"
else:
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "QNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, 1)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
else:
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class TfdsProcessor(DataProcessor):
"""Processor for generic text classification TFDS data set.
The TFDS parameters are expected to be provided in the tfds_params string, in
a comma-separated list of parameter assignments.
Examples:
tfds_params="dataset=scicite,text_key=string"
tfds_params="dataset=imdb_reviews,test_split=,dev_split=test"
tfds_params="dataset=glue/cola,text_key=sentence"
tfds_params="dataset=glue/sst2,text_key=sentence"
tfds_params="dataset=glue/qnli,text_key=question,text_b_key=sentence"
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number).
data_dir: Optional TFDS source root directory.
train_split: Name of the train split (defaults to `train`).
dev_split: Name of the dev split (defaults to `validation`).
test_split: Name of the test split (defaults to `test`).
text_key: Key of the text_a feature (defaults to `text`).
text_b_key: Key of the second text feature if available.
label_key: Key of the label feature (defaults to `label`).
test_text_key: Key of the text feature to use in test set.
test_text_b_key: Key of the second text feature to use in test set.
test_label: String to be used as the label for all test examples.
"""
def __init__(self, tfds_params,
process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir,
with_info=True)
self._labels = list(range(info.features[self.label_key].num_classes))
def _process_tfds_params_str(self, params_str):
"""Extracts TFDS parameters from a comma-separated assignements string."""
tuples = [x.split("=") for x in params_str.split(",")]
d = {k.strip(): v.strip() for k, v in tuples}
self.dataset_name = d["dataset"] # Required.
self.data_dir = d.get("data_dir", None)
self.train_split = d.get("train_split", "train")
self.dev_split = d.get("dev_split", "validation")
self.test_split = d.get("test_split", "test")
self.text_key = d.get("text_key", "text")
self.text_b_key = d.get("text_b_key", None)
self.label_key = d.get("label_key", "label")
self.test_text_key = d.get("test_text_key", self.text_key)
self.test_text_b_key = d.get("test_text_b_key", self.text_b_key)
self.test_label = d.get("test_label", "test_example")
def get_train_examples(self, data_dir):
assert data_dir is None
return self._create_examples(self.train_split, "train")
def get_dev_examples(self, data_dir):
assert data_dir is None
return self._create_examples(self.dev_split, "dev")
def get_test_examples(self, data_dir):
assert data_dir is None
return self._create_examples(self.test_split, "test")
def get_labels(self):
return self._labels
def get_processor_name(self):
return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets."""
if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator()
examples = []
text_b = None
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = self.process_text_fn(example[self.test_text_key])
if self.test_text_b_key:
text_b = self.process_text_fn(example[self.test_text_b_key])
label = self.test_label
else:
text_a = self.process_text_fn(example[self.text_key])
if self.text_b_key:
text_b = self.process_text_fn(example[self.text_b_key])
label = int(example[self.label_key])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
label_id = label_map[example.label]
if ex_index < 5:
logging.info("*** Example ***")
logging.info("guid: %s", (example.guid))
logging.info("tokens: %s",
" ".join([tokenization.printable_text(x) for x in tokens]))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %d)", example.label, label_id)
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id,
is_real_example=True)
return feature
def file_based_convert_examples_to_features(examples, label_list,
max_seq_length, tokenizer,
output_file):
"""Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples))
feature = convert_single_example(ex_index, example, label_list,
max_seq_length, tokenizer)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def generate_tf_record_from_data_file(processor,
data_dir,
tokenizer,
train_data_output_path=None,
eval_data_output_path=None,
max_seq_length=128):
"""Generates and saves training data into a tf record file.
Arguments:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from "dev.tsv", "test.tsv", or "train.tsv".
tokenizer: The tokenizer to be applied on the data.
train_data_output_path: Output to which processed tf record for training
will be saved.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
Returns:
A dictionary containing input meta data.
"""
assert train_data_output_path or eval_data_output_path
label_list = processor.get_labels()
assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer,
train_data_output_path)
num_training_data = len(train_input_data_examples)
if eval_data_output_path:
eval_input_data_examples = processor.get_dev_examples(data_dir)
file_based_convert_examples_to_features(eval_input_data_examples,
label_list, max_seq_length,
tokenizer, eval_data_output_path)
meta_data = {
"task_type": "bert_classification",
"processor_type": processor.get_processor_name(),
"num_labels": len(processor.get_labels()),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples)
return meta_data
@@ -0,0 +1,203 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT finetuning task dataset generator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
# word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp
FLAGS = flags.FLAGS
flags.DEFINE_enum(
"fine_tuning_task_type", "classification", ["classification", "squad"],
"The name of the BERT fine tuning task for which data "
"will be generated..")
# BERT classification specific flags.
flags.DEFINE_string(
"input_data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "SST-2", "XNLI"],
"The name of the task to train BERT classifier.")
# BERT Squad task specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
flags.DEFINE_integer(
"doc_stride", 128,
"When splitting up a long document into chunks, how much stride to "
"take between chunks.")
flags.DEFINE_integer(
"max_query_length", 64,
"The maximum number of tokens for the question. Questions longer than "
"this will be truncated to this length.")
flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string(
"train_data_output_path", None,
"The path in which generated training input data will be written as tf"
" records.")
flags.DEFINE_string(
"eval_data_output_path", None,
"The path in which generated training input data will be written as tf"
" records.")
flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_string("sp_model_file", "",
"The path to the model used by sentence piece tokenizer.")
flags.DEFINE_enum(
"tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"],
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer.")
flags.DEFINE_string("tfds_params", "",
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation).")
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.classification_task_name
or FLAGS.tfds_params)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params,
process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length)
else:
processors = {
"cola": classifier_data_lib.ColaProcessor,
"mnli": classifier_data_lib.MnliProcessor,
"mrpc": classifier_data_lib.MrpcProcessor,
"qnli": classifier_data_lib.QnliProcessor,
"sst-2": classifier_data_lib.SstProcessor,
"xnli": classifier_data_lib.XnliProcessor,
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name](processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
FLAGS.input_data_dir,
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length)
def generate_squad_dataset():
"""Generates squad training dataset and returns input meta data."""
assert FLAGS.squad_data_file
if FLAGS.tokenizer_impl == "word_piece":
return squad_lib_wp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative)
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
return squad_lib_sp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.sp_model_file,
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
def main(_):
if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file:
raise ValueError(
"FLAG vocab_file for word-piece tokenizer is not specified.")
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
if not FLAGS.sp_model_file:
raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset()
else:
input_meta_data = generate_squad_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
writer.write(json.dumps(input_meta_data, indent=4) + "\n")
if __name__ == "__main__":
flags.mark_flag_as_required("train_data_output_path")
flags.mark_flag_as_required("meta_data_file_path")
app.run(main)
@@ -0,0 +1,486 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import random
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.bert import tokenization
FLAGS = flags.FLAGS
flags.DEFINE_string("input_file", None,
"Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
"output_file", None,
"Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_bool(
"do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_bool(
"gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer(
"dupe_factor", 10,
"Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
"short_seq_prob", 0.1,
"Probability of creating sequences which are shorter than the "
"maximum length.")
class TrainingInstance(object):
"""A single training instance (sentence pair)."""
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens
self.segment_ids = segment_ids
self.is_random_next = is_random_next
self.masked_lm_positions = masked_lm_positions
self.masked_lm_labels = masked_lm_labels
def __str__(self):
s = ""
s += "tokens: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.tokens]))
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
s += "is_random_next: %s\n" % self.is_random_next
s += "masked_lm_positions: %s\n" % (" ".join(
[str(x) for x in self.masked_lm_positions]))
s += "masked_lm_labels: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files,
gzip_compress):
"""Create TF example files from `TrainingInstance`s."""
writers = []
for output_file in output_files:
writers.append(
tf.io.TFRecordWriter(
output_file, options="GZIP" if gzip_compress else ""))
writer_index = 0
total_written = 0
for (inst_index, instance) in enumerate(instances):
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
input_mask = [1] * len(input_ids)
segment_ids = list(instance.segment_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = list(instance.masked_lm_positions)
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < max_predictions_per_seq:
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance.is_random_next else 0
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writers[writer_index].write(tf_example.SerializeToString())
writer_index = (writer_index + 1) % len(writers)
total_written += 1
if inst_index < 20:
logging.info("*** Example ***")
logging.info("tokens: %s", " ".join(
[tokenization.printable_text(x) for x in instance.tokens]))
for feature_name in features.keys():
feature = features[feature_name]
values = []
if feature.int64_list.value:
values = feature.int64_list.value
elif feature.float_list.value:
values = feature.float_list.value
logging.info("%s: %s", feature_name, " ".join([str(x) for x in values]))
for writer in writers:
writer.close()
logging.info("Wrote %d total instances", total_written)
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
def create_float_feature(values):
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return feature
def create_training_instances(input_files,
tokenizer,
max_seq_length,
dupe_factor,
short_seq_prob,
masked_lm_prob,
max_predictions_per_seq,
rng,
do_whole_word_mask=False):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for input_file in input_files:
with tf.io.gfile.GFile(input_file, "rb") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
# Empty lines are used as document delimiters
if not line:
all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
all_documents[-1].append(tokens)
# Remove empty documents
all_documents = [x for x in all_documents if x]
rng.shuffle(all_documents)
vocab_words = list(tokenizer.vocab.keys())
instances = []
for _ in range(dupe_factor):
for document_index in range(len(all_documents)):
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask))
rng.shuffle(instances)
return instances
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
# Account for [CLS], [SEP], [SEP]
max_num_tokens = max_seq_length - 3
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
target_seq_length = max_num_tokens
if rng.random() < short_seq_prob:
target_seq_length = rng.randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# Random next
is_random_next = False
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
is_random_next=is_random_next,
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
"""Truncates a pair of sequences to a maximum sequence length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def main(_):
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.io.gfile.glob(input_pattern))
logging.info("*** Reading from input files ***")
for input_file in input_files:
logging.info(" %s", input_file)
rng = random.Random(FLAGS.random_seed)
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng, FLAGS.do_whole_word_mask)
output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***")
for output_file in output_files:
logging.info(" %s", output_file)
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files,
FLAGS.gzip_compress)
if __name__ == "__main__":
flags.mark_flag_as_required("input_file")
flags.mark_flag_as_required("output_file")
flags.mark_flag_as_required("vocab_file")
app.run(main)
@@ -0,0 +1,880 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to process data for SQuAD 1.1 and SQuAD 2.0."""
# pylint: disable=g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import math
import six
from absl import logging
import tensorflow as tf
from official.nlp.bert import tokenization
class SquadExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += ", question_text: %s" % (
tokenization.printable_text(self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
if self.start_position:
s += ", end_position: %d" % (self.end_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
class FeatureWriter(object):
"""Writes InputFeature to TF example file."""
def __init__(self, filename, is_training):
self.filename = filename
self.is_training = is_training
self.num_features = 0
self._writer = tf.io.TFRecordWriter(filename)
def process_feature(self, feature):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
self.num_features += 1
def create_int_feature(values):
feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values)))
return feature
features = collections.OrderedDict()
features["unique_ids"] = create_int_feature([feature.unique_id])
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position])
features["end_positions"] = create_int_feature([feature.end_position])
impossible = 0
if feature.is_impossible:
impossible = 1
features["is_impossible"] = create_int_feature([impossible])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
self._writer.write(tf_example.SerializeToString())
def close(self):
self._writer.close()
def read_squad_examples(input_file, is_training, version_2_with_negative):
"""Read a SQuAD json file into a list of SquadExample."""
with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"]
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True
return False
examples = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
doc_tokens = []
char_to_word_offset = []
prev_is_whitespace = True
for c in paragraph_text:
if is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
char_to_word_offset.append(len(doc_tokens) - 1)
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
start_position = None
end_position = None
orig_answer_text = None
is_impossible = False
if is_training:
if version_2_with_negative:
is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError(
"For training, each question should have exactly 1 answer.")
if not is_impossible:
answer = qa["answers"][0]
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
answer_length = len(orig_answer_text)
start_position = char_to_word_offset[answer_offset]
end_position = char_to_word_offset[answer_offset + answer_length -
1]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text = " ".join(
doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logging.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text)
continue
else:
start_position = -1
end_position = -1
orig_answer_text = ""
example = SquadExample(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position,
is_impossible=is_impossible)
examples.append(example)
return examples
def convert_examples_to_features(examples,
tokenizer,
max_seq_length,
doc_stride,
max_query_length,
is_training,
output_fn,
batch_size=None):
"""Loads a data file into a list of `InputBatch`s."""
base_id = 1000000000
unique_id = base_id
feature = None
for (example_index, example) in enumerate(examples):
query_tokens = tokenizer.tokenize(example.question_text)
if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length]
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
tok_start_position = None
tok_end_position = None
if is_training and example.is_impossible:
tok_start_position = -1
tok_end_position = -1
if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
else:
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
example.orig_answer_text)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
"DocSpan", ["start", "length"])
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc:
length = max_tokens_for_doc
doc_spans.append(_DocSpan(start=start_offset, length=length))
if start_offset + length == len(all_doc_tokens):
break
start_offset += min(length, doc_stride)
for (doc_span_index, doc_span) in enumerate(doc_spans):
tokens = []
token_to_orig_map = {}
token_is_max_context = {}
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in query_tokens:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
start_position = None
end_position = None
if is_training and not example.is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1
out_of_span = False
if not (tok_start_position >= doc_start and
tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
start_position = 0
end_position = 0
else:
doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
if is_training and example.is_impossible:
start_position = 0
end_position = 0
if example_index < 20:
logging.info("*** Example ***")
logging.info("unique_id: %s", (unique_id))
logging.info("example_index: %s", (example_index))
logging.info("doc_span_index: %s", (doc_span_index))
logging.info("tokens: %s",
" ".join([tokenization.printable_text(x) for x in tokens]))
logging.info(
"token_to_orig_map: %s", " ".join([
"%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)
]))
logging.info(
"token_is_max_context: %s", " ".join([
"%d:%s" % (x, y)
for (x, y) in six.iteritems(token_is_max_context)
]))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
if is_training and example.is_impossible:
logging.info("impossible example")
if is_training and not example.is_impossible:
answer_text = " ".join(tokens[start_position:(end_position + 1)])
logging.info("start_position: %d", (start_position))
logging.info("end_position: %d", (end_position))
logging.info("answer: %s", tokenization.printable_text(answer_text))
feature = InputFeatures(
unique_id=unique_id,
example_index=example_index,
doc_span_index=doc_span_index,
tokens=tokens,
token_to_orig_map=token_to_orig_map,
token_is_max_context=token_is_max_context,
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
start_position=start_position,
end_position=end_position,
is_impossible=example.is_impossible)
# Run callback
if is_training:
output_fn(feature)
else:
output_fn(feature, is_padding=False)
unique_id += 1
if not is_training and feature:
assert batch_size
num_padding = 0
num_examples = unique_id - base_id
if unique_id % batch_size != 0:
num_padding = batch_size - (num_examples % batch_size)
logging.info("Adding padding examples to make sure no partial batch.")
logging.info("Adds %d padding examples for inference.", num_padding)
dummy_feature = copy.deepcopy(feature)
for _ in range(num_padding):
dummy_feature.unique_id = unique_id
# Run callback
output_fn(feature, is_padding=True)
unique_id += 1
return unique_id - base_id
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text):
"""Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to
# whitespace-tokenized words. But then after WordPiece tokenization, we can
# often find a "better match". For example:
#
# Question: What year was John Smith born?
# Context: The leader was John Smith (1895-1943).
# Answer: 1895
#
# The original whitespace-tokenized answer will be "(1895-1943).". However
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
# the exact answer, 1895.
#
# However, this is not always possible. Consider the following:
#
# Question: What country is the top exporter of electornics?
# Context: The Japanese electronics industry is the lagest in the world.
# Answer: Japan
#
# In this case, the annotator chose "Japan" as a character sub-span of
# the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen.
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
if text_span == tok_answer_text:
return (new_start, new_end)
return (input_start, input_end)
def _check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token."""
# Because of the sliding window approach taken to scoring documents, a single
# token can appear in multiple documents. E.g.
# Doc: the man went to the store and bought a gallon of milk
# Span A: the man went to the
# Span B: to the store and bought
# Span C: and bought a gallon of
# ...
#
# Now the word 'bought' will have two scores from spans B and C. We only
# want to consider the score with "maximum context", which we define as
# the *minimum* of its left and right context (the *sum* of left and
# right context will always be the same, of course).
#
# In the example the maximum context for 'bought' would be span C since
# it has 1 left context and 3 right context, while span B has 4 left context
# and 0 right context.
best_score = None
best_span_index = None
for (span_index, doc_span) in enumerate(doc_spans):
end = doc_span.start + doc_span.length - 1
if position < doc_span.start:
continue
if position > end:
continue
num_left_context = position - doc_span.start
num_right_context = end - position
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
if best_score is None or score > best_score:
best_score = score
best_span_index = span_index
return cur_span_index == best_span_index
def write_predictions(all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
verbose=False):
"""Write final predictions to the json file and log-odds of null if needed."""
logging.info("Writing predictions to: %s", (output_prediction_file))
logging.info("Writing nbest to: %s", (output_nbest_file))
all_predictions, all_nbest_json, scores_diff_json = (
postprocess_output(all_examples=all_examples,
all_features=all_features,
all_results=all_results,
n_best_size=n_best_size,
max_answer_length=max_answer_length,
do_lower_case=do_lower_case,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=null_score_diff_threshold,
verbose=verbose))
write_to_json_files(all_predictions, output_prediction_file)
write_to_json_files(all_nbest_json, output_nbest_file)
if version_2_with_negative:
write_to_json_files(scores_diff_json, output_null_log_odds_file)
def postprocess_output(all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
verbose=False):
"""Postprocess model output, to form predicton results."""
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in all_results:
unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction",
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index]
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(
tok_text, orig_text, do_lower_case, verbose=verbose)
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
else:
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="", start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
probs = _compute_softmax(total_scores)
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
nbest_json.append(output)
assert len(nbest_json) >= 1
if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# pytype: disable=attribute-error
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
# pytype: enable=attribute-error
all_nbest_json[example.qas_id] = nbest_json
return all_predictions, all_nbest_json, scores_diff_json
def write_to_json_files(json_records, json_file):
with tf.io.gfile.GFile(json_file, "w") as writer:
writer.write(json.dumps(json_records, indent=4) + "\n")
def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heruistic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
if verbose:
logging.info("Unable to find text: '%s' in '%s'", pred_text, orig_text)
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
if verbose:
logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
if verbose:
logging.info("Couldn't map start position")
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
if verbose:
logging.info("Couldn't map end position")
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def _get_best_indexes(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
for i in range(len(index_and_score)): # pylint: disable=consider-using-enumerate
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
def generate_tf_record_from_json_file(input_file_path,
vocab_file_path,
output_path,
max_seq_length=384,
do_lower_case=True,
max_query_length=64,
doc_stride=128,
version_2_with_negative=False):
"""Generates and saves training data into a tf record file."""
train_examples = read_squad_examples(
input_file=input_file_path,
is_training=True,
version_2_with_negative=version_2_with_negative)
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file_path, do_lower_case=do_lower_case)
train_writer = FeatureWriter(filename=output_path, is_training=True)
number_of_examples = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=True,
output_fn=train_writer.process_feature)
train_writer.close()
meta_data = {
"task_type": "bert_squad",
"train_data_size": number_of_examples,
"max_seq_length": max_seq_length,
"max_query_length": max_query_length,
"doc_stride": doc_stride,
"version_2_with_negative": version_2_with_negative,
}
return meta_data
@@ -0,0 +1,890 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization.
The file is forked from:
https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import math
from absl import logging
import numpy as np
import tensorflow as tf
from official.nlp.bert import tokenization
class SquadExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
paragraph_text,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.paragraph_text = paragraph_text
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += ", question_text: %s" % (
tokenization.printable_text(self.question_text))
s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
if self.start_position:
s += ", end_position: %d" % (self.end_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tok_start_to_orig_index,
tok_end_to_orig_index,
token_is_max_context,
tokens,
input_ids,
input_mask,
segment_ids,
paragraph_len,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tok_start_to_orig_index = tok_start_to_orig_index
self.tok_end_to_orig_index = tok_end_to_orig_index
self.token_is_max_context = token_is_max_context
self.tokens = tokens
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.paragraph_len = paragraph_len
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def read_squad_examples(input_file, is_training, version_2_with_negative):
"""Read a SQuAD json file into a list of SquadExample."""
del version_2_with_negative
with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"]
examples = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
start_position = None
orig_answer_text = None
is_impossible = False
if is_training:
is_impossible = qa.get("is_impossible", False)
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError(
"For training, each question should have exactly 1 answer.")
if not is_impossible:
answer = qa["answers"][0]
orig_answer_text = answer["text"]
start_position = answer["answer_start"]
else:
start_position = -1
orig_answer_text = ""
example = SquadExample(
qas_id=qas_id,
question_text=question_text,
paragraph_text=paragraph_text,
orig_answer_text=orig_answer_text,
start_position=start_position,
is_impossible=is_impossible)
examples.append(example)
return examples
def _convert_index(index, pos, m=None, is_start=True):
"""Converts index."""
if index[pos] is not None:
return index[pos]
n = len(index)
rear = pos
while rear < n - 1 and index[rear] is None:
rear += 1
front = pos
while front > 0 and index[front] is None:
front -= 1
assert index[front] is not None or index[rear] is not None
if index[front] is None:
if index[rear] >= 1:
if is_start:
return 0
else:
return index[rear] - 1
return index[rear]
if index[rear] is None:
if m is not None and index[front] < m - 1:
if is_start:
return index[front] + 1
else:
return m - 1
return index[front]
if is_start:
if index[rear] > index[front] + 1:
return index[front] + 1
else:
return index[rear]
else:
if index[rear] > index[front] + 1:
return index[rear] - 1
else:
return index[front]
def convert_examples_to_features(examples,
tokenizer,
max_seq_length,
doc_stride,
max_query_length,
is_training,
output_fn,
do_lower_case,
batch_size=None):
"""Loads a data file into a list of `InputBatch`s."""
cnt_pos, cnt_neg = 0, 0
base_id = 1000000000
unique_id = base_id
max_n, max_m = 1024, 1024
f = np.zeros((max_n, max_m), dtype=np.float32)
for (example_index, example) in enumerate(examples):
if example_index % 100 == 0:
logging.info("Converting %d/%d pos %d neg %d", example_index,
len(examples), cnt_pos, cnt_neg)
query_tokens = tokenization.encode_ids(
tokenizer.sp_model,
tokenization.preprocess_text(
example.question_text, lower=do_lower_case))
if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length]
paragraph_text = example.paragraph_text
para_tokens = tokenization.encode_pieces(
tokenizer.sp_model,
tokenization.preprocess_text(
example.paragraph_text, lower=do_lower_case))
chartok_to_tok_index = []
tok_start_to_chartok_index = []
tok_end_to_chartok_index = []
char_cnt = 0
for i, token in enumerate(para_tokens):
new_token = token.replace(tokenization.SPIECE_UNDERLINE, " ")
chartok_to_tok_index.extend([i] * len(new_token))
tok_start_to_chartok_index.append(char_cnt)
char_cnt += len(new_token)
tok_end_to_chartok_index.append(char_cnt - 1)
tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE,
" ")
n, m = len(paragraph_text), len(tok_cat_text)
if n > max_n or m > max_m:
max_n = max(n, max_n)
max_m = max(m, max_m)
f = np.zeros((max_n, max_m), dtype=np.float32)
g = {}
# pylint: disable=cell-var-from-loop
def _lcs_match(max_dist, n=n, m=m):
"""Longest-common-substring algorithm."""
f.fill(0)
g.clear()
### longest common sub sequence
# f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
for i in range(n):
# unlike standard LCS, this is specifically optimized for the setting
# because the mismatch between sentence pieces and original text will
# be small
for j in range(i - max_dist, i + max_dist):
if j >= m or j < 0:
continue
if i > 0:
g[(i, j)] = 0
f[i, j] = f[i - 1, j]
if j > 0 and f[i, j - 1] > f[i, j]:
g[(i, j)] = 1
f[i, j] = f[i, j - 1]
f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
if (tokenization.preprocess_text(
paragraph_text[i], lower=do_lower_case,
remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
g[(i, j)] = 2
f[i, j] = f_prev + 1
# pylint: enable=cell-var-from-loop
max_dist = abs(n - m) + 5
for _ in range(2):
_lcs_match(max_dist)
if f[n - 1, m - 1] > 0.8 * n:
break
max_dist *= 2
orig_to_chartok_index = [None] * n
chartok_to_orig_index = [None] * m
i, j = n - 1, m - 1
while i >= 0 and j >= 0:
if (i, j) not in g:
break
if g[(i, j)] == 2:
orig_to_chartok_index[i] = j
chartok_to_orig_index[j] = i
i, j = i - 1, j - 1
elif g[(i, j)] == 1:
j = j - 1
else:
i = i - 1
if (all(v is None for v in orig_to_chartok_index) or
f[n - 1, m - 1] < 0.8 * n):
logging.info("MISMATCH DETECTED!")
continue
tok_start_to_orig_index = []
tok_end_to_orig_index = []
for i in range(len(para_tokens)):
start_chartok_pos = tok_start_to_chartok_index[i]
end_chartok_pos = tok_end_to_chartok_index[i]
start_orig_pos = _convert_index(
chartok_to_orig_index, start_chartok_pos, n, is_start=True)
end_orig_pos = _convert_index(
chartok_to_orig_index, end_chartok_pos, n, is_start=False)
tok_start_to_orig_index.append(start_orig_pos)
tok_end_to_orig_index.append(end_orig_pos)
if not is_training:
tok_start_position = tok_end_position = None
if is_training and example.is_impossible:
tok_start_position = 0
tok_end_position = 0
if is_training and not example.is_impossible:
start_position = example.start_position
end_position = start_position + len(example.orig_answer_text) - 1
start_chartok_pos = _convert_index(
orig_to_chartok_index, start_position, is_start=True)
tok_start_position = chartok_to_tok_index[start_chartok_pos]
end_chartok_pos = _convert_index(
orig_to_chartok_index, end_position, is_start=False)
tok_end_position = chartok_to_tok_index[end_chartok_pos]
assert tok_start_position <= tok_end_position
def _piece_to_id(x):
return tokenizer.sp_model.PieceToId(x)
all_doc_tokens = list(map(_piece_to_id, para_tokens))
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
"DocSpan", ["start", "length"])
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc:
length = max_tokens_for_doc
doc_spans.append(_DocSpan(start=start_offset, length=length))
if start_offset + length == len(all_doc_tokens):
break
start_offset += min(length, doc_stride)
for (doc_span_index, doc_span) in enumerate(doc_spans):
tokens = []
token_is_max_context = {}
segment_ids = []
cur_tok_start_to_orig_index = []
cur_tok_end_to_orig_index = []
tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
segment_ids.append(0)
for token in query_tokens:
tokens.append(token)
segment_ids.append(0)
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
segment_ids.append(0)
for i in range(doc_span.length):
split_token_index = doc_span.start + i
cur_tok_start_to_orig_index.append(
tok_start_to_orig_index[split_token_index])
cur_tok_end_to_orig_index.append(
tok_end_to_orig_index[split_token_index])
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1)
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
segment_ids.append(1)
paragraph_len = len(tokens)
input_ids = tokens
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
span_is_impossible = example.is_impossible
start_position = None
end_position = None
if is_training and not span_is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1
out_of_span = False
if not (tok_start_position >= doc_start and
tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
# continue
start_position = 0
end_position = 0
span_is_impossible = True
else:
doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
if is_training and span_is_impossible:
start_position = 0
end_position = 0
if example_index < 20:
logging.info("*** Example ***")
logging.info("unique_id: %s", (unique_id))
logging.info("example_index: %s", (example_index))
logging.info("doc_span_index: %s", (doc_span_index))
logging.info("tok_start_to_orig_index: %s",
" ".join([str(x) for x in cur_tok_start_to_orig_index]))
logging.info("tok_end_to_orig_index: %s",
" ".join([str(x) for x in cur_tok_end_to_orig_index]))
logging.info(
"token_is_max_context: %s", " ".join(
["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()]))
logging.info(
"input_pieces: %s",
" ".join([tokenizer.sp_model.IdToPiece(x) for x in tokens]))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
if is_training and span_is_impossible:
logging.info("impossible example span")
if is_training and not span_is_impossible:
pieces = [
tokenizer.sp_model.IdToPiece(token)
for token in tokens[start_position:(end_position + 1)]
]
answer_text = tokenizer.sp_model.DecodePieces(pieces)
logging.info("start_position: %d", (start_position))
logging.info("end_position: %d", (end_position))
logging.info("answer: %s", (tokenization.printable_text(answer_text)))
# With multi processing, the example_index is actually the index
# within the current process therefore we use example_index=None
# to avoid being used in the future.
# The current code does not use example_index of training data.
if is_training:
feat_example_index = None
else:
feat_example_index = example_index
feature = InputFeatures(
unique_id=unique_id,
example_index=feat_example_index,
doc_span_index=doc_span_index,
tok_start_to_orig_index=cur_tok_start_to_orig_index,
tok_end_to_orig_index=cur_tok_end_to_orig_index,
token_is_max_context=token_is_max_context,
tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
paragraph_len=paragraph_len,
start_position=start_position,
end_position=end_position,
is_impossible=span_is_impossible)
# Run callback
if is_training:
output_fn(feature)
else:
output_fn(feature, is_padding=False)
unique_id += 1
if span_is_impossible:
cnt_neg += 1
else:
cnt_pos += 1
if not is_training and feature:
assert batch_size
num_padding = 0
num_examples = unique_id - base_id
if unique_id % batch_size != 0:
num_padding = batch_size - (num_examples % batch_size)
dummy_feature = copy.deepcopy(feature)
for _ in range(num_padding):
dummy_feature.unique_id = unique_id
# Run callback
output_fn(feature, is_padding=True)
unique_id += 1
logging.info("Total number of instances: %d = pos %d neg %d",
cnt_pos + cnt_neg, cnt_pos, cnt_neg)
return unique_id - base_id
def _check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token."""
# Because of the sliding window approach taken to scoring documents, a single
# token can appear in multiple documents. E.g.
# Doc: the man went to the store and bought a gallon of milk
# Span A: the man went to the
# Span B: to the store and bought
# Span C: and bought a gallon of
# ...
#
# Now the word 'bought' will have two scores from spans B and C. We only
# want to consider the score with "maximum context", which we define as
# the *minimum* of its left and right context (the *sum* of left and
# right context will always be the same, of course).
#
# In the example the maximum context for 'bought' would be span C since
# it has 1 left context and 3 right context, while span B has 4 left context
# and 0 right context.
best_score = None
best_span_index = None
for (span_index, doc_span) in enumerate(doc_spans):
end = doc_span.start + doc_span.length - 1
if position < doc_span.start:
continue
if position > end:
continue
num_left_context = position - doc_span.start
num_right_context = end - position
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
if best_score is None or score > best_score:
best_score = score
best_span_index = span_index
return cur_span_index == best_span_index
def write_predictions(all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
verbose=False):
"""Write final predictions to the json file and log-odds of null if needed."""
logging.info("Writing predictions to: %s", (output_prediction_file))
logging.info("Writing nbest to: %s", (output_nbest_file))
all_predictions, all_nbest_json, scores_diff_json = (
postprocess_output(all_examples=all_examples,
all_features=all_features,
all_results=all_results,
n_best_size=n_best_size,
max_answer_length=max_answer_length,
do_lower_case=do_lower_case,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=null_score_diff_threshold,
verbose=verbose))
write_to_json_files(all_predictions, output_prediction_file)
write_to_json_files(all_nbest_json, output_nbest_file)
if version_2_with_negative:
write_to_json_files(scores_diff_json, output_null_log_odds_file)
def postprocess_output(all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
verbose=False):
"""Postprocess model output, to form predicton results."""
del do_lower_case, verbose
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in all_results:
unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction",
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index]
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
doc_offset = feature.tokens.index("[SEP]") + 1
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
continue
if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
continue
# if start_index not in feature.tok_start_to_orig_index:
# continue
# if end_index not in feature.tok_end_to_orig_index:
# continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index - doc_offset,
end_index=end_index - doc_offset,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=-1,
end_index=-1,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index >= 0: # this is a non-null prediction
tok_start_to_orig_index = feature.tok_start_to_orig_index
tok_end_to_orig_index = feature.tok_end_to_orig_index
start_orig_pos = tok_start_to_orig_index[pred.start_index]
end_orig_pos = tok_end_to_orig_index[pred.end_index]
paragraph_text = example.paragraph_text
final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
else:
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="", start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
probs = _compute_softmax(total_scores)
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
nbest_json.append(output)
assert len(nbest_json) >= 1
if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
assert best_non_null_entry is not None
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
return all_predictions, all_nbest_json, scores_diff_json
def write_to_json_files(json_records, json_file):
with tf.io.gfile.GFile(json_file, "w") as writer:
writer.write(json.dumps(json_records, indent=4) + "\n")
def _get_best_indexes(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
for i in range(len(index_and_score)):
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
class FeatureWriter(object):
"""Writes InputFeature to TF example file."""
def __init__(self, filename, is_training):
self.filename = filename
self.is_training = is_training
self.num_features = 0
self._writer = tf.io.TFRecordWriter(filename)
def process_feature(self, feature):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
self.num_features += 1
def create_int_feature(values):
feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values)))
return feature
features = collections.OrderedDict()
features["unique_ids"] = create_int_feature([feature.unique_id])
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position])
features["end_positions"] = create_int_feature([feature.end_position])
impossible = 0
if feature.is_impossible:
impossible = 1
features["is_impossible"] = create_int_feature([impossible])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
self._writer.write(tf_example.SerializeToString())
def close(self):
self._writer.close()
def generate_tf_record_from_json_file(input_file_path,
sp_model_file,
output_path,
max_seq_length=384,
do_lower_case=True,
max_query_length=64,
doc_stride=128,
version_2_with_negative=False):
"""Generates and saves training data into a tf record file."""
train_examples = read_squad_examples(
input_file=input_file_path,
is_training=True,
version_2_with_negative=version_2_with_negative)
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=sp_model_file)
train_writer = FeatureWriter(filename=output_path, is_training=True)
number_of_examples = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=True,
output_fn=train_writer.process_feature,
do_lower_case=do_lower_case)
train_writer.close()
meta_data = {
"task_type": "bert_squad",
"train_data_size": number_of_examples,
"max_seq_length": max_seq_length,
"max_query_length": max_query_length,
"doc_stride": doc_stride,
"version_2_with_negative": version_2_with_negative,
}
return meta_data
@@ -0,0 +1,43 @@
# NLP Modeling Library
This libary provides a set of Keras primitives (Layers, Networks, and Models)
that can be assembled into transformer-based models. They are
flexible, validated, interoperable, and both TF1 and TF2 compatible.
* [`layers`](layers) are the fundamental building blocks for NLP models.
They can be used to assemble new layers, networks, or models.
* [`networks`](networks) are combinations of layers (and possibly other networks). They are sub-units of models that would not be trained alone. They
encapsulate common network structures like a classification head
or a transformer encoder into an easily handled object with a
standardized configuration.
* [`models`](models) are combinations of layers and networks that would be trained. Pre-built canned models are provided as both convenience functions and canonical examples.
* [`losses`](losses) contains common loss computation used in NLP tasks.
Besides the pre-defined primitives, it also provides scaffold classes to allow
easy experimentation with noval achitectures, e.g., you dont need to fork a whole Transformer object to try a different kind of attention primitive, for instance.
* [`TransformerScaffold`](layers/transformer_scaffold.py) implements the
Transformer from ["Attention Is All You Need"]
(https://arxiv.org/abs/1706.03762), with a customizable attention layer
option. Users can pass a class to `attention_cls` and associated config to
`attention_cfg`, in which case the scaffold will instantiate the class with
the config, or pass a class instance to `attention_cls`.
* [`EncoderScaffold`](networks/encoder_scaffold.py) implements the transformer
encoder from ["BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding"](https://arxiv.org/abs/1810.04805), with customizable
embedding subnetwork (which will replace the standard embedding logic) and/or a
custom hidden layer (which will replace the Transformer instantiation in the
encoder).
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
@@ -0,0 +1,29 @@
# Layers
Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given initialization
parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
If `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache used
for auto-agressive decoding.
* [Transformer](transformer.py) implements an optionally masked transformer as
described in ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [OnDeviceEmbedding](on_device_embedding.py) implements efficient embedding lookups designed for TPU-based models.
* [PositionalEmbedding](position_embedding.py) creates a positional embedding
as described in ["BERT: Pre-training
of Deep Bidirectional Transformers for Language Understanding"]
(https://arxiv.org/abs/1810.04805).
* [SelfAttentionMask](self_attention_mask.py) creates a 3D attention mask from a 2D tensor mask.
* [MaskedSoftmax](masked_softmax.py) implements a softmax with an optional masking input. If no mask is provided to this layer, it performs a standard softmax; however, if a mask tensor is applied (which should be 1 in positions where the data should be allowed through, and 0 where the data should be masked), the output will have masked positions set to approximately zero.
@@ -0,0 +1,23 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers package definition."""
from official.nlp.modeling.layers.attention import * # pylint: disable=wildcard-import
from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.transformer import Transformer
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
@@ -0,0 +1,264 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import math
import tensorflow as tf
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax
@tf.keras.utils.register_keras_serializable(package="Text")
class MultiHeadAttention(tf.keras.layers.Layer):
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then
this is self-attention. Each timestep in `from_tensor` attends to the
corresponding sequence in `to_tensor`, and returns a fixed-width vector.
This function first projects `from_tensor` into a "query" tensor and
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
of tensors of length `num_attention_heads`, where each tensor is of shape
[batch_size, seq_length, size_per_head].
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor and returned.
Arguments:
num_heads: Number of attention heads.
head_size: Size of each attention head.
dropout: Dropout probability.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(self,
num_heads,
head_size,
dropout_rate=0.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._dropout_rate = dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="key")
self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="value")
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"num_heads":
self._num_heads,
"head_size":
self._head_size,
"dropout_rate":
self._dropout_rate,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) == 3 else None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
@tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding.
Arguments:
num_heads: Number of attention heads.
head_size: Size of each attention head.
**kwargs: Other keyword arguments inherit from `Attention` class.
"""
def __init__(self, num_heads, head_size, **kwargs):
super(CachedAttention, self).__init__(num_heads, head_size, **kwargs)
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
if decode_loop_step is not None:
# TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
[1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices
value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
[1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices
else:
key_tensor = tf.concat(
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
# Update cache
cache["key"] = key_tensor
cache["value"] = value_tensor
return key_tensor, value_tensor
def call(self, inputs, decode_loop_step=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) >= 3 else None
cache = inputs[3] if len(inputs) >= 4 else None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
cache, decode_loop_step)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor), cache
@@ -0,0 +1,157 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the attention layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import attention
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class MultiHeadAttentionTest(keras_parameterized.TestCase):
def test_non_masked_attention(self):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
to_tensor = tf.keras.Input(shape=(20, 80))
output = test_layer([from_tensor, to_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=12, head_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
def test_masked_attention(self):
"""Test with a mask tensor."""
test_layer = attention.MultiHeadAttention(num_heads=2, head_size=2)
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(4, 8))
to_tensor = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([from_tensor, to_tensor, mask_tensor])
# Create a model containing the test layer.
model = tf.keras.Model([from_tensor, to_tensor, mask_tensor], output)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((3, 4, 8))
to_data = 10 * np.random.random_sample((3, 2, 8))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=(3, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((3, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = attention.MultiHeadAttention(
num_heads=12,
head_size=64,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
from_tensor = tf.keras.Input(shape=(40, 80))
output = test_layer([from_tensor, from_tensor])
self.assertEqual(output.shape.as_list(), [None, 40, 12, 64])
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
"key":
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
"value":
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
@keras_parameterized.run_all_keras_modes
class CachedAttentionTest(keras_parameterized.TestCase):
def test_masked_attention(self):
"""Test with a mask tensor."""
num_heads, head_size = 2, 2
# Create a 3-dimensional input (the first dimension is implicit).
from_seq_length = 4
batch_size = 3
# GPU/CPU case.
init_decode_length = 0
# Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size)
# Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data, mask_data, cache])
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data])
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertIsNone(cache)
def test_padded_decode(self):
"""Test with a mask tensor."""
num_heads, head_size = 2, 2
from_seq_length = 4
# TPU decoding should pre-allocate the entire sequence.
batch_size = 3
init_decode_length = from_seq_length
# Directly tests the keras layer.
cache = _create_cache(batch_size, init_decode_length, num_heads, head_size)
layer = attention.CachedAttention(num_heads=num_heads, head_size=head_size)
# Generate data for the input (non-mask) tensors.
from_data = tf.zeros((batch_size, from_seq_length, 8), dtype=np.float32)
decode_loop_step = 2
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data, mask_data, cache],
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 2, 2))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
if __name__ == "__main__":
tf.test.main()
@@ -0,0 +1,180 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based einsum layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
@tf.keras.utils.register_keras_serializable(package="Text")
class DenseEinsum(tf.keras.layers.Layer):
"""A densely connected layer that uses tf.einsum as the backing computation.
This layer can perform einsum calculations of arbitrary dimensionality.
Arguments:
output_shape: Positive integer or tuple, dimensionality of the output space.
num_summed_dimensions: The number of dimensions to sum over. Standard 2D
matmul should use 1, 3D matmul should use 2, and so forth.
activation: Activation function to use. If you don't specify anything, no
activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation")..
kernel_constraint: Constraint function applied to the `kernel` weights
matrix.
bias_constraint: Constraint function applied to the bias vector.
Input shape:
N-D tensor with shape: `(batch_size, ..., input_dim)`. The most common
situation would be a 2D input with shape `(batch_size, input_dim)`.
Output shape:
N-D tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D
input with shape `(batch_size, input_dim)`, the output would have shape
`(batch_size, units)`.
"""
def __init__(self,
output_shape,
num_summed_dimensions=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(DenseEinsum, self).__init__(**kwargs)
self._output_shape = output_shape if isinstance(
output_shape, (list, tuple)) else (output_shape,)
self._activation = tf.keras.activations.get(activation)
self._use_bias = use_bias
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._num_summed_dimensions = num_summed_dimensions
self._einsum_string = None
def _build_einsum_string(self, free_input_dims, bound_dims, output_dims):
input_str = ""
kernel_str = ""
output_str = ""
letter_offset = 0
for i in range(free_input_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_input_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
return input_str + "," + kernel_str + "->" + output_str
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
input_rank = input_shape.rank
free_input_dims = input_rank - self._num_summed_dimensions
output_dims = len(self._output_shape)
self._einsum_string = self._build_einsum_string(free_input_dims,
self._num_summed_dimensions,
output_dims)
# This is only saved for testing purposes.
self._kernel_shape = (
input_shape[free_input_dims:].concatenate(self._output_shape))
self._kernel = self.add_weight(
"kernel",
shape=self._kernel_shape,
initializer=self._kernel_initializer,
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
trainable=True)
if self._use_bias:
self._bias = self.add_weight(
"bias",
shape=self._output_shape,
initializer=self._bias_initializer,
regularizer=self._bias_regularizer,
constraint=self._bias_constraint,
dtype=self.dtype,
trainable=True)
else:
self._bias = None
super(DenseEinsum, self).build(input_shape)
def get_config(self):
config = {
"output_shape":
self._output_shape,
"num_summed_dimensions":
self._num_summed_dimensions,
"activation":
tf.keras.activations.serialize(self._activation),
"use_bias":
self._use_bias,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(DenseEinsum, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
ret = tf.einsum(self._einsum_string, inputs, self._kernel)
if self._use_bias:
ret += self._bias
if self._activation is not None:
ret = self._activation(ret)
return ret
@@ -0,0 +1,123 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based einsum layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import dense_einsum
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class DenseEinsumLayer(keras_parameterized.TestCase):
def test_3D_einsum_with_two_bound_dimensions(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=(64,), num_summed_dimensions=2)
# Create a 4-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 40, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
def test_3D_einsum_with_one_bound_dimensions(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=(64, 32), num_summed_dimensions=1)
# Create a 3-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abc,cde->abde")
self.assertEqual(test_layer._kernel_shape, (80, 64, 32))
def test_2D_einsum_with_one_bound_dimensions(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=(64,), num_summed_dimensions=1)
# Create a 3-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
self.assertEqual(test_layer._kernel_shape, (80, 64))
def test_bias_term_can_be_disabled(self):
# A layer created using the bias should have two weights.
test_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1, use_bias=True)
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(2, len(test_layer.get_weights()))
# A layer created without the bias should have only one weight.
test_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1, use_bias=False)
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(1, len(test_layer.get_weights()))
def test_activation(self):
# Create a model that does not use an activation.
no_activation_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1, activation=None)
input_tensor = tf.keras.Input(shape=(None, 80))
output_tensor = no_activation_layer(input_tensor)
no_activation_model = tf.keras.Model(input_tensor, output_tensor)
# Create a model that uses a softmax activation.
activation_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1, activation="softmax")
input_tensor = tf.keras.Input(shape=(None, 80))
output_tensor = activation_layer(input_tensor)
activation_model = tf.keras.Model(input_tensor, output_tensor)
# Make sure the models' weights are identical.
activation_model.set_weights(no_activation_model.get_weights())
# Predict using each model on the same input data. The output should be
# different, since one is using a softmax - even though the models' weights
# are the same.
input_values = 10 * np.random.random_sample((10, 4, 80))
non_activated_data = no_activation_model.predict(input_values)
activated_data = activation_model.predict(input_values)
self.assertNotAllClose(activated_data, non_activated_data)
def test_non_iterable_output_shape(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=64, num_summed_dimensions=1)
# Create a 3-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abc,cd->abd")
self.assertEqual(test_layer._kernel_shape, (80, 64))
def test_with_explicit_initializer(self):
test_layer = dense_einsum.DenseEinsum(
output_shape=(64,),
num_summed_dimensions=2,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 4-dimensional input (the first dimension is implicit).
input_tensor = tf.keras.Input(shape=(None, 40, 80))
_ = test_layer(input_tensor)
self.assertEqual(test_layer._einsum_string, "abcd,cde->abe")
self.assertEqual(test_layer._kernel_shape, (40, 80, 64))
if __name__ == "__main__":
tf.test.main()
@@ -0,0 +1,61 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based softmax layer with optional masking."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
class MaskedSoftmax(tf.keras.layers.Layer):
"""Performs a softmax with optional masking on a tensor.
Arguments:
mask_expansion_axes: Any axes that should be padded on the mask tensor.
"""
def __init__(self, mask_expansion_axes=None, **kwargs):
self._mask_expansion_axes = mask_expansion_axes
super(MaskedSoftmax, self).__init__(**kwargs)
def call(self, inputs):
if isinstance(inputs, list) and len(inputs) == 2:
scores, mask = inputs
else:
scores, mask = (inputs, None)
if mask is not None:
if self._mask_expansion_axes is not None:
mask = tf.expand_dims(mask, axis=self._mask_expansion_axes)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - tf.cast(mask, scores.dtype)) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
scores += adder
return tf.nn.softmax(scores)
def get_config(self):
config = {'mask_expansion_axes': self._mask_expansion_axes}
base_config = super(MaskedSoftmax, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -0,0 +1,88 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based masked softmax layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import masked_softmax
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class MaskedSoftmaxLayerTest(keras_parameterized.TestCase):
def test_non_masked_softmax(self):
test_layer = masked_softmax.MaskedSoftmax()
input_tensor = tf.keras.Input(shape=(4, 8))
output = test_layer(input_tensor)
model = tf.keras.Model(input_tensor, output)
input_data = 10 * np.random.random_sample((3, 4, 8))
output_data = model.predict(input_data)
expected_data = tf.nn.softmax(input_data)
self.assertAllClose(expected_data, output_data)
def test_masked_softmax(self):
test_layer = masked_softmax.MaskedSoftmax()
input_tensor = tf.keras.Input(shape=(4, 8))
mask_tensor = tf.keras.Input(shape=(4, 8))
output = test_layer([input_tensor, mask_tensor])
model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample((3, 4, 8))
mask_data = np.random.randint(2, size=(3, 4, 8))
output_data = model.predict([input_data, mask_data])
expected_zeros = np.greater(mask_data, 0)
is_zeros = np.greater(output_data, 0)
self.assertAllEqual(expected_zeros, is_zeros)
def test_masked_softmax_with_none_mask(self):
test_layer = masked_softmax.MaskedSoftmax()
input_tensor = tf.keras.Input(shape=(4, 8))
output = test_layer([input_tensor, None])
model = tf.keras.Model(input_tensor, output)
input_data = 10 * np.random.random_sample((3, 4, 8))
output_data = model.predict(input_data)
expected_data = tf.nn.softmax(input_data)
self.assertAllClose(expected_data, output_data)
def test_softmax_with_axes_expansion(self):
test_layer = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
input_tensor = tf.keras.Input(shape=(4, 8))
mask_tensor = tf.keras.Input(shape=(8))
output = test_layer([input_tensor, mask_tensor])
model = tf.keras.Model([input_tensor, mask_tensor], output)
input_data = 10 * np.random.random_sample((3, 4, 8))
mask_data = np.random.randint(2, size=(3, 8))
output_data = model.predict([input_data, mask_data])
expanded_mask = np.expand_dims(mask_data, axis=1) * np.ones_like(input_data)
expected_zeros = np.greater(expanded_mask, 0)
is_zeros = np.greater(output_data, 0)
self.assertAllEqual(expected_zeros, is_zeros)
if __name__ == '__main__':
tf.test.main()

Some files were not shown because too many files have changed in this diff Show More