Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

推理和生成相关调研和设计 #265

Open
L1aoXingyu opened this issue Apr 13, 2022 · 7 comments
Open

推理和生成相关调研和设计 #265

L1aoXingyu opened this issue Apr 13, 2022 · 7 comments

Comments

@L1aoXingyu
Copy link
Collaborator

调研了不同的 NLP 库在预测阶段的处理方式

FairSeq

针对生成任务的代码主要在 https://github.com/pytorch/fairseq/blob/main/fairseq/sequence_generator.py

class SequenceGenerator(nn.Module):
    def __init__(
        self,
        models,
        tgt_dict,
        beam_size=1,
        ...
    ):
        """Generates translations of a given source sentence."""
        ...
        
    def _generate(
        self,
        sample: Dict[str, Dict[str, Tensor]],
        prefix_tokens: Optional[Tensor] = None,
        constraints: Optional[Tensor] = None,
        bos_token: Optional[int] = None,
    ):
    	...

针对序列预测任务的代码主要在 https://github.com/pytorch/fairseq/blob/7e758841da9e05cb21826a60d30a563a9e189d1d/fairseq/sequence_scorer.py#L12

class SequenceScorer(object):
   """Scores the target for a given source sentence."""

   def __init__(
       self,
       tgt_dict,
       softmax_batch=None,
       compute_alignment=False,
       eos=None,
       symbols_to_strip_from_output=None,
   ):
     ...
   
   @torch.no_grad()
   def generate(self, models, sample, **kwargs):
       """Score a batch of translations."""
       net_input = sample["net_input"]
       ...

主要针对生成的任务进行构建的,tasks 支持比较少,而且两种风格不统一,同时不支持模型并行模式的推理。

AllenNLP

主要代码在 https://github.com/allenai/allennlp/blob/426d894ceef591b406cb77a7b094c88c85ad0068/allennlp/models/model.py#L193

在模型层面进行实现,每种模型绑定一个推理方式,这种方式下,模型和任务没有解耦,在训练中耦合 generation 的逻辑

Megatron-LM

提供了 api 代码 https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/text_generation/api.py#L30

def generate_and_post_process(model,
                              prompts=None,
                              tokens_to_generate=0,
                              return_output_log_probs=False,
                              top_k_sampling=0,
                              top_p_sampling=0.0,
                              temperature=1.0,
                              add_BOS=False,
                              use_eod_token_for_early_termination=True):
    """Run inference and post-process outputs, i.e., detokenize,
    move to cpu and convert to list."""

    # Main inference.
    tokens, lengths, output_log_probs = generate(
        model,
        prompts=prompts,
        tokens_to_generate=tokens_to_generate,
        return_output_log_probs=return_output_log_probs,
        top_k_sampling=top_k_sampling,
        top_p_sampling=top_p_sampling,
        temperature=temperature,
        add_BOS=add_BOS,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)

    # Only post-process on first stage.
    if mpu.is_pipeline_first_stage():
        tokens, prompts_plus_generations, prompts_plus_generations_segments = \
            detokenize_generations(tokens, lengths, True)
    ...

支持的 tasks 比较少,不过可以支持复杂并行的模型推理,比如 pipeline 并行,但是整体实现以及调用流程比较复杂,对用户不友好

HuggingFace

主要代码在 https://github.com/huggingface/transformers/blob/eb5bdcdfa51f743887ee1d9c7f230444d7a8b23c/src/transformers/pipelines/base.py#L710

在整个流程抽象为如下的处理流

Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output

调用方式清晰简单

from transformers import pipeline

# Allocate a pipeline for sentiment-analysis
classifier = pipeline('sentiment-analysis')
classifier('We are very happy to introduce pipeline to the transformers repository.')
>>> [{'label': 'POSITIVE', 'score': 0.9996980428695679}]

扩展任务比较方便,可以继承基类 Pipeline,解耦了任务相关的流程和模型推理的流程。

@thinksoso @xiezipeng-ML 遗漏的内容可以补充一下,有错误的地方可以修正~

@L1aoXingyu
Copy link
Collaborator Author

目前倾向于参考 Huggingface 的方案,将整个 inference 流程分解为 task-specific 的部分和 model-related 部分,学习 Huggingface 的推理 API,在模型内部支持 tensor 并行和 pipeline 并行的调用,先支持经典的 text_classification 和 text_generation 任务。

@CPFLAME
Copy link
Contributor

CPFLAME commented Apr 14, 2022

目前来说我想好的整个pineline和huggingface流程差不多, 首先我们得有一个基类的pipeline作为可继承的类使用:

from libai.config import LazyConfig, try_get_key
from libai.engine.default import DefaultTrainer
from libai.utils.checkpoint import Checkpointer
from libai.data.structures import DistTensorData, Instance

class BasicPipeline:
    def __init__(
            self,
            config_file,
            **kwargs):
        self.cfg = LazyConfig.load(config_file)
        self.model = self.load_model(config_file)
        self.tokenier = self.build_tokenizer(config_file)
        ...

    def load_model(cfg):
        model = DefaultTrainer.build_model(cfg).eval()
        # 这里除了加载libai的模型用checkpointer以外, 
        # 也可以用户支持自定义, 从其他框架导入weight, 比如load_huggingface_weight
        Checkpointer(model, save_dir=cfg.train.output_dir).resume_or_load(
            cfg.train.load_weight, resume=False
        )
        
        if try_get_key(cfg, "train.graph.enabled", default=False):
            model = DefaultTrainer.build_graph(cfg, model, is_train=False)
        return model

    def build_tokenizer(cfg):
        ...

    def __call__(self, inputs, *args, batch_size=None, **kwargs):
        model_inputs = self.preprocess(inputs, batch_size)
        model_outputs = self.forward(model_inputs)
        outputs = self.postprocess(model_outputs)
        return outputs

    def preprocess(self, inputs, batch_size, **kwargs):
        ...
        return Instance(
            input_ids=DistTensorData(...),
            attention_mask=DistTensorData(...),
            tokentype_ids=DistTensorData(...),
        )

    def forward(self, model_inputs, **kwargs):
        ...
        model_outputs = self.model(model_inputs)
        return model_outputs


    def postprocess(self, model_outputs, **kwargs):
        ...
        return outputs

对于其中tensor并行pipeline并行 因为是直接用了libai来build模型, 所以只需要修改lazyconfig里面并行配置就可以支持各种并行了. 在load_model()中用户唯一可以自定义修改的地方是 load_pretrain_weight() 的方式, 是直接从libai里面读, 还是读取其他框架里面的weights. 但是宗旨不变的是, 模型的构建代码 用libai的layers构建,这样就可以多卡启动和支持各种并行了.

@CPFLAME
Copy link
Contributor

CPFLAME commented Apr 14, 2022

对于不同的任务, 我们的inference代码会不一样,

分类任务

如果是对于只有encoder的分类任务, 那么模型会比较简单, 直接输出类别和分数就可以了.

生成任务

但是如果是包含decoder的生成任务, 在进行forward()的时候, 需要特别注意:

  1. 在生成任务里面, 由于decoder当前的输出 依赖于以前的输出, 所以需要用for循环一直调用, 类似的代码如下:
def couplet(model, src, data_loader, config):
    vocab = data_loader.vocab
    tokenizer = data_loader.tokenizer
    model.eval()
    tokens = [vocab.stoi[tok] for tok in tokenizer(src)]  # 构造一个样本
    num_tokens = len(tokens)
    src = (torch.LongTensor(tokens).reshape(num_tokens, 1))  # 将src_len 作为第一个维度
    tgt_tokens = greedy_decode(model, src, max_len=num_tokens + 5,
                               start_symbol=data_loader.BOS_IDX, config=config,
                               data_loader=data_loader).flatten()  # 解码的预测结果
    return "".join([vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")

def greedy_decode(model, src, max_len, start_symbol, config, data_loader):
    src = src.to(config.device)
    memory = model.encoder(src)  # 对输入的Token序列进行解码翻译
    ys = torch.ones(1, 1).fill_(start_symbol). \
        type(torch.long).to(config.device)  # 解码的第一个输入,起始符号
    for i in range(max_len - 1):
        memory = memory.to(config.device)
        tgt_mask = (model.my_transformer.generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(config.device)  # 根据tgt_len产生一个注意力mask矩阵(对称的)
        out = model.decoder(ys, memory, tgt_mask)  # [tgt_len,tgt_vocab_size]
        out = out.transpose(0, 1)  # [tgt_vocab_size, tgt_len]
        prob = model.classification(out[:, -1])  # 只对对预测的下一个词进行分类
        _, next_word = torch.max(prob, dim=1)  # 选择概率最大者
        next_word = next_word.item()
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        # 将当前时刻解码的预测输出结果,同之前所有的结果堆叠作为输入再去预测下一个词。
        if next_word == data_loader.EOS_IDX:  # 如果当前时刻的预测输出为结束标志,则跳出循环结束预测。
            break
    return ys
  1. 可以看到上述的代码中, 在for训练里面ys会不断的concat当前的输出, 然后送到下一轮decoder里面去

生成任务的加速

从大体上看, 上述的代码是没有问题的, 但是有一个点我们可以加速的地方, 我们可以把decoder里面第一次运行的key-value保存起来, 在huggingface里面也是这么做的

由于decoder里面的key和value, 都是通过encoder的输出进行全连接得到的, 在网络是eval()模式, 而且encoder也只进行了一次前向的情况下, 在每次调用decoder期间, 用到的key和value都是同一个值, 也就是说在decoder里面key和value的生成只需要进行一次计算, 然后保存起来, 以后的计算都是重复的.

在LiBai的libai/layers/transformer_layer.py里面已经提供了这个接口:

def forward(
        self,
        hidden_states,
        attention_mask=None,
        encoder_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        use_cache=False,
    ):
        ...
        if past_key_value is not None:
            if self.is_decoder:
                assert len(past_key_value) == 4
                self_attn_past_key_value = past_key_value[:2]
                cross_attn_past_key_value = past_key_value[2:]
            else:
                self_attn_past_key_value = past_key_value
                cross_attn_past_key_value = None
        else:
            self_attn_past_key_value, cross_attn_past_key_value = None, None

        layernorm_output = self.input_layernorm(hidden_states)
        attention_output = self.self_attention(
            layernorm_output,
            attention_mask=attention_mask,
            past_key_value=self_attn_past_key_value,
            use_cache=use_cache,
        )
        attention_output = self.drop_path(attention_output)

        if use_cache:
            attention_output, presents = attention_output
        hidden_states = hidden_states + attention_output

        layernorm_output = self.post_attention_layernorm(hidden_states)

        if self.is_decoder:
            # todo: use key-value to pass the arguments
            attention_output = self.cross_attention(
                layernorm_output,
                encoder_states,
                attention_mask=encoder_attention_mask,
                past_key_value=cross_attn_past_key_value,
                use_cache=use_cache,
            )

            if use_cache:
                attention_output, decoder_presents = attention_output
                presents += decoder_presents

            attention_output = self.drop_path(attention_output)
            hidden_states = hidden_states + attention_output
            layernorm_output = self.post_cross_attention_layernorm(hidden_states)

        mlp_output = self.mlp(layernorm_output)
        mlp_output = self.drop_path(mlp_output)
        output = hidden_states + mlp_output

        if use_cache:
            output = (output, presents)
        return output

所以我们需要在写inference的时候, 需要在调用transformer_layer的地方, 设置use_cache=True, 然后把decoder第一次运行完后每个transformer_layer返回的presents全部保存起来. 然后在后续再次调用decoder的时候, 把每个transformer对应的presents作为past_key_value传进去, 避免重复计算.

关于怎么修改代码

有两个办法,

  • 方法1: 在inference里面 重新定义一下model.forward(),

大致代码如下:

from types import MethodType

def my_forward(self, ...):
        ...
        dec_embedding_output = self.embedding(decoder_input_ids)
        dec_hidden_states = dec_embedding_output
        presents = []
        if past_key_values is None:
            past_key_values = [None] * self.decoder.layers
        for layer, past_key_value in zip(self.decoder.layers, past_key_values):
            dec_hidden_states, present = layer(
                dec_hidden_states,
                decoder_attn_mask,
                encoder_states,
                encoder_decoder_attn_mask,
                past_key_value=past_key_value,
                use_cache=True,
            )
            presents.append(present)
        decoder_states = self.decoder.final_layernorm(dec_hidden_states)
        logits = self.lm_head(decoder_states, self.embedding.word_embeddings.weight)
        return logits, presents

# 重新指定model.forward()

model.forward = MethodType(my_forward, model)
  • 方法2 : 直接在libai/models/task_model.py下面直接修改的代码, 添加if else考虑past_key_valueuse_cache的情况

其中方法1的好处是不用修改libai里面本来的代码, libai里面的代码让人看上去觉得比较干净, 坏处就是每个包含decoder的model, 可能都需要单独写一个forward()来重构一下.

方法2的好处是可以一劳永逸, 在inference里面会比较干净, 坏处就是在libai/models/添加了if_else分支, 如果对于只想看网络结构的用户来说, past_key_value这个部分是多余的, 甚至会对整体网络的理解造成一定的困难.

@L1aoXingyu
Copy link
Collaborator Author

我倾向于用方法2,megatron 和 huggingface 应该都是这样做的~

@Ldpe2G
Copy link
Collaborator

Ldpe2G commented Apr 14, 2022

做推理生成任务的时候,输入序列是变长的是吧,那目前只能用 eager global 来做了?

@CPFLAME
Copy link
Contributor

CPFLAME commented Apr 15, 2022

做推理生成任务的时候,输入序列是变长的是吧,那目前只能用 eager global 来做了?

我理解 不止生成任务, 可能分类任务输入序列也是变长的, 只不过都会进行padding到max_length.
但是我感觉用eager global来做更加的灵活.
而且可能还有一种情况就是, 输入的序列, 超过了训练的max_length, 这种情况怎么弄可能还需要再讨论一下

@L1aoXingyu
Copy link
Collaborator Author

正好我们下午要和 idea 开会,这个部分的问题涉及到 NLP 的 domain knowledge,我们和他们请教一下

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants