Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Feature/mlm repl #915

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 20 additions & 2 deletions layers/eight_mile/pytorch/serialize.py
Expand Up @@ -91,6 +91,7 @@
'bert.embeddings.LayerNorm.gamma': 'embeddings.reduction.ln.weight',
'bert.embeddings.LayerNorm.bias': 'embeddings.reduction.ln.bias',
'bert.embeddings.LayerNorm.weight': 'embeddings.reduction.ln.weight',
'bert.cls.predictions.bias': 'output_layer.bias'
}

ROBERTA_HF_LAYER_MAP = {
Expand Down Expand Up @@ -162,12 +163,16 @@ def convert_transformers_keys(num_layers: int, d: Dict, nested_layer_map: Dict =
try:
m[v.format(i)] = d[k.format(i)]
except:
print(f"Bad key. Skipping {k.format(i)}")
# If its called alpha and beta, this key will be skipped and is not error worthy
if not 'LayerNorm.weight' in k and not 'LayerNorm.bias' in k:
print(f"Bad key. Skipping {k.format(i)}")
for k, v in flat_map.items():
try:
m[v] = d[k]
except:
print(f"Bad key. Skipping {k}")
# If its called alpha and beta, this key will be skipped and is not error worthy
if not 'LayerNorm.weight' in k and not 'LayerNorm.bias' in k:
print(f"Bad key. Skipping {k}")

return m

Expand Down Expand Up @@ -453,6 +458,8 @@ def to_tlm_array(pytorch_tlm: nn.Module, embeddings_keys: List[str] = None, name

if hasattr(pytorch_tlm.embeddings.reduction, 'ln'):
d.update(to_weight_array(pytorch_tlm.embeddings.reduction.ln, name=f"{name}/Embeddings/reduction/ln"))


return d


Expand All @@ -467,6 +474,12 @@ def save_tlm_npz(pytorch_tlm: nn.Module, npz: str, embeddings_keys: List[str] =
:return: None
"""
d = to_tlm_array(pytorch_tlm, embeddings_keys, name)
# This might not be the best way to do this, but it should work
# we dont want to put it in to_tlm_array because there are other cases where we need this to be something else
if hasattr(pytorch_tlm, 'output_layer') and hasattr(pytorch_tlm.output_layer, 'bias') and pytorch_tlm.output_layer.bias != None:
bias = pytorch_tlm.output_layer.bias.cpu().detach().numpy()
d.update({f"{name}/output/bias": bias})

if verbose:
print(d.keys())
np.savez(npz, **d)
Expand Down Expand Up @@ -774,6 +787,11 @@ def load_tlm_npz(pytorch_tlm: nn.Module, npz: str, embeddings_keys: List[str] =
d = np.load(npz)
from_tlm_array(pytorch_tlm, d, embeddings_keys, name)


if hasattr(pytorch_tlm, 'output_layer') and hasattr(pytorch_tlm.output_layer, 'bias') and pytorch_tlm.output_layer.bias != None:
device = pytorch_tlm.output_layer.bias.device
pytorch_tlm.output_layer.bias = nn.Parameter(torch.from_numpy(d[f"{name}/output/bias"]).to(device=device), requires_grad=True)

def load_tlm_output_npz(pytorch_tlm: nn.Module, npz: str, embeddings_keys: List[str] = None, name: str = "TLM"):
"""Restore a TLM-like model (possibly a `nn.Module` for fine-tuning

Expand Down
2 changes: 2 additions & 0 deletions mead/api_examples/convert_hf2npz.py
Expand Up @@ -106,6 +106,8 @@ def create_transformer_lm(config_url: str) -> Tuple[TransformerMaskedLanguageMod
embeddings_dropout=pdrop,
dropout=pdrop,
activation=activation,
output_bias=True,
layer_norm_eps=layer_norm_eps,
layer_norms_after=True,
embeddings_reduction='sum-layer-norm')
return model, num_layers
Expand Down
83 changes: 64 additions & 19 deletions mead/api_examples/generate_mlm.py
Expand Up @@ -9,7 +9,7 @@
from eight_mile.pytorch.serialize import tlm_load_state_dict, load_tlm_npz
from baseline.pytorch.lm import TransformerMaskedLanguageModel
from eight_mile.utils import str2bool, read_json, Offsets, revlut
from baseline.vectorizers import Token1DVectorizer, BPEVectorizer1D
from baseline.vectorizers import Token1DVectorizer, BPEVectorizer1D, WordpieceVectorizer1D
from baseline.pytorch.embeddings import *
from mead.api_examples.transformer_utils import find_latest_checkpoint
logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -52,7 +52,8 @@ def decode_sentence(model, vectorizer, query, word2index, index2word, device, sa
return words


def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name, activation):
def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name,
activation, layer_norm_eps, layer_norms_after, embeddings_reduction, output_bias):
rpr_k = listify(rpr_k)

if len(rpr_k) == 0 or rpr_k[0] < 1:
Expand All @@ -63,6 +64,7 @@ def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_va
logger.info("Creating tied encoder decoder model")
model = TransformerMaskedLanguageModel.create({'x': embeddings},
hsz=d_model,
embeddings_reduction=embeddings_reduction,
d_ff=d_ff,
tie_weights=True,
dropout=0,
Expand All @@ -72,7 +74,10 @@ def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_va
rpr_k=rpr_k,
rpr_value_on=rpr_value_on,
d_k=d_k,
layer_norm_eps=layer_norm_eps,
layer_norms_after=layer_norms_after,
activation=activation,
output_bias=output_bias,
src_keys=['x'], tgt_key='x')
if checkpoint_name.endswith('npz'):
load_tlm_npz(model, checkpoint_name)
Expand All @@ -85,12 +90,22 @@ def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_va
return model


def get_subword_vec1d(type):
if type == 'bpe':
return BPEVectorizer1D
elif type == 'wordpiece':
return WordpieceVectorizer1D
else:
from baseline.vectorizers import SentencePieceVectorizer1D
return SentencePieceVectorizer1D


def main():
parser = ArgumentParser()
parser.add_argument("--basedir", type=str)
parser.add_argument("--checkpoint", type=str, help='Checkpoint name or directory to load')
parser.add_argument("--sample", type=str2bool, help='Sample from the decoder? Defaults to `false`', default=0)
parser.add_argument("--query", type=str, default='hello , <unk> are you today ?')
parser.add_argument("--query", type=str)
parser.add_argument("--dataset_cache", type=str, default=os.path.expanduser('~/.bl-data'),
help="Path or url of the dataset cache")
parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)")
Expand All @@ -100,12 +115,17 @@ def main():
parser.add_argument("--num_layers", type=int, default=8, help="Number of layers")
parser.add_argument("--nctx", type=int, default=128, help="Max context length (for both encoder and decoder)")
parser.add_argument("--embed_type", type=str, default='default',
help="register label of the embeddings, so far support positional or learned-positional")
parser.add_argument("--subword_model_file", type=str, required=True)
parser.add_argument("--subword_vocab_file", type=str, required=True)
parser.add_argument("--use_cls", type=str2bool, default=False)
help="register label of the embeddings")
parser.add_argument("--subword_model_file", type=str, required=False)
parser.add_argument("--subword_vocab_file", type=str, required=False)
parser.add_argument("--subword_type", type=str, choices=["bpe", "wordpiece", "sentencepiece"], default="bpe")
parser.add_argument("--rpr_value_on", type=str2bool, default=False)
parser.add_argument('--end_token', default='<EOU>')
parser.add_argument('--begin_token', default='[CLS]')
parser.add_argument('--output_bias', default=False, type=str2bool)
parser.add_argument('--embeddings_reduction', default='sum')
parser.add_argument("--layer_norms_after", type=str2bool, default=False, help="Layer norms after (set True for BERT)")
parser.add_argument('--layer_norm_eps', default=1e-6, type=float)
parser.add_argument("--activation", type=str, default='gelu')
parser.add_argument('--rpr_k', help='Relative attention positional sizes pass 0 if you dont want relative attention',
type=int, default=[8], nargs='+')
Expand All @@ -117,33 +137,58 @@ def main():
help="Device (cuda or cpu)")
args = parser.parse_args()

if torch.cuda.device_count() == 1:
torch.cuda.set_device(0)
args.device = torch.device("cuda", 0)


if os.path.isdir(args.checkpoint):
checkpoint, _ = find_latest_checkpoint(args.checkpoint)
logger.warning("Found latest checkpoint %s", checkpoint)
else:
checkpoint = args.checkpoint

cls = None if not args.use_cls else '[CLS]'
end = args.end_token
vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=cls, emit_end_tok=end, extra_tokens=args.extra_tokens)
Vec1D = get_subword_vec1d(args.subword_type)
vectorizer = Vec1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file,
mxlen=args.nctx, emit_begin_tok=args.begin_token, emit_end_tok=args.end_token, extra_tokens=args.extra_tokens)

vocab = vectorizer.vocab.copy()
# If we are not using chars, then use 'x' for both input and output
preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type, preserve_vocab_indices=True)
embeddings = preproc_data['embeddings']
vocab = preproc_data['vocab']
model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, num_heads=args.num_heads, num_layers=args.num_layers,
rpr_k=args.rpr_k, rpr_value_on=args.rpr_value_on, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation)
rpr_k=args.rpr_k, rpr_value_on=args.rpr_value_on,
d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation,
layer_norm_eps=args.layer_norm_eps, layer_norms_after=args.layer_norms_after,
embeddings_reduction=args.embeddings_reduction, output_bias=args.output_bias)
model.to(args.device)


index2word = revlut(vocab)
print('[Query]', args.query)
bpe_out = decode_sentence(model, vectorizer, args.query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only)

print('[Response]', ' '.join(bpe_out))
if args.query:
print('[Query]', args.query)
bpe_out = decode_sentence(model, vectorizer, args.query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only)
print('[Response]', ' '.join(bpe_out))
return

from prompt_toolkit import prompt

from prompt_toolkit.history import FileHistory
prompt_name='->> '
history_file='.history'
history = FileHistory(history_file)
while True:
query = prompt(prompt_name, history=history)
query = query.strip()
if query == 'quit':
break
if query == ':sample':
args.sample = True
print("Turn sampling mode on")
continue
if query == ':max':
args.sample = False
print("Turn sampling mode off")
continue
print('[Query]', query)
bpe_out = decode_sentence(model, vectorizer, query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only)
print('[Response]', ' '.join(bpe_out))

main()