apply formatting with black (#1038)

* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
This commit is contained in:
Jong Wook Kim
2023-03-06 18:50:37 -05:00
committed by GitHub
parent 500d0fe966
commit b80bcf610d
21 changed files with 533 additions and 227 deletions
+19 -9
View File
@@ -1,7 +1,7 @@
import os
import string
from dataclasses import dataclass
from functools import lru_cache, cached_property
from functools import cached_property, lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -138,7 +138,9 @@ class Tokenizer:
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
def decode(
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
):
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
@@ -154,8 +156,9 @@ class Tokenizer:
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)
return "".join(
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@cached_property
def eot(self) -> int:
@@ -197,7 +200,7 @@ class Tokenizer:
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured")
raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
@@ -242,8 +245,10 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@@ -255,7 +260,10 @@ class Tokenizer:
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
for tokens in [
self.tokenizer.encode(symbol),
self.tokenizer.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
@@ -367,4 +375,6 @@ def get_tokenizer(
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
return Tokenizer(
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
)