This repository has been archived by the owner on Feb 16, 2022. It is now read-only.
/
text_encoder.py
760 lines (609 loc) · 25.2 KB
/
text_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoders for text data.
* TextEncoder: base class
* ByteTextEncoder: for ascii text
* TokenTextEncoder: with user-supplied vocabulary file
* SubwordTextEncoder: invertible
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from itertools import chain
import re
# Dependency imports
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tokenizer
import tensorflow as tf
# Reserved tokens for things like padding and EOS symbols.
PAD = "<pad>"
EOS = "<EOS>"
RESERVED_TOKENS = [PAD, EOS]
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
if six.PY2:
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
# Regular expression for unescaping token strings.
# '\u' is converted to '_'
# '\\' is converted to '\'
# '\213;' is converted to unichr(213)
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_ESCAPE_CHARS = set(u"\\_u;0123456789")
# Conversion between Unicode and UTF-8, if required (on Python2).
if six.PY2:
def native_to_unicode(s):
return s if isinstance(s, unicode) else s.decode("utf8")
def unicode_to_native(s):
return s.encode("utf-8")
else: # No conversion required on Python >= 3.
def native_to_unicode(s):
return s
def unicode_to_native(s):
return s
class TextEncoder(object):
"""Base class for converting from ints to/from human readable strings."""
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
self._num_reserved_ids = num_reserved_ids
@property
def num_reserved_ids(self):
return self._num_reserved_ids
def encode(self, s):
"""Transform a human-readable string into a sequence of int ids.
The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
num_reserved_ids) are reserved.
EOS is not appended.
Args:
s: human-readable string to be converted.
Returns:
ids: list of integers
"""
return [int(w) + self._num_reserved_ids for w in s.split()]
def decode(self, ids):
"""Transform a sequence of int ids into a human-readable string.
EOS is not expected in ids.
Args:
ids: list of integers to be converted.
Returns:
s: human-readable string.
"""
return " ".join(self.decode_list(ids))
def decode_list(self, ids):
"""Transform a sequence of int ids into a their string versions.
This method supports transforming individual input/output ids to their
string versions so that sequence to/from text conversions can be visualized
in a human readable format.
Args:
ids: list of integers to be converted.
Returns:
strs: list of human-readable string.
"""
decoded_ids = []
for id_ in ids:
if 0 <= id_ < self._num_reserved_ids:
decoded_ids.append(RESERVED_TOKENS[int(id_)])
else:
decoded_ids.append(id_ - self._num_reserved_ids)
return [str(d) for d in decoded_ids]
@property
def vocab_size(self):
raise NotImplementedError()
class ByteTextEncoder(TextEncoder):
"""Encodes each byte to an id. For 8-bit strings only."""
def encode(self, s):
numres = self._num_reserved_ids
if six.PY2:
if isinstance(s, unicode):
s = s.encode("utf-8")
return [ord(c) + numres for c in s]
# Python3: explicitly convert to UTF-8
return [c + numres for c in s.encode("utf-8")]
def decode(self, ids):
numres = self._num_reserved_ids
decoded_ids = []
int2byte = six.int2byte
for id_ in ids:
if 0 <= id_ < numres:
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
else:
decoded_ids.append(int2byte(id_ - numres))
if six.PY2:
return "".join(decoded_ids)
# Python3: join byte arrays and then decode string
return b"".join(decoded_ids).decode("utf-8", "replace")
def decode_list(self, ids):
numres = self._num_reserved_ids
decoded_ids = []
int2byte = six.int2byte
for id_ in ids:
if 0 <= id_ < numres:
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
else:
decoded_ids.append(int2byte(id_ - numres))
# Python3: join byte arrays and then decode string
return decoded_ids
@property
def vocab_size(self):
return 2**8 + self._num_reserved_ids
class ClassLabelEncoder(TextEncoder):
"""Encoder for class labels."""
def __init__(self, class_labels=None, class_labels_fname=None):
super(ClassLabelEncoder, self).__init__(num_reserved_ids=0)
assert class_labels or class_labels_fname
assert not (class_labels and class_labels_fname)
if class_labels_fname:
with tf.gfile.Open(class_labels_fname) as f:
class_labels = [label.strip() for label in f.readlines()]
self._class_labels = class_labels
def encode(self, label_str):
return self._class_labels.index(label_str)
def decode(self, label_id):
if isinstance(label_id, list):
assert len(label_id) == 1
label_id, = label_id
return self._class_labels[label_id]
@property
def vocab_size(self):
return len(self._class_labels)
class TokenTextEncoder(TextEncoder):
"""Encoder based on a user-supplied vocabulary (file or list)."""
def __init__(self,
vocab_filename,
reverse=False,
vocab_list=None,
replace_oov=None,
num_reserved_ids=NUM_RESERVED_TOKENS):
"""Initialize from a file or list, one token per line.
Handling of reserved tokens works as follows:
- When initializing from a list, we add reserved tokens to the vocab.
- When initializing from a file, we do not add reserved tokens to the vocab.
- When saving vocab files, we save reserved tokens to the file.
Args:
vocab_filename: If not None, the full filename to read vocab from. If this
is not None, then vocab_list should be None.
reverse: Boolean indicating if tokens should be reversed during encoding
and decoding.
vocab_list: If not None, a list of elements of the vocabulary. If this is
not None, then vocab_filename should be None.
replace_oov: If not None, every out-of-vocabulary token seen when
encoding will be replaced by this string (which must be in vocab).
num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
"""
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
self._reverse = reverse
self._replace_oov = replace_oov
if vocab_filename:
self._init_vocab_from_file(vocab_filename)
else:
assert vocab_list is not None
self._init_vocab_from_list(vocab_list)
def encode(self, sentence):
"""Converts a space-separated string of tokens to a list of ids."""
tokens = sentence.strip().split()
if self._replace_oov is not None:
tokens = [t if t in self._token_to_id else self._replace_oov
for t in tokens]
ret = [self._token_to_id[tok] for tok in tokens]
return ret[::-1] if self._reverse else ret
def decode(self, ids):
return " ".join(self.decode_list(ids))
def decode_list(self, ids):
seq = reversed(ids) if self._reverse else ids
return [self._safe_id_to_token(i) for i in seq]
@property
def vocab_size(self):
return len(self._id_to_token)
def _safe_id_to_token(self, idx):
return self._id_to_token.get(idx, "ID_%d" % idx)
def _init_vocab_from_file(self, filename):
"""Load vocab from a file.
Args:
filename: The file to load vocabulary from.
"""
def token_gen():
with tf.gfile.Open(filename) as f:
for line in f:
token = line.strip()
yield token
self._init_vocab(token_gen(), add_reserved_tokens=False)
def _init_vocab_from_list(self, vocab_list):
"""Initialize tokens from a list of tokens.
It is ok if reserved tokens appear in the vocab list. They will be
removed. The set of tokens in vocab_list should be unique.
Args:
vocab_list: A list of tokens.
"""
def token_gen():
for token in vocab_list:
if token not in RESERVED_TOKENS:
yield token
self._init_vocab(token_gen())
def _init_vocab(self, token_generator, add_reserved_tokens=True):
"""Initialize vocabulary with tokens from token_generator."""
self._id_to_token = {}
non_reserved_start_index = 0
if add_reserved_tokens:
self._id_to_token.update(enumerate(RESERVED_TOKENS))
non_reserved_start_index = len(RESERVED_TOKENS)
self._id_to_token.update(
enumerate(token_generator, start=non_reserved_start_index))
# _token_to_id is the reverse of _id_to_token
self._token_to_id = dict((v, k)
for k, v in six.iteritems(self._id_to_token))
def store_to_file(self, filename):
"""Write vocab file to disk.
Vocab files have one token per line. The file ends in a newline. Reserved
tokens are written to the vocab file as well.
Args:
filename: Full path of the file to store the vocab to.
"""
with tf.gfile.Open(filename, "w") as f:
for i in xrange(len(self._id_to_token)):
f.write(self._id_to_token[i] + "\n")
def _escape_token(token, alphabet):
"""Escape away underscores and OOV characters and append '_'.
This allows the token to be experessed as the concatenation of a list
of subtokens from the vocabulary. The underscore acts as a sentinel
which allows us to invertibly concatenate multiple such lists.
Args:
token: A unicode string to be escaped.
alphabet: A set of all characters in the vocabulary's alphabet.
Returns:
escaped_token: An escaped unicode string.
Raises:
ValueError: If the provided token is not unicode.
"""
if not isinstance(token, six.text_type):
raise ValueError("Expected string type for token, got %s" % type(token))
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u")
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token]
return u"".join(ret) + "_"
def _unescape_token(escaped_token):
"""Inverse of _escape_token().
Args:
escaped_token: a unicode string
Returns:
token: a unicode string
"""
def match(m):
if m.group(1) is None:
return u"_" if m.group(0) == u"\\u" else u"\\"
try:
return six.unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return ""
trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token
return _UNESCAPE_REGEX.sub(match, trimmed)
class SubwordTextEncoder(TextEncoder):
"""Class for invertibly encoding text using a limited vocabulary.
Invertibly encodes a native string as a sequence of subtokens from a limited
vocabulary.
A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
the corpus), and stored to a file. See text_encoder_build_subword.py.
It can then be loaded and used to encode/decode any text.
Encoding has four phases:
1. Tokenize into a list of tokens. Each token is a unicode string of either
all alphanumeric characters or all non-alphanumeric characters. We drop
tokens consisting of a single space that are between two alphanumeric
tokens.
2. Escape each token. This escapes away special and out-of-vocabulary
characters, and makes sure that each token ends with an underscore, and
has no other underscores.
3. Represent each escaped token as a the concatenation of a list of subtokens
from the limited vocabulary. Subtoken selection is done greedily from
beginning to end. That is, we construct the list in order, always picking
the longest subtoken in our vocabulary that matches a prefix of the
remaining portion of the encoded token.
4. Concatenate these lists. This concatenation is invertible due to the
fact that the trailing underscores indicate when one list is finished.
"""
def __init__(self, filename=None):
"""Initialize and read from a file, if provided.
Args:
filename: filename from which to read vocab. If None, do not load a
vocab
"""
self._alphabet = set()
if filename is not None:
self._load_from_file(filename)
super(SubwordTextEncoder, self).__init__(num_reserved_ids=None)
def encode(self, raw_text):
"""Converts a native string to a list of subtoken ids.
Args:
raw_text: a native string.
Returns:
a list of integers in the range [0, vocab_size)
"""
return self._tokens_to_subtoken_ids(
tokenizer.encode(native_to_unicode(raw_text)))
def decode(self, subtokens):
"""Converts a sequence of subtoken ids to a native string.
Args:
subtokens: a list of integers in the range [0, vocab_size)
Returns:
a native string
"""
return unicode_to_native(
tokenizer.decode(self._subtoken_ids_to_tokens(subtokens)))
def decode_list(self, subtokens):
return [self._subtoken_id_to_subtoken_string(s) for s in subtokens]
@property
def vocab_size(self):
"""The subtoken vocabulary size."""
return len(self._all_subtoken_strings)
def _tokens_to_subtoken_ids(self, tokens):
"""Converts a list of tokens to a list of subtoken ids.
Args:
tokens: a list of strings.
Returns:
a list of integers in the range [0, vocab_size)
"""
ret = []
for token in tokens:
ret.extend(
self._escaped_token_to_subtoken_ids(
_escape_token(token, self._alphabet)))
return ret
def _subtoken_ids_to_tokens(self, subtokens):
"""Converts a list of subtoken ids to a list of tokens.
Args:
subtokens: a list of integers in the range [0, vocab_size)
Returns:
a list of strings.
"""
concatenated = "".join(
[self._subtoken_id_to_subtoken_string(s) for s in subtokens])
split = concatenated.split("_")
return [_unescape_token(t + "_") for t in split if t]
def _subtoken_id_to_subtoken_string(self, subtoken):
"""Converts a subtoken integer ID to a subtoken string."""
if 0 <= subtoken < self.vocab_size:
return self._all_subtoken_strings[subtoken]
return u""
def _escaped_token_to_subtoken_strings(self, escaped_token):
"""Converts an escaped token string to a list of subtoken strings.
Args:
escaped_token: An escaped token as a unicode string.
Returns:
A list of subtokens as unicode strings.
"""
# NOTE: This algorithm is greedy; it won't necessarily produce the "best"
# list of subtokens.
ret = []
start = 0
token_len = len(escaped_token)
while start < token_len:
for end in xrange(
min(token_len, start + self._max_subtoken_len), start, -1):
subtoken = escaped_token[start:end]
if subtoken in self._subtoken_string_to_id:
ret.append(subtoken)
start = end
break
else: # Did not break
# If there is no possible encoding of the escaped token then one of the
# characters in the token is not in the alphabet. This should be
# impossible and would be indicative of a bug.
assert False, "Token substring not found in subtoken vocabulary."
return ret
def _escaped_token_to_subtoken_ids(self, escaped_token):
"""Converts an escaped token string to a list of subtoken IDs.
Args:
escaped_token: An escaped token as a unicode string.
Returns:
A list of subtoken IDs as integers.
"""
return [
self._subtoken_string_to_id[subtoken]
for subtoken in self._escaped_token_to_subtoken_strings(escaped_token)
]
@classmethod
def build_to_target_size(cls,
target_size,
token_counts,
min_val,
max_val,
num_iterations=4):
"""Builds a SubwordTextEncoder that has `vocab_size` near `target_size`.
Uses simple recursive binary search to find a minimum token count that most
closely matches the `target_size`.
Args:
target_size: Desired vocab_size to approximate.
token_counts: A dictionary of token counts, mapping string to int.
min_val: An integer; lower bound for the minimum token count.
max_val: An integer; upper bound for the minimum token count.
num_iterations: An integer; how many iterations of refinement.
Returns:
A SubwordTextEncoder instance.
Raises:
ValueError: If `min_val` is greater than `max_val`.
"""
if min_val > max_val:
raise ValueError("Lower bound for the minimum token count "
"is greater than the upper bound.")
if target_size < 1:
raise ValueError("Target size must be positive.")
def bisect(min_val, max_val):
"""Bisection to find the right size."""
present_count = (max_val + min_val) // 2
tf.logging.info("Trying min_count %d" % present_count)
subtokenizer = cls()
subtokenizer.build_from_token_counts(token_counts, present_count,
num_iterations)
# Debugging the encoder
print("Current established vocab_size is: " + str(subtokenizer.vocab_size) + ", with min_cout of " + str(present_count) )
# Being within 1% of the target size is ok.
is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size
# If min_val == max_val, we can't do any better than this.
if is_ok or min_val >= max_val or present_count < 2:
return subtokenizer
if subtokenizer.vocab_size > target_size:
other_subtokenizer = bisect(present_count + 1, max_val)
else:
other_subtokenizer = bisect(min_val, present_count - 1)
if other_subtokenizer is None:
return subtokenizer
if (abs(other_subtokenizer.vocab_size - target_size) <
abs(subtokenizer.vocab_size - target_size)):
return other_subtokenizer
return subtokenizer
return bisect(min_val, max_val)
def build_from_token_counts(self,
token_counts,
min_count,
num_iterations=4,
num_reserved_ids=NUM_RESERVED_TOKENS):
"""Train a SubwordTextEncoder based on a dictionary of word counts.
Args:
token_counts: a dictionary of Unicode strings to int.
min_count: an integer - discard subtokens with lower counts.
num_iterations: an integer. how many iterations of refinement.
num_reserved_ids: an integer. how many ids to reserve for special tokens.
Raises:
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it
is not clear what the space is being reserved for, or when it will be
filled in.
"""
# Initialize the alphabet. Note, this must include reserved tokens or it can
# result in encoding failures.
if num_reserved_ids == NUM_RESERVED_TOKENS:
alphabet_tokens = chain(six.iterkeys(token_counts),
[native_to_unicode(t) for t in RESERVED_TOKENS])
elif num_reserved_ids == 0:
alphabet_tokens = six.iterkeys(token_counts)
else:
raise ValueError("Unexpected value for reserved. What is being reserved?")
self._init_alphabet_from_tokens(alphabet_tokens)
# Bootstrap the initial list of subtokens with the characters from the
# alphabet plus the escaping characters.
self._init_subtokens_from_list(
list(self._alphabet), reserved=num_reserved_ids)
# We build iteratively. On each iteration, we segment all the words,
# then count the resulting potential subtokens, keeping the ones
# with high enough counts for our new vocabulary.
if min_count < 1:
min_count = 1
for i in xrange(num_iterations):
tf.logging.info("Iteration {0}".format(i))
# Collect all substrings of the encoded token that break along current
# subtoken boundaries.
subtoken_counts = collections.defaultdict(int)
for token, count in six.iteritems(token_counts):
escaped_token = _escape_token(token, self._alphabet)
subtokens = self._escaped_token_to_subtoken_strings(escaped_token)
start = 0
for subtoken in subtokens:
for end in xrange(start + 1, len(escaped_token) + 1):
new_subtoken = escaped_token[start:end]
subtoken_counts[new_subtoken] += count
start += len(subtoken)
# Array of sets of candidate subtoken strings, by length.
len_to_subtoken_strings = []
for subtoken_string, count in six.iteritems(subtoken_counts):
lsub = len(subtoken_string)
if count >= min_count:
while len(len_to_subtoken_strings) <= lsub:
len_to_subtoken_strings.append(set())
len_to_subtoken_strings[lsub].add(subtoken_string)
# Consider the candidates longest to shortest, so that if we accept
# a longer subtoken string, we can decrement the counts of its prefixes.
new_subtoken_strings = []
for lsub in xrange(len(len_to_subtoken_strings) - 1, 0, -1):
subtoken_strings = len_to_subtoken_strings[lsub]
for subtoken_string in subtoken_strings:
count = subtoken_counts[subtoken_string]
if count >= min_count:
# Exclude alphabet tokens here, as they must be included later,
# explicitly, regardless of count.
if subtoken_string not in self._alphabet:
new_subtoken_strings.append((count, subtoken_string))
for l in xrange(1, lsub):
subtoken_counts[subtoken_string[:l]] -= count
# Include the alphabet explicitly to guarantee all strings are encodable.
new_subtoken_strings.extend((subtoken_counts.get(a, 0), a)
for a in self._alphabet)
new_subtoken_strings.sort(reverse=True)
# Reinitialize to the candidate vocabulary.
self._init_subtokens_from_list(
[subtoken for _, subtoken in new_subtoken_strings],
reserved=num_reserved_ids)
tf.logging.info("vocab_size = %d" % self.vocab_size)
def dump(self):
"""Debugging dump of the current subtoken vocabulary."""
subtoken_strings = [(i, s)
for s, i in six.iteritems(self._subtoken_string_to_id)]
print(u", ".join(u"{0} : '{1}'".format(i, s)
for i, s in sorted(subtoken_strings)))
def _init_subtokens_from_list(self, subtoken_strings, reserved=0):
"""Initialize token information from a list of subtoken strings.
Args:
subtoken_strings: a list of subtokens
reserved: number of spaces to save at the beginning for reserved tokens
Raises:
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it
is not clear what the space is being reserved for, or when it will be
filled in.
"""
if reserved == 0:
self._all_subtoken_strings = subtoken_strings
elif reserved == len(RESERVED_TOKENS):
self._all_subtoken_strings = RESERVED_TOKENS + subtoken_strings
else:
# TODO(dtarlow): or should we fall back to the previous behavior and
# insert copies of "" for each reserved count?
raise ValueError("Unexpected value for reserved. What is being reserved?")
# we remember the maximum length of any subtoken to avoid having to
# check arbitrarily long strings.
self._max_subtoken_len = max([len(s) for s in subtoken_strings])
self._subtoken_string_to_id = {
s: i + reserved
for i, s in enumerate(subtoken_strings) if s
}
def _init_alphabet_from_tokens(self, tokens):
"""Initialize alphabet from an iterable of token or subtoken strings."""
# Include all characters from all tokens in the alphabet to guarantee that
# any token can be encoded. Additionally, include all escaping characters.
self._alphabet = {c for token in tokens for c in token}
self._alphabet |= _ESCAPE_CHARS
def _load_from_file_object(self, f):
"""Load from a file object.
Args:
f: File object to load vocabulary from
"""
subtoken_strings = []
for line in f:
s = line.strip()
# Some vocab files wrap words in single quotes, but others don't
if ((s.startswith("'") and s.endswith("'")) or
(s.startswith("\"") and s.endswith("\""))):
s = s[1:-1]
subtoken_strings.append(native_to_unicode(s))
self._init_subtokens_from_list(subtoken_strings)
self._init_alphabet_from_tokens(subtoken_strings)
def _load_from_file(self, filename):
"""Load from a file.
Args:
filename: Filename to load vocabulary from
"""
with tf.gfile.Open(filename) as f:
self._load_from_file_object(f)
def store_to_file(self, filename):
with tf.gfile.Open(filename, "w") as f:
for subtoken_string in self._all_subtoken_strings:
f.write("'" + unicode_to_native(subtoken_string) + "'\n")