Skip to content
This repository has been archived by the owner on Feb 25, 2022. It is now read-only.

Commit

Permalink
Merge pull request #230 from nostalgebraist/tfrecords-prepend-fix
Browse files Browse the repository at this point in the history
Fix trailing token bug in create_tfrecords
  • Loading branch information
StellaAthena committed Aug 28, 2021
2 parents afe5e69 + b12f6b2 commit d741ddf
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions data/create_tfrecords.py
Expand Up @@ -106,7 +106,7 @@ def split_list(l, n):
return [l[i:i + n] for i in range(0, len(l), n)]


def archive_to_tokens(f, encoder, args):
def archive_to_tokens(f, encoder, args, prefix=[]):
# Generator that yields the contents of the files in an archive
# if data_to_prepend is not None, prepend data_to_prepend + a EOS separator to the encoded data
reader = Reader(f)
Expand All @@ -116,7 +116,8 @@ def archive_to_tokens(f, encoder, args):
if args.wikitext_detokenize:
doc = wikitext_detokenizer(doc)
doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token
yield split_list(doc, args.chunk_size) # split into n_ctx + 1 size chunks
yield split_list(prefix + doc, args.chunk_size) # split into n_ctx + 1 size chunks
prefix = []


def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None):
Expand Down Expand Up @@ -189,24 +190,21 @@ def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_c
tokenized_files_array = []

for f in files:
for tokenized_files in archive_to_tokens(f, enc, args):
for tokenized_files in archive_to_tokens(f, enc, args, prefix=data_to_prepend):
files_processed += 1
if files_processed < resume_files_processed:
continue # resume from checkpoint

# if the last chunk < chunk size, but > minimum_size, take it and append it to the beginning of the next file
data_to_prepend = []
n_tokens = len(tokenized_files[-1])
if n_tokens < args.chunk_size:
data = tokenized_files.pop(-1)
if n_tokens >= args.minimum_size:
data_to_prepend.extend(data)
data_to_prepend = data
else:
discarded_files += 1

if len(data_to_prepend) >= args.chunk_size:
# if length of data_to_prepend becomes greater than chunk size, add concatted files to tokenized files
tokenized_files_array.append(data_to_prepend[:args.chunk_size])
data_to_prepend = data_to_prepend[args.chunk_size:]
# add tokenized files > chunk size to main array
tokenized_files_array.extend(tokenized_files)

Expand Down

0 comments on commit d741ddf

Please sign in to comment.