Skip to content

Commit

Permalink
Add Beam Search Decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel committed Mar 5, 2017
1 parent 20a4e31 commit 4584baa
Showing 1 changed file with 175 additions and 10 deletions.
185 changes: 175 additions & 10 deletions labs/06_seq2seq/Translation_of_Numeric_Phrases_with_Seq2Seq.ipynb
Expand Up @@ -9,7 +9,13 @@
"source": [
"# Translation of Numeric Phrases with Seq2Seq\n",
"\n",
"In the following we will try to build a translation model from french phrases describing numbers to the corresponding digital representation (base 10).\n",
"In the following we will try to build a **translation model from french phrases describing numbers** to the corresponding **numeric representation** (base 10).\n",
"\n",
"This is a toy machine translation with a **restricted vocabulary** and a **single valid translation for each source phrase** which makes it more tractable to train on a laptop computer and easier to evaluate. Despite those limitations we expect that this task highlight interested properties of Seq2Seq models including\n",
"\n",
"- the ability to deal with source and target sequences with different lengths,\n",
"- token with a meaning that changes depending on the context (e.g \"quatre\" vs \"quatre vingts\" in \"quatre cents\"),\n",
"- basic counting and \"reasoning\" capabilities of LSTM and GRU models.\n",
"\n",
"The parallel text data is generated from a \"ground-truth\" Python function named `to_french_phrase` that captures common rules from the French language except hypenation to make the French strings more ambiguous:"
]
Expand Down Expand Up @@ -40,9 +46,9 @@
"source": [
"## Generating a Training Set\n",
"\n",
"The following will generate phrases 20000 example phrases for numbers between 1 and 1,000,000 (excluded). It will over-represent small numbers by generating all the possible short sequences between 1 and `exhaustive`.\n",
"The following will **generate phrases 20000 example phrases for numbers between 1 and 1,000,000** (excluded). We chose to over-represent small numbers by generating all the possible short sequences between 1 and `exhaustive`.\n",
"\n",
"Let's split the generated set into non-overlapping train, validation and test splits."
"We then split the generated set into non-overlapping train, validation and test splits."
]
},
{
Expand Down Expand Up @@ -1004,7 +1010,7 @@
"source": [
"def greedy_translate(model, source_sequence, shared_vocab, rev_shared_vocab,\n",
" word_level_source=True, word_level_target=True):\n",
" \"\"\"Greedy decoder recurisvely predicting one token at a time\"\"\"\n",
" \"\"\"Greedy decoder recursively predicting one token at a time\"\"\"\n",
" # Initialize the list of input token ids with the source sequence\n",
" source_tokens = tokenize(source_sequence, word_level=word_level_source)\n",
" input_ids = [shared_vocab.get(t, UNK) for t in source_tokens[::-1]]\n",
Expand Down Expand Up @@ -1120,16 +1126,17 @@
},
"outputs": [],
"source": [
"def phrase_accuracy(model, num_sequences, fr_sequences, n_samples=300):\n",
"def phrase_accuracy(model, num_sequences, fr_sequences, n_samples=300,\n",
" decoder_func=greedy_translate):\n",
" correct = []\n",
" n_samples = len(num_sequences) if n_samples is None else n_samples\n",
" for i, num_seq, fr_seq in zip(range(n_samples), num_sequences, fr_sequences):\n",
" if i % 100 == 0:\n",
" print(\"Decoding %d/%d\" % (i, n_samples))\n",
"\n",
" predicted_seq = greedy_translate(simple_seq2seq, fr_seq,\n",
" shared_vocab, rev_shared_vocab,\n",
" word_level_target=False)\n",
" predicted_seq = decoder_func(simple_seq2seq, fr_seq,\n",
" shared_vocab, rev_shared_vocab,\n",
" word_level_target=False)\n",
" correct.append(num_seq == predicted_seq)\n",
" return np.mean(correct)"
]
Expand All @@ -1144,7 +1151,7 @@
},
"outputs": [],
"source": [
"print(\"Phrase-level test accuracy: %0.2f\"\n",
"print(\"Phrase-level test accuracy: %0.3f\"\n",
" % phrase_accuracy(simple_seq2seq, num_test, fr_test))"
]
},
Expand All @@ -1158,10 +1165,21 @@
},
"outputs": [],
"source": [
"print(\"Phrase-level train accuracy: %0.2f\"\n",
"print(\"Phrase-level train accuracy: %0.3f\"\n",
" % phrase_accuracy(simple_seq2seq, num_train, fr_train))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Decoding with a Beam Search\n",
"\n",
"Instead of decoding with greedy strategy that only considers the most-likely next token at each prediction, we can hold a priority queue of the most promising top-n sequences ordered by loglikelihoods.\n",
"\n",
"This could potentially improve the final accuracy of an imperfect model: indeed it can be the case that the most likely sequence (based on the conditional proability estimated by the model) starts with a character that is not the most likely alone."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -1171,6 +1189,153 @@
"editable": true
},
"outputs": [],
"source": [
"def beam_translate(model, source_sequence, shared_vocab, rev_shared_vocab,\n",
" word_level_source=True, word_level_target=True,\n",
" beam_size=10, return_ll=False):\n",
" \"\"\"Decode candidate translations with a beam search strategy\n",
" \n",
" If return_ll is False, only the best candidate string is returned.\n",
" If return_ll is True, all the candidate strings and their loglikelihoods\n",
" are returned.\n",
" \"\"\"\n",
" # Initialize the list of input token ids with the source sequence\n",
" source_tokens = tokenize(source_sequence, word_level=word_level_source)\n",
" input_ids = [shared_vocab.get(t, UNK) for t in source_tokens[::-1]]\n",
" input_ids += [shared_vocab[GO]]\n",
" \n",
" # initialize loglikelihood, input token ids, decoded tokens for\n",
" # each candidate in the beam\n",
" candidates = [(0, input_ids[:], [], False)]\n",
"\n",
" # Prepare a fixed size numpy array that matches the expected input\n",
" # shape for the model\n",
" input_array = np.empty(shape=(beam_size, simple_seq2seq.input_shape[1]),\n",
" dtype=np.int32)\n",
" while any([not done and (len(input_ids) < max_length)\n",
" for _, input_ids, _, done in candidates]):\n",
" # Vectorize a the list of input tokens and use zeros padding.\n",
" input_array.fill(shared_vocab[PAD])\n",
" for i, (_, input_ids, _, done) in enumerate(candidates):\n",
" if not done:\n",
" input_array[i, -len(input_ids):] = input_ids\n",
" \n",
" # Predict the next output in a single call to the model to amortize\n",
" # the overhead and benefit from vector data parallelism on GPU.\n",
" next_likelihood_batch = model.predict(input_array)\n",
" \n",
" # Build the new candidates list by summing the loglikelood of the\n",
" # next token with their parents for each new possible expansion.\n",
" new_candidates = []\n",
" for i, (ll, input_ids, decoded, done) in enumerate(candidates):\n",
" if done:\n",
" new_candidates.append((ll, input_ids, decoded, done))\n",
" else:\n",
" next_loglikelihoods = np.log(next_likelihood_batch[i, -1])\n",
" for next_token_id, next_ll in enumerate(next_loglikelihoods):\n",
" new_ll = ll + next_ll\n",
" new_input_ids = input_ids[:]\n",
" new_input_ids.append(next_token_id)\n",
" new_decoded = decoded[:]\n",
" new_done = done\n",
" if next_token_id == shared_vocab[EOS]:\n",
" new_done = True\n",
" if not new_done:\n",
" new_decoded.append(rev_shared_vocab[next_token_id])\n",
" new_candidates.append(\n",
" (new_ll, new_input_ids, new_decoded, new_done))\n",
" \n",
" # Only keep a beam of the most promising candidates\n",
" new_candidates.sort(reverse=True)\n",
" candidates = new_candidates[:beam_size]\n",
"\n",
" separator = \" \" if word_level_target else \"\"\n",
" if return_ll:\n",
" return [(separator.join(decoded), ll) for ll, _, decoded, _ in candidates]\n",
" else:\n",
" _, _, decoded, done = candidates[0]\n",
" return separator.join(decoded)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"candidates = beam_translate(simple_seq2seq, \"cent mille un\",\n",
" shared_vocab, rev_shared_vocab,\n",
" word_level_target=False,\n",
" return_ll=True, beam_size=10)\n",
"candidates"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"candidates = beam_translate(simple_seq2seq, \"quatre vingts\",\n",
" shared_vocab, rev_shared_vocab,\n",
" word_level_target=False,\n",
" return_ll=True, beam_size=10)\n",
"candidates"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Accuracy with Beam Search Decoding"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print(\"Phrase-level test accuracy: %0.3f\"\n",
" % phrase_accuracy(simple_seq2seq, num_test, fr_test,\n",
" decoder_func=beam_translate))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print(\"Phrase-level train accuracy: %0.3f\"\n",
" % phrase_accuracy(simple_seq2seq, num_train, fr_train,\n",
" decoder_func=beam_translate))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When using the partially trained model the test phrase-level is slightly better (0.38 vs 0.37) with the beam decoder than with greedy decoder but the change is not that important on such a toy dataset. Training the model to covergence would yield perfect scores anyway.\n",
"\n",
"Properly tuned beam search decoding can be critical to improve the quality of Machine Translation systems trained on natural language pairs though."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
Expand Down

0 comments on commit 4584baa

Please sign in to comment.