Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unable to create cuda shared memory handle when using multiprocessing to send multiple requests #7101

Open
justanhduc opened this issue Apr 11, 2024 · 6 comments
Labels
bug Something isn't working module: clients Issues related to Perf Analyzer and clients

Comments

@justanhduc
Copy link

justanhduc commented Apr 11, 2024

Description
I use multiprocessing to send multiple requests to the Triton clients. When I use CUDA shm, even with only 1 process, it results in an initialization error

Traceback (most recent call last):
  File "/mnt/zfs/duc_nguyen/miniconda3/envs/mlvton/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/mnt/zfs/duc_nguyen/miniconda3/envs/mlvton/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/mnt/zfs/duc_nguyen/projects/ml-vton/triton_modules/models/v1/gp_hrvton_gen_cudashm.py", line 103, in warmup
    output = self.inference_request(*inputs)
  File "/mnt/zfs/duc_nguyen/projects/ml-vton/triton_modules/models/v1/gp_hrvton_gen_cudashm.py", line 33, in inference_request
    input0_shm_handle = cudashm.create_shared_memory_region(
  File "/mnt/zfs/duc_nguyen/miniconda3/envs/mlvton/lib/python3.8/site-packages/tritonclient/utils/cuda_shared_memory/__init__.py", line 140, in create_shared_memory_region
    raise CudaSharedMemoryException(
tritonclient.utils.cuda_shared_memory._utils.CudaSharedMemoryException: unable to create cuda shared memory handle

The CUDA shm implementation follows closely the given example, and it works perfectly when not using multiprocessing.

Triton Information

Triton Docker 22.12

To Reproduce

        input_0_np = np.asarray(input_0, dtype=np.float32)
        input_1_np = np.asarray(input_1, dtype=np.float32)
        input0_byte_size = input_0_np.size * input_0_np.itemsize
        input0_data = str(uuid.uuid4())
        input0_shm_handle = cudashm.create_shared_memory_region(
            input0_data, input0_byte_size, 0
        )
        cudashm.set_shared_memory_region(input0_shm_handle, [input_0_np])
        self.triton_client.register_cuda_shared_memory(
            input0_data, cudashm.get_raw_handle(input0_shm_handle), 0, input0_byte_size
        )
        infer_input_0 = InferInput("input_0", input_0_np.shape, "FP32")
        infer_input_0.set_shared_memory(input0_data, input0_byte_size)

        input1_byte_size = input_1_np.size * input_1_np.itemsize
        input1_data = str(uuid.uuid4())
        input1_shm_handle = cudashm.create_shared_memory_region(
            input1_data, input1_byte_size, 0
        )
        cudashm.set_shared_memory_region(input1_shm_handle, [input_1_np])
        self.triton_client.register_cuda_shared_memory(
            input1_data, cudashm.get_raw_handle(input1_shm_handle), 0, input1_byte_size
        )
        infer_input_1 = InferInput("input_1", input_1_np.shape, "FP32")
        infer_input_1.set_shared_memory(input1_data, input1_byte_size)

        output = np.empty((1, 3, 1024, 768), dtype=np.float32)
        output_byte_size = output.size * output.itemsize
        output_data = str(uuid.uuid4())
        output_shm_handle = cudashm.create_shared_memory_region(
            output_data, output_byte_size, 0
        )
        self.triton_client.register_cuda_shared_memory(
            output_data, cudashm.get_raw_handle(output_shm_handle), 0, output_byte_size
        )
        infer_output = InferRequestedOutput("output_0", binary_data=True)
        infer_output.set_shared_memory(output_data, output_byte_size)

        response = self.triton_client.infer(
            self.model_name,
            inputs=[infer_input_0, infer_input_1],
            request_id=str(uuid.uuid4()),
            model_version=self.model_version,
            outputs=[infer_output]
        )
        output_response = response.get_output("output_0")
        if output_response is not None:
            output = cudashm.get_contents_as_numpy(
                output_shm_handle,
                utils.triton_to_np_dtype(output_response["datatype"]),
                output_response["shape"],
            )
        else:
            import sys
            print("last_flow is missing in the response.")
            sys.exit(1)
            
        self.triton_client.unregister_cuda_shared_memory(input0_data)
        self.triton_client.unregister_cuda_shared_memory(input1_data)
        self.triton_client.unregister_cuda_shared_memory(output_data)
        cudashm.destroy_shared_memory_region(output_shm_handle)
        cudashm.destroy_shared_memory_region(input0_shm_handle)
        cudashm.destroy_shared_memory_region(input1_shm_handle)

This is how the inference with CUDA shm is implemented.

    model = TritonModel()
    num_requests = 1

    # Create a list of processes
    processes = []
    for _ in range(num_requests):
        # Create a process for each request
        p = multiprocessing.Process(target=model.warmup, args=(True,))
        processes.append(p)
        p.start()

    # Wait for all processes to finish
    for p in processes:
        p.join()

The multiprocessing part is implemented as above. model.warmup sends random inputs to the server. The error above happens when hitting input0_shm_handle = cudashm.create_shared_memory_region(input0_data, input0_byte_size, 0).

Expected behavior
I expected the code should run normally.

Please let me know what I'm missing here. Thanks in advance!

Update 1: The program cannot get/set device in subprocess. It dies here prev_device = call_cuda_function(cudart.cudaGetDevice) in create_shared_memory_region.

@Tabrizian Tabrizian added bug Something isn't working module: clients Issues related to Perf Analyzer and clients labels Apr 19, 2024
@Tabrizian
Copy link
Member

Thanks for reporting this issue. I have filed an internal issue for further investigation.

@tanmayv25
Copy link
Contributor

tanmayv25 commented Apr 22, 2024

@justanhduc 22.12 is a very old release. We have made some changes in our client library to use CUDA python for cuda shared memory handle implementation.

Can you try upgrading to our latest client and let us know if you are seeing this issue here?
We are still using cudaGetDevice call so probably you would run into this issue again.

@justanhduc
Copy link
Author

Hi @tanmayv25. Thanks for the pointer. I used the latest server docker (24.03) and the latest client and I'm still facing this issue. Could you have a look further into this issue?

@lkomali
Copy link

lkomali commented Apr 26, 2024

Hi @justanhduc
I tried to reproduce the issue and created a repro script. I used Triton 24.03. With p =1, I did not get any error and the script ran successfully.
Attaching the repro script for reference.
cuda_shm_repo.zip

@tanmayv25
Copy link
Contributor

@justanhduc Can you help us reproduce the issue?

@justanhduc
Copy link
Author

Hi guys @lkomali @tanmayv25. Sorry I didn't see the noti earlier. I will run your code and see if I can reproduce the error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module: clients Issues related to Perf Analyzer and clients
Development

No branches or pull requests

4 participants