Skip to content

Latest commit

 

History

History
481 lines (313 loc) · 36.1 KB

File metadata and controls

481 lines (313 loc) · 36.1 KB

本文涉及的jupter notebook在篇章4代码库中

建议直接使用google colab notebook打开本教程,可以快速下载相关数据集和模型。 如果您正在google的colab中打开这个notebook,您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。

! pip install datasets transformers rouge-score nltk

分布式训练请查看 这里.

微调transformer模型解决摘要生成任务

在本notebook中,我们将展示如何微调 🤗 Transformers中的预训练模型来解决摘要生成任务。我们使用XSum dataset数据集。这个数据集包含了BBC的文章和一句对应的摘要。下面是一个例子:

Widget inference on a summarization task

对于摘要生成任务,我们将展示如何使用简单的加载数据集,同时针对相应的仍无使用transformer中的Trainer接口对模型进行微调。

model_checkpoint = "t5-small"

只要预训练的transformer模型包含seq2seq结构的head层,那么本notebook理论上可以使用各种各样的transformer模型,解决任何摘要生成任务。这里,我们使用t5-small模型checkpoint。

加载数据

我们将会使用🤗 Datasets库来加载数据和对应的评测方式。数据加载和评测方式加载只需要简单使用load_dataset和load_metric即可。

from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

这个datasets对象本身是一种DatasetDict数据结构. 对于训练集、验证集和测试集,只需要使用对应的key(train,validation,test)即可得到相应的数据。

raw_datasets
DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

给定一个数据切分的key(train、validation或者test)和下标即可查看数据:

raw_datasets["train"][0]
{'document': 'Recent reports have linked some France-based players with returns to Wales.\n"I\'ve always felt - and this is with my rugby hat on now; this is not region or WRU - I\'d rather spend that money on keeping players in Wales," said Davies.\nThe WRU provides £2m to the fund and £1.3m comes from the regions.\nFormer Wales and British and Irish Lions fly-half Davies became WRU chairman on Tuesday 21 October, succeeding deposed David Pickering following governing body elections.\nHe is now serving a notice period to leave his role as Newport Gwent Dragons chief executive after being voted on to the WRU board in September.\nDavies was among the leading figures among Dragons, Ospreys, Scarlets and Cardiff Blues officials who were embroiled in a protracted dispute with the WRU that ended in a £60m deal in August this year.\nIn the wake of that deal being done, Davies said the £3.3m should be spent on ensuring current Wales-based stars remain there.\nIn recent weeks, Racing Metro flanker Dan Lydiate was linked with returning to Wales.\nLikewise the Paris club\'s scrum-half Mike Phillips and centre Jamie Roberts were also touted for possible returns.\nWales coach Warren Gatland has said: "We haven\'t instigated contact with the players.\n"But we are aware that one or two of them are keen to return to Wales sooner rather than later."\nSpeaking to Scrum V on BBC Radio Wales, Davies re-iterated his stance, saying keeping players such as Scarlets full-back Liam Williams and Ospreys flanker Justin Tipuric in Wales should take precedence.\n"It\'s obviously a limited amount of money [available]. The union are contributing 60% of that contract and the regions are putting £1.3m in.\n"So it\'s a total pot of just over £3m and if you look at the sorts of salaries that the... guys... have been tempted to go overseas for [are] significant amounts of money.\n"So if we were to bring the players back, we\'d probably get five or six players.\n"And I\'ve always felt - and this is with my rugby hat on now; this is not region or WRU - I\'d rather spend that money on keeping players in Wales.\n"There are players coming out of contract, perhaps in the next year or so… you\'re looking at your Liam Williams\' of the world; Justin Tipuric for example - we need to keep these guys in Wales.\n"We actually want them there. They are the ones who are going to impress the young kids, for example.\n"They are the sort of heroes that our young kids want to emulate.\n"So I would start off [by saying] with the limited pot of money, we have to retain players in Wales.\n"Now, if that can be done and there\'s some spare monies available at the end, yes, let\'s look to bring players back.\n"But it\'s a cruel world, isn\'t it?\n"It\'s fine to take the buck and go, but great if you can get them back as well, provided there\'s enough money."\nBritish and Irish Lions centre Roberts has insisted he will see out his Racing Metro contract.\nHe and Phillips also earlier dismissed the idea of leaving Paris.\nRoberts also admitted being hurt by comments in French Newspaper L\'Equipe attributed to Racing Coach Laurent Labit questioning their effectiveness.\nCentre Roberts and flanker Lydiate joined Racing ahead of the 2013-14 season while scrum-half Phillips moved there in December 2013 after being dismissed for disciplinary reasons by former club Bayonne.',
 'id': '29750031',
 'summary': 'New Welsh Rugby Union chairman Gareth Davies believes a joint £3.3m WRU-regions fund should be used to retain home-based talent such as Liam Williams, not bring back exiled stars.'}

为了能够进一步理解数据长什么样子,下面的函数将从数据集里随机选择几个例子进行展示。

import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=2):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))
show_random_elements(raw_datasets["train"])
document id summary
0 Media playback is not supported on this device\nThe 31-year-old Portugal captain has the chance to win the first international title of his glittering career when his side take on the tournament hosts at Stade de France.\n"Everyone in the country will be against him, but he will thrive off that hostility, and off their fear," Ferdinand told BBC Sport.\n"The French know Cristiano is the player capable of destroying their dream because he has produced magic moments in huge matches right through his career.\n"Part of the reason he is a superstar is because he is not fazed by the big occasion - quite the opposite in fact.\n"Superstars like him relish these situations - the pressure that goes with it brings the best out of him, when other players falter."\nRonaldo helped Portugal reach the final when they hosted Euro 2004, but was left crying on the pitch after defeat by Greece, and they have not reached this stage of a major tournament since.\n"He was only 19 then, so just a kid," added Ferdinand, who played alongside Ronaldo at Old Trafford between 2003 and 2009 and is working in France as a BBC pundit.\n"I remember his reaction but I think he was a bit too young to take it all in. At that age, you expect you will get back to another final soon to rectify what happened but he has had to wait 12 years for his chance.\n"I think he is very aware this is his last opportunity to win something with his country and, knowing him like I do, that makes him even more dangerous. He will be so desperate not to miss out again.\n"Cristiano has produced great performances for Portugal when it matters before, for example his hat-trick against Sweden in the play-off for the 2014 World Cup.\n"So France will know that it is not just in a Real Madrid or Manchester United shirt that he is capable of great things, especially because it was his moment of brilliance that helped decide Portugal's semi-final against Wales."\nRonaldo is now the joint highest scorer in European Championship history with nine goals, and holds the record for most headed goals with five.\nTwo of them have come in France, most notably when he soared to nod Portugal in front against Wales, but Ferdinand says that part of Ronaldo's game is nothing new.\n"He has always been amazing on the ball but even when he first joined United in 2003 he was great in the air too," he said.\n"Early in his career it was a part of his game that was quite undervalued but he always scored a lot of headers and, the way he does it, he is the closest thing in football to basketball legend Michael Jordan.\n"The way he jumps and hangs in the air is the same as Jordan and he has got the ability to stay up there, assess the situation and then put the ball where he wants to, with power.\n"I will always remember the header he scored for United away at Roma in the Champions League quarter-finals in April 2008.\n"He more or less jumped on the edge of the box to meet a cross that Paul Scholes put over but he met the ball a good way inside the area. If you watch TV footage of that game, he just appears from nowhere and smashes it into the bottom corner.\n"Just like his goal against Wales, it was an unbelievable jump and he generated incredible power. I was on the pitch that night, and it was amazing to see.\n"Cristiano's heading ability will be a huge threat in the final too, along with Nani - another former United team-mate of mine.\n"France have struggled to defend crosses for most of the tournament and, although they were better at it in the semi-final, Germany did not have anyone to aim at in the box."\nEighty-six players at Euro 2016 have completed more dribbles than the three Ronaldo has managed in six matches.\n"He was always well known for his brilliant runs forward but his game is not about that any more," said Ferdinand, who was in the same United team as Ronaldo when he scored 42 goals in the 2007-08 season.\n"Before, he used to exert a lot of energy trying to take people on from deep areas, running at goal from 30 or 40 yards, or even further out.\n"Now, he is very clever in where he tries to receive the ball. It has to be in good positions and, when he gets it, he finds a yard of space and hits it - either a shot or a cross.\n"Part of the reason he has been able to reinvent himself is because of how hard he works - right from the start of his career, when we were together at Old Trafford, he was totally committed to improving every part of his game.\n"But to be able to re-evaluate his game and change it is also down to his football intelligence.\n"Clearly he is clever - you do not score 50 goals a season, six seasons running, for Real Madrid if you are not.\n"But his extra intelligence has allowed him to evolve as a player, understand his body, where it can take him and how often.\n"He has become a much more efficient player, but is still an extremely effective one."\nSunday offers Ronaldo the opportunity for personal glory too, with the chance to get one over Lionel Messi in the battle to be viewed as the best player in the world.\nMessi has never won a major tournament for Argentina and announced his international retirement last month after they lost to Chile in the final of the Copa America.\n"If Ronaldo wins the European Championship, it will be massive for him," said Ferdinand, 37.\n"I don't think it will give him the edge over Messi in terms of who deserves the accolade of the best in the world, but it is a huge achievement.\n"And it will matter to both of them, because there is a definite battle between them in their own minds about who has done what for club and country.\n"It is far from a given that Ronaldo will manage it, of course. France are looking very good and they have a game-changer in Antoine Griezmann.\n"Even if Ronaldo is at his best, it is a difficult ask for them and I think they are going to have to play ugly, like they have done all the way through the tournament, to win.\n"I have known him a long time and I would love to see him do it, but it is awkward for me because I have friends in both camps - Nani and Cristiano for Portugal, and Paul Pogba and Patrice Evra for France.\n"So I am not really bothered about the result, I just want to see a good game. I would love to see Ronaldo and Griezmann perform to their potential and finish off this tournament on a high note." 36749142 Cristiano Ronaldo will relish taking on the whole of France in Sunday's Euro 2016 final, according to his former Manchester United team-mate Rio Ferdinand.
1 Media playback is unsupported on your device\n18 December 2014 Last updated at 10:28 GMT\nMalaysia has successfully tackled poverty over the last four decades by drawing on its rich natural resources.\nAccording to the World Bank, some 49% of Malaysians in 1970 were extremely poor, and that figure has been reduced to 1% today. However, the government's next challenge is to help the lower income group to move up to the middle class, the bank says.\nUlrich Zahau, the World Bank's Southeast Asia director, spoke to the BBC's Jennifer Pak. 30530533 In Malaysia the "aspirational" low-income part of the population is helping to drive economic growth through consumption, according to the World Bank.

metric是datasets.Metric类的一个实例,查看metric和使用的例子:

metric
Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_agregator: Return aggregates if this is set to True
Returns:
    rouge1: rouge_1 (precision, recall, f1),
    rouge2: rouge_2 (precision, recall, f1),
    rougeL: rouge_l (precision, recall, f1),
    rougeLsum: rouge_lsum (precision, recall, f1)
Examples:

    >>> rouge = datasets.load_metric('rouge')
    >>> predictions = ["hello there", "general kenobi"]
    >>> references = ["hello there", "general kenobi"]
    >>> results = rouge.compute(predictions=predictions, references=references)
    >>> print(list(results.keys()))
    ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
    >>> print(results["rouge1"])
    AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))
    >>> print(results["rouge1"].mid.fmeasure)
    1.0
""", stored examples: 0)

我们使用compute方法来对比predictions和labels,从而计算得分。predictions和labels都需要是一个list。具体格式见下面的例子:

fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)
{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}

数据预处理

在将数据喂入模型之前,我们需要对数据进行预处理。预处理的工具叫Tokenizer。Tokenizer首先对输入进行tokenize,然后将tokens转化为预模型中需要对应的token ID,再转化为模型需要的输入格式。

为了达到数据预处理的目的,我们使用AutoTokenizer.from_pretrained方法实例化我们的tokenizer,这样可以确保:

  • 我们得到一个与预训练模型一一对应的tokenizer。
  • 使用指定的模型checkpoint对应的tokenizer的时候,我们也下载了模型需要的词表库vocabulary,准确来说是tokens vocabulary。

这个被下载的tokens vocabulary会被缓存起来,从而再次使用的时候不会重新下载。

from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.

tokenizer既可以对单个文本进行预处理,也可以对一对文本进行预处理,tokenizer预处理后得到的数据满足预训练模型输入格式

tokenizer("Hello, this one sentence!")
{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

上面看到的token IDs也就是input_ids一般来说随着预训练模型名字的不同而有所不同。原因是不同的预训练模型在预训练的时候设定了不同的规则。但只要tokenizer和model的名字一致,那么tokenizer预处理的输入格式就会满足model需求的。关于预处理更多内容参考这个教程

除了可以tokenize一句话,我们也可以tokenize一个list的句子。

tokenizer(["Hello, this one sentence!", "This is another sentence."])
{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

注意:为了给模型准备好翻译的targets,我们使用as_target_tokenizer来控制targets所对应的特殊token:

with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))
{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

如果您使用的是T5预训练模型的checkpoints,需要对特殊的前缀进行检查。T5使用特殊的前缀来告诉模型具体要做的任务,具体前缀例子如下:

if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

现在我们可以把所有内容放在一起组成我们的预处理函数了。我们对样本进行预处理的时候,我们还会truncation=True这个参数来确保我们超长的句子被截断。默认情况下,对与比较短的句子我们会自动padding。

max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

以上的预处理函数可以处理一个样本,也可以处理多个样本exapmles。如果是处理多个样本,则返回的是多个样本被预处理之后的结果list。

preprocess_function(raw_datasets['train'][:2])
{'input_ids': [[21603, 10, 17716, 2279, 43, 5229, 128, 1410, 18, 390, 1508, 28, 5146, 12, 10256, 5, 96, 196, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 976, 243, 28571, 5, 37, 549, 8503, 795, 17586, 51, 12, 8, 3069, 11, 3996, 13606, 51, 639, 45, 8, 6266, 5, 18263, 10256, 11, 2390, 11, 7262, 10371, 7, 3971, 18, 17114, 28571, 1632, 549, 8503, 13404, 30, 2818, 1401, 1797, 6, 7229, 53, 20, 12151, 1955, 8356, 49, 53, 826, 3, 19585, 643, 9768, 5, 216, 19, 230, 3122, 3, 9, 2103, 1059, 12, 1175, 112, 1075, 38, 24260, 350, 16103, 10282, 7, 5752, 4297, 227, 271, 3, 11060, 30, 12, 8, 549, 8503, 1476, 16, 1600, 5, 28571, 47, 859, 8, 1374, 5638, 859, 10282, 7, 6, 411, 7, 2026, 63, 7, 6, 14586, 7677, 11, 26911, 2419, 7, 4298, 113, 130, 10960, 52, 26786, 16, 3, 9, 813, 11674, 11044, 28, 8, 549, 8503, 24, 3492, 16, 3, 9, 3996, 3328, 51, 1154, 16, 1660, 48, 215, 5, 86, 8, 7178, 13, 24, 1154, 271, 612, 6, 28571, 243, 8, 3996, 19660, 51, 225, 36, 1869, 30, 3, 5833, 750, 10256, 18, 390, 4811, 2367, 132, 5, 86, 1100, 1274, 6, 16046, 10730, 24397, 49, 2744, 31914, 17, 15, 47, 5229, 28, 7646, 12, 10256, 5, 3, 21322, 8, 1919, 1886, 31, 7, 14667, 440, 18, 17114, 4794, 16202, 7, 11, 2050, 17845, 2715, 7, 130, 92, 2633, 26, 21, 487, 5146, 5, 10256, 3763, 16700, 2776, 17, 40, 232, 65, 243, 10, 96, 1326, 43, 29, 31, 17, 16, 7, 2880, 920, 574, 28, 8, 1508, 5, 96, 11836, 62, 33, 2718, 24, 80, 42, 192, 13, 135, 33, 9805, 12, 1205, 12, 10256, 14159, 1066, 145, 865, 535, 14734, 12, 4712, 2781, 584, 30, 9938, 5061, 10256, 6, 28571, 3, 60, 18, 155, 15, 4094, 112, 3, 8389, 6, 2145, 2627, 1508, 224, 38, 14586, 7677, 423, 18, 1549, 1414, 265, 6060, 11, 411, 7, 2026, 63, 7, 24397, 49, 12446, 2262, 3791, 447, 16, 10256, 225, 240, 20799, 1433, 5, 96, 196, 17, 31, 7, 6865, 3, 9, 1643, 866, 13, 540, 784, 28843, 4275, 37, 7021, 33, 12932, 15436, 13, 24, 1696, 11, 8, 6266, 33, 3, 3131, 3996, 13606, 51, 16, 5, 96, 5231, 34, 31, 7, 3, 9, 792, 815, 13, 131, 147, 23395, 51, 11, 3, 99, 25, 320, 44, 8, 10549, 13, 21331, 24, 8, 233, 3413, 233, 43, 118, 3, 22765, 12, 281, 10055, 21, 784, 355, 908, 1516, 6201, 13, 540, 5, 96, 5231, 3, 99, 62, 130, 12, 830, 8, 1508, 223, 6, 62, 31, 26, 1077, 129, 874, 42, 1296, 1508, 5, 96, 7175, 27, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 5, 96, 7238, 33, 1508, 1107, 91, 13, 1696, 6, 2361, 16, 8, 416, 215, 42, 78, 233, 25, 31, 60, 479, 44, 39, 1414, 265, 6060, 31, 13, 8, 296, 117, 12446, 2262, 3791, 447, 21, 677, 3, 18, 62, 174, 12, 453, 175, 3413, 16, 10256, 5, 96, 1326, 700, 241, 135, 132, 5, 328, 33, 8, 2102, 113, 33, 352, 12, 18514, 8, 1021, 1082, 6, 21, 677, 5, 96, 10273, 33, 8, 1843, 13, 17736, 24, 69, 1021, 1082, 241, 12, 29953, 5, 96, 5231, 27, 133, 456, 326, 784, 969, 2145, 908, 28, 8, 1643, 815, 13, 540, 6, 62, 43, 12, 7365, 1508, 16, 10256, 5, 96, 17527, 6, 3, 99, 24, 54, 36, 612, 11, 132, 31, 7, 128, 8179, 3, 26413, 347, 44, 8, 414, 6, 4273, 6, 752, 31, 7, 320, 12, 830, 1508, 223, 5, 96, 11836, 34, 31, 7, 3, 9, 23958, 296, 6, 19, 29, 31, 17, 34, 58, 96, 196, 17, 31, 7, 1399, 12, 240, 8, 3, 13863, 11, 281, 6, 68, 248, 3, 99, 25, 54, 129, 135, 223, 38, 168, 6, 937, 132, 31, 7, 631, 540, 535, 2390, 11, 7262, 10371, 7, 2050, 2715, 7, 65, 16, 15777, 3, 88, 56, 217, 91, 112, 16046, 10730, 1696, 5, 216, 11, 16202, 7, 92, 2283, 19664, 8, 800, 13, 3140, 1919, 5, 2715, 7, 92, 10246, 271, 4781, 57, 2622, 16, 2379, 29494, 301, 31, 427, 23067, 15, 3, 20923, 12, 16046, 9493, 9906, 17, 325, 2360, 822, 53, 70, 9570, 5, 2969, 2715, 7, 11, 24397, 49, 31914, 17, 15, 3311, 16046, 2177, 13, 8, 2038, 11590, 774, 298, 14667, 440, 18, 17114, 16202, 7, 2301, 132, 16, 1882, 2038, 227, 271, 19664, 21, 3, 15471, 2081, 57, 1798, 1886, 2474, 5993, 5, 1], [21603, 10, 6788, 19758, 7, 2273, 130, 718, 91, 12, 1154, 28, 3, 9, 6220, 2642, 44, 8, 6036, 30, 8, 368, 3540, 986, 7, 2409, 30, 1701, 706, 5, 2409, 7, 130, 16645, 326, 11, 2117, 12355, 1054, 38, 3, 9, 6478, 16813, 47, 4006, 91, 5, 37, 12787, 6, 261, 57, 1932, 27874, 5220, 31195, 3230, 6, 43, 118, 7774, 3, 9, 381, 13, 648, 5, 1377, 1310, 6, 17602, 6417, 6032, 130, 4006, 91, 30, 8, 6036, 30, 12096, 8348, 16, 1186, 11, 932, 5, 37, 6032, 1553, 826, 3, 9, 27874, 896, 2063, 2902, 16, 1882, 1673, 18395, 53, 8, 7070, 13, 8, 7021, 5692, 44, 8, 896, 2501, 5, 1193, 1778, 29, 53, 8, 1251, 3534, 9, 226, 6, 11529, 283, 4569, 4409, 5225, 8692, 243, 10, 96, 196, 17, 19, 3, 9, 2261, 5415, 21, 8, 415, 616, 6, 34, 4110, 2261, 17879, 6, 34, 10762, 151, 31, 7, 1342, 44, 1020, 6, 34, 54, 1709, 3583, 364, 7232, 8, 616, 5, 96, 19494, 62, 174, 151, 28, 251, 12, 698, 24, 28, 8, 2095, 16, 455, 21, 135, 12, 103, 70, 613, 11, 830, 175, 151, 12, 4831, 535, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[368, 22982, 24895, 3545, 13404, 3121, 15, 189, 28571, 7228, 3, 9, 4494, 3996, 19660, 51, 549, 8503, 18, 18145, 7, 3069, 225, 36, 261, 12, 7365, 234, 18, 390, 3683, 224, 38, 1414, 265, 6060, 6, 59, 830, 223, 1215, 699, 26, 4811, 5, 1], [71, 21641, 2642, 646, 1067, 46, 11529, 3450, 828, 16, 5727, 27874, 65, 118, 10126, 3, 9, 3534, 9, 226, 5, 1]]}

接下来对数据集datasets里面的所有样本进行预处理,处理的方式是使用map函数,将预处理函数prepare_train_features应用到(map)所有样本上。

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

更好的是,返回的结果会自动被缓存,避免下次处理的时候重新计算(但是也要注意,如果输入有改动,可能会被缓存影响!)。datasets库函数会对输入的参数进行检测,判断是否有变化,如果没有变化就使用缓存数据,如果有变化就重新处理。但如果输入参数不变,想改变输入的时候,最好清理调这个缓存。清理的方式是使用load_from_cache_file=False参数。另外,上面使用到的batched=True这个参数是tokenizer的特点,以为这会使用多线程同时并行对输入进行处理。

微调模型

既然数据已经准备好了,现在我们需要下载并加载我们的预训练模型,然后微调预训练模型。既然我们是做seq2seq任务,那么我们需要一个能解决这个任务的模型类。我们使用AutoModelForSeq2SeqLM这个类。和tokenizer相似,from_pretrained方法同样可以帮助我们下载并加载模型,同时也会对模型进行缓存,就不会重复下载模型啦。

from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

由于我们微调的任务是seq2seq任务,而我们加载的是预训练的seq2seq模型,所以不会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数(比如:预训练语言模型的神经网络head被扔掉了,同时随机初始化了机器翻译的神经网络head)。

为了能够得到一个Seq2SeqTrainer训练工具,我们还需要3个要素,其中最重要的是训练的设定/参数Seq2SeqTrainingArguments。这个训练设定包含了能够定义训练过程的所有属性

batch_size = 16
args = Seq2SeqTrainingArguments(
    "test-summarization",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
)

上面evaluation_strategy = "epoch"参数告诉训练代码:我们每个epcoh会做一次验证评估。

上面batch_size在这个notebook之前定义好了。

由于我们的数据集比较大,同时Seq2SeqTrainer会不断保存模型,所以我们需要告诉它至多保存save_total_limit=3个模型。

最后我们需要一个数据收集器data collator,将我们处理好的输入喂给模型。

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

设置好Seq2SeqTrainer还剩最后一件事情,那就是我们需要定义好评估方法。我们使用metric来完成评估。将模型预测送入评估之前,我们也会做一些数据后处理:

import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

最后将所有的参数/数据/模型传给Seq2SeqTrainer即可

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

调用train方法进行微调训练。

trainer.train()
<div>
    <style>
        /* Turns off some styling */
        progress {
            /* gets rid of default border in Firefox and Opera. */
            border: none;
            /* Needs to be in here for Safari polyfill so background images work as expected. */
            background-size: auto;
        }
    </style>

  <progress value='12753' max='12753' style='width:300px; height:20px; vertical-align: middle;'></progress>
  [12753/12753 1:21:48, Epoch 1/1]
</div>
<table border="1" class="dataframe">
Epoch Training Loss Validation Loss Rouge1 Rouge2 Rougel Rougelsum Gen Len Runtime Samples Per Second 1 2.721100 2.479327 28.300900 7.721100 22.243000 22.249600 18.822500 326.333800 34.725000

TrainOutput(global_step=12753, training_loss=2.7692033505520146, metrics={'train_runtime': 4909.3835, 'train_samples_per_second': 2.598, 'total_flos': 7.774481450954342e+16, 'epoch': 1.0, 'init_mem_cpu_alloc_delta': 335248, 'init_mem_gpu_alloc_delta': 242026496, 'init_mem_cpu_peaked_delta': 18306, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 2637782, 'train_mem_gpu_alloc_delta': 728138240, 'train_mem_cpu_peaked_delta': 138226182, 'train_mem_gpu_peaked_delta': 14677017088})

最后别忘了,查看如何上传模型 ,上传模型到](https://huggingface.co/transformers/model_sharing.html) 到🤗 Model Hub。随后您就可以像这个notebook一开始一样,直接用模型名字就能使用您的模型啦。