apply formatting with black (#1038)
* applying black (with the default 88-column limit) * add flake8 * add isort * fix isort
This commit is contained in:
+19
-9
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache, cached_property
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -138,7 +138,9 @@ class Tokenizer:
|
||||
def encode(self, text, **kwargs):
|
||||
return self.tokenizer.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
||||
def decode(
|
||||
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
|
||||
):
|
||||
return self.tokenizer.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, tokens) -> str:
|
||||
@@ -154,8 +156,9 @@ class Tokenizer:
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
return "".join(outputs)
|
||||
return "".join(
|
||||
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
@@ -197,7 +200,7 @@ class Tokenizer:
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError(f"This tokenizer does not have language token configured")
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
additional_tokens = dict(
|
||||
zip(
|
||||
@@ -242,8 +245,10 @@ class Tokenizer:
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
||||
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||
symbols += (
|
||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
)
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
@@ -255,7 +260,10 @@ class Tokenizer:
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
||||
for tokens in [
|
||||
self.tokenizer.encode(symbol),
|
||||
self.tokenizer.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
@@ -367,4 +375,6 @@ def get_tokenizer(
|
||||
if task is not None:
|
||||
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||
|
||||
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
||||
return Tokenizer(
|
||||
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user