Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

It is dangerous to using default non_block=True. #1146

Open
heshenghuan opened this issue Oct 27, 2023 · 0 comments
Open

It is dangerous to using default non_block=True. #1146

heshenghuan opened this issue Oct 27, 2023 · 0 comments

Comments

@heshenghuan
Copy link

Hi all, I'm recently trying to run the LLaMA-2-70B model in a single GPU, with a lot of help from this project.

But I found that, it is very dangerous to using default non_block=True setting like:

https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L328
https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L332

main code:

tokenizer = LlamaTokenizer.from_pretrained(args.model_dir)
model = LlamaForCausalLM.from_pretrained(
    args.model_dir,
    low_cpu_mem_usage=True,
    torch_dtype=DTYPE
).eval()

origin_llama_model = model.get_decoder()
model.set_decoder(
    OffloadLlamaModel(origin_llama_model, device=device, num_slices=args.num_slices)
)
del origin_llama_model
model.lm_head.cuda()  # move model.lm_head to GPU

prompt = "Give me some suggestions on how to lose weight."
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device=device)

logging.info("Generating response...")
s = time.time()
generate_ids = model.generate(
    input_ids,
    do_sample=False,
    num_beams=1,
    max_length=200
)

The OffloadLlamaModel code:

class DecodeOutput(object):
    def __init__(self, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
                 output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns):

        super().__init__()
        self.elements = [
            hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
            output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns
        ]

    def items(self):
        return self.elements

    def cuda(self):
        self.elements = [item.cuda() if hasattr(item, 'cuda') and callable(item.__getattribute__('cuda')) else item
                         for item in self.elements]
        return self

    def cpu(self):
        self.elements = [item.cpu() if hasattr(item, 'cpu') and callable(item.__getattribute__('cpu')) else item
                         for item in self.elements]
        return self

    def __str__(self):
        return "DecodeOutput(" + str(self.elements) + ")"

    def __getitem__(self, index: int):
        return self.elements[index]


class WrappedLlamaDecoderLayer(nn.Module):
    def __init__(self, index: int, decoder: LlamaDecoderLayer):
        super(WrappedLlamaDecoderLayer, self).__init__()
        self.idx = index
        self.decoder = decoder

    def forward(self, inputs: DecodeOutput):
        # unpack all parameters
        [hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
         output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns] = inputs.items()

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        past_key_value = past_key_values[self.idx] if past_key_values is not None else None

        # note: removed code like 'if self.gradient_checkpointing and self.training', so only for inference
        layer_outputs = self.decoder(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )

        hidden_states = layer_outputs[0]

        if use_cache:
            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

        outputs = DecodeOutput(
            hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
            output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns
        )

        return outputs


class OffloadLlamaModel(nn.Module):
    def __init__(self, llama_model: LlamaModel, device=torch.device('cuda'), offload_device=torch.device("cpu"),
                 num_slices=3, checkpoint_activation=False, num_microbatches=1):
        logging.info("OffloadLlamaModel Initializing.")
        super(OffloadLlamaModel, self).__init__()
        self.config = llama_model.config
        self.padding_idx = llama_model.padding_idx
        self.vocab_size = llama_model.vocab_size

        self.embed_tokens = llama_model.embed_tokens.cuda()

        logging.info("Convert origin LlamaModel.layers to a nn.Sequential of WrappedLlamaDecoders.")
        _sequential = nn.Sequential()
        for idx, decoder in enumerate(llama_model.layers):
            _sequential.add_module("layer_%d" % idx, WrappedLlamaDecoderLayer(idx, decoder))

        self.layers = OffloadModel(
            model=_sequential,
            device=device,
            offload_device=offload_device,
            num_slices=num_slices,
            checkpoint_activation=checkpoint_activation,
            num_microbatches=num_microbatches,
        )

        for sid, slc in enumerate(self.layers.model_slices):
            logging.debug(
                f"Shard {sid:d} holds WrappedLlamaDecodeLayer [{','.join(str(m.idx) for m in slc.model_shard)}]"
            )

        self.norm = llama_model.norm.cuda()

        # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        inputs = DecodeOutput(
            hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
            output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns
        )
        layer_outputs = self.layers.forward(inputs)

        [hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
         output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns] = layer_outputs.items()

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

I found that the model generated different responses when using different num_slices settings, even when the random seed fixed.

The pairwise_distance of each decoder layer between the original model and the offloaded model was like:

2023-10-27 17:46:28,544 - INFO: Loading LLaMA model and tokenizer.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.98s/it]
2023-10-27 17:46:47,599 - INFO: Running model natively.
2023-10-27 17:47:15,826 - INFO: Running OffloadModel using given num_slices setting.
2023-10-27 17:47:15,826 - INFO: OffloadLlamaModel Initializing.
2023-10-27 17:47:16,049 - INFO: Convert origin LlamaModel.layers to a nn.Sequential of WrappedLlamaDecoders.
2023-10-27 17:47:16,052 - INFO: This model has 12688.18M parameters, aiming for 6344.09M parameters per shard
2023-10-27 17:47:39,404 - INFO: Shard 0 holds 6344.09M parameters
2023-10-27 17:47:39,405 - INFO: Shard 1 holds 6344.09M parameters
2023-10-27 17:47:39,412 - DEBUG: Shard 0 holds WrappedLlamaDecodeLayer [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19]
2023-10-27 17:47:39,412 - DEBUG: Shard 1 holds WrappedLlamaDecodeLayer [20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39]
Embedding is same: True
RMSNorm is same: True
Checking layers:
Layer 00 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 01 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 02 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 03 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 04 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 05 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 06 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 07 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 08 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 09 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 10 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 11 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 12 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 13 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 14 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 15 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 16 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 17 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 18 obj is same: True, attention diff: 0.7080, hidden_state diff: 0.0001
Layer 19 obj is same: True, attention diff: 0.7104, hidden_state diff: 13.1719
Layer 20 obj is same: True, attention diff: 0.0465, hidden_state diff: 19.6094
Layer 21 obj is same: True, attention diff: 0.0314, hidden_state diff: 19.7031
Layer 22 obj is same: True, attention diff: 0.0166, hidden_state diff: 19.9062
Layer 23 obj is same: True, attention diff: 0.0146, hidden_state diff: 20.4062
Layer 24 obj is same: True, attention diff: 0.0164, hidden_state diff: 20.8125
Layer 25 obj is same: True, attention diff: 0.0153, hidden_state diff: 21.2500
Layer 26 obj is same: True, attention diff: 0.0143, hidden_state diff: 21.8281
Layer 27 obj is same: True, attention diff: 0.0151, hidden_state diff: 22.4375
Layer 28 obj is same: True, attention diff: 0.0112, hidden_state diff: 22.9844
Layer 29 obj is same: True, attention diff: 0.0150, hidden_state diff: 23.4531
Layer 30 obj is same: True, attention diff: 0.0098, hidden_state diff: 24.0781
Layer 31 obj is same: True, attention diff: 0.0129, hidden_state diff: 24.6562
Layer 32 obj is same: True, attention diff: 0.0098, hidden_state diff: 25.2656
Layer 33 obj is same: True, attention diff: 0.0164, hidden_state diff: 25.8750
Layer 34 obj is same: True, attention diff: 0.0106, hidden_state diff: 26.5000
Layer 35 obj is same: True, attention diff: 0.0133, hidden_state diff: 27.2188
Layer 36 obj is same: True, attention diff: 0.0166, hidden_state diff: 28.0000
Layer 37 obj is same: True, attention diff: 0.0179, hidden_state diff: 28.9688
Layer 38 obj is same: True, attention diff: 0.7056, hidden_state diff: 30.0312
Layer 39 obj is same: True, attention diff: 0.6108, hidden_state diff: 83.8750
['<s>me a examples for how to improve weight and\n']
['<s>me a examples on how to improve weight fast I']

Once I manually set non_blocking=False, all the above diff disappeared.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant