Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text generation inference integration #12

Open
wants to merge 60 commits into
base: main
Choose a base branch
from

Conversation

andrewramsay
Copy link
Collaborator

@andrewramsay andrewramsay commented May 3, 2024

This PR is about replacing the original custom LLM implementation in OAT with HuggingFace's Text Generation Inference framework.

The core change here is refactoring OAT's architecture around LLM requests.

In the current main version of OAT, online services making RPCs to llm_functionalities which then makes calls to the configured alpaca_llm model loaded into its container. This obviously means llm_functionalities requires a GPU, and so the other OAT services have to be prepared for it to be unavailable if the system is running somewhere without sufficient GPU resources.

With TGI integrated, llm_functionalities now becomes a relatively thin wrapper around calls to a TGI endpoint using the InferenceClient API. This allows for most of the error/timeout handling code around LLM calls in the current codebase to be removed and concentrated in llm_functionalities, since all LLM requests will get routed through there before reaching the TGI endpoint.

Edited to include features since the PR was created:

  • It's now possible to load a local model into TGI by placing the files under shared/file_system/downloads/llm_functionalities/<model_name> and then setting MODEL_ID to /models/<model_name>
  • Updated the Docker volume configuration for tgi to distinguish between local models and downloaded models
  • llm_functionalitiies still downloads the alpaca_llm model but stores it in the location corresponding to the tgi Docker volume for local models
  • This allows us to continue using the same model everything is currently built around (I think the prompts being hardcoded makes it tricky to swap them out for other model-specific prompts without more refactoring)
  • 747050d llm_functionalities has a different way of connecting to the TGI endpoint. Originally it would enter a loop attempting to connect and give up after some number of retries. Now it will only attempt to connect when the first LLM request arrives from OAT (returning empty responses when the endpoint is unavailable), and it should also be able to handle the case where the endpoint goes down and comes back up while OAT is running. It also includes a fix for the case where the endpoint's hostname is unresolvable, this was causing a long timeout
  • There are some a few LLM-specific tests in shared/tests/integration_tests/test_llm.py. They don't cover all the LLM components but it's a start at least (some of them are awkward to trigger)
  • 62bb76f there seemed to be a problem with the parsing of some types of LLM responses due to the way the prompt was constructed, I've tried to fix this

Testing

To run the new LLM tests:

  1. docker compose up (this should download the alpaca_llm folder to shared/file_system/downloads/tgi/local/)
  2. docker compose run tester --runslow -vk test_llm should run the 3 LLM tests

- Remove existing requirements and replace with huggingface_hub
- Use oat_common as a base since we don't need CUDA support now
I've replaced the existing code to load the model with creating an
`InferenceClient` object using the endpoint URL defined in the
docker-compose.yml file.

Creating the object doesn't trigger any connection, so it currently
submits a simple query to check if the TGI endpoint is actually
available. This might not happen immediately (e.g. if it's still
downloading or loading a model), it currently has a basic retry setup
but might need some more thought put into it.

The `call_model` and `batch_call_model` methods are updated to call the
`.text_generation` method on the `InferenceClient`. For the batch
method, it submits things in parallel using a ThreadPoolExecutor, TGI
doesn't offer a batch-specific endpoint but it should automatically batch
the requests internally based on the docs.
- Add some env vars with default values to the tgi service definition to
  allow easy control of some options that might need changed depending
  on model/hardware
- Add a wrapper script as the entrypoint for the tgi container to allow
  passing in extra CLI parameters using a TGI_PARAMS env var
These 2 env vars can be used to adjust the number of retries
llm_functionalities will make when attempting to connect to the TGI
endpoint, and the delay between successive retries.
There seems to be a bug in the `InferenceClient.summarization` method in
the recent officially released versions, it's fixed in the current
development version.
Use the new env vars from `docker-compose.yml` to control the connection
attempts to the TGI endpoint
This just adds TGI equivalents for the `generate_summary` and
`generate_summaries` methods that pass requests through to the
`InferenceClient.summarization` method
(https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.summarization)
This is currently used by one of the LLM tests to check if a particular
step in the policy was activated to produce the current response.
Adding `delete_taskmap` and `delete_session` methods. These aren't much
use for the online system, they're used in some new tests.
This is to make some new tests simpler to implement, not relevant for
the online system.

- Define new `delete_session` and `delete_taskmap` RPCs
- Add `delete_session` and `delete_taskmap` methods to `DynamoDB`
- Add a `delete` method to `ProtoDB` which those 2 methods call
This commit removes most of the timeout-handling and exception-checking
for the LLM code in `functionalities`.

Previously the system couldn't assume `llm_functionalities` would be
available, so there had to be error handling for failed RPC calls within
the various `llm_*` components in `functionalities`. There also had to
be timeout handling in some cases to prevent lengthy LLM calls from
delaying system responses.

In the TGI-enabled version of OAT, `llm_functionalities` has no GPU
requirements any more so we can probably assume it's always going to be
available like the other services. That also means we can just do
the timeout/error handling in `llm_functionalities` since all calls to
TGI will be routed through there.

This ultimately means the code around the  RPCs to `llm_functionalities`
can be simplfiied to remove the existing timeout and exception handling.
More removal of timeout and error handling around RPCs to
`llm_functionalities` (now handled when `llm_functionalities` makes
calls to TGI).
Various changes:
- Remove special handling of summary requests
- Define a default timeout for TGI calls
- Update `_check_connectivity` to return True/False rather than throw an
  exception, to allow for returning empty responses to the client
  instead
- Add a `_call_with_timeout` wrapper method for the two `call_model`
  methods to allow them to submit TGI requests with a timeout applied
- Timeouts are set from the `ModelRequest` objects or if not set there
  the default value is used
Python strings don't support `foo[1] = "."` assignments, so this would
throw an exception if it was executed
- Remove old TGI_CONNECTION_* env vars
- Add a volume to the TGI container to load local models from
- Update comments
- Adjust default SHM size
- Set the default MODEL_ID to be the local Alpaca model
Instead of reloading the session from the database we can just check the
JSON object returned by the orchestrator, since it should include some
of the changed text.
The previous version of this class entered a loop on startup to
periodically check if the TGI endpoint address was connectable. If it
was able to make a connection it assumed that it would be remain
valid indefinitely.

This version should be a bit more flexible in that it now only attempts
to connect when LLM requests are triggered. It should also handle the
endpoint becoming unavailable and then available again because the first
request sent after it comes back up should create a new client object
for future requests.

The other important change here is the use of dnspython to test if the
TGI endpoint hostname is resolvable. I found that if this isn't true
(e.g. if you launch OAT with a remote TGI endpoint that is still
starting up), the DNS resolution process takes 10+ seconds to timeout.
This seems to be difficult to handle using the Python stdlib, but the
dnspython package makes it very simple to set a timeout on the
resolution process. This allows the per-request connection checking to
work without excessive delays.
While I was trying to add some LLM tests, I found that this class didn't
seem to be set up to parse the model responses correctly. The original
prompt asks it to produce output in the format:

> {"step_text": "<model output>"}

and then ends the prompt with:

> {"step_text:": "

The Alpaca model we've been using does successfully complete this with
something like:

> one two three"}

but the problem is the `extract_response` method assumes that the
generated text will be a complete parseable JSON string, not a partial
one. This leads the method to return an empty dict as if the response
had failed to generate any text, when it will normally have generated
something valid.

I've changed the prompt so it will output the text without any extra
formatting, removed `extract_response` because it's now not required,
and removed `process_response_text` because it seemed to be already
unused.
If the input string was empty, these methods would throw an IndexError
when calling `re.search` (because of trying to do `""[-1]`)
The prompt for this type of request expects a JSON-compatible string
response with 2 fields. However the prompt already includes the first
part of the expected string (`"{\"name\"`) and the parsing fails because
it's run only on the response, which is an incomplete JSON string.

This just adjusts the parsing method to add the section of the response
format included in the prompt to the start of the actual response,
allowing the parsing to function as intended.
By keeping this in LLM functionalities it allows the `Downloader` class
to continue working as normal, even though it's technically going to be
downloading files for use by TGI
This helps make clear which files are local and which have been
downloaded from huggingface.co
The connection will be attempted for the first time when the first LLM
request is received instead
@andrewramsay andrewramsay marked this pull request as ready for review May 23, 2024 16:29
@SophieFischer SophieFischer self-requested a review May 27, 2024 12:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant