-
Notifications
You must be signed in to change notification settings - Fork 275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLM fine-tuning #1350
LLM fine-tuning #1350
Conversation
fn insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult<String> { | ||
|
||
let id_value = Spi::get_one_with_args::<i64>( | ||
"INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we include a migration for this table somewhere? We need to make sure it's created on all databases running PostgresML.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, need to add the following three to our migration once we freeze on the version number.
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text_pair_classification';
CREATE TABLE IF NOT EXISTS pgml.logs (
id SERIAL PRIMARY KEY,
model_id BIGINT,
project_id BIGINT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
logs JSONB
);
MarkupSafe==2.1.3 | ||
marshmallow==3.20.1 | ||
matplotlib==3.8.2 | ||
maturin==1.4.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think you need maturin inside PostgresML deployments. This may be a "leak" from the pypgrx extension venv.
@@ -803,7 +803,7 @@ fn tune( | |||
project_name: &str, | |||
task: default!(Option<&str>, "NULL"), | |||
relation_name: default!(Option<&str>, "NULL"), | |||
y_column_name: default!(Option<&str>, "NULL"), | |||
_y_column_name: default!(Option<&str>, "NULL"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the underscore? Is it because it's not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct.
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | ||
from trl.trainer import ConstantLengthDataset | ||
from peft import LoraConfig, get_peft_model | ||
from pypgrx import print_info, insert_logs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to make sure we either import this conditionally (only for fine tuning) and we include this in requirements.linux.txt. I didn't see a Mac OS build for this and for the M1/M2 architecture, we've been doing releases manually from our Macs (Github actions doesn't have M1 builders).
This makes me thing we should start cross-compiling soon. Rust supports this pretty well, maturin may need a patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't get fine tuning to work on Mac OS. It keeps crashing. How about I check for the operating system and bail out if it is mac?
requirements.linux.txt is updated with trl and peft.
logs["step"] = state.global_step | ||
logs["max_steps"] = state.max_steps | ||
logs["timestamp"] = str(datetime.now()) | ||
print_info(json.dumps(logs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you use use print()
, this will appear in Postgres logs. It won't be pretty, but we can add a function that formats it correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add indent in json.dumps() to pretty print.
trainable_model_params += param.numel() | ||
|
||
# Calculate and print the number and percentage of trainable parameters | ||
print_info(f"Trainable model parameters: {trainable_model_params}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kczimm This will require us to use the main thread for ML workloads in our cloud.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A PR with that is close. What's the reason we need main thread here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need logging visibility during fine tuning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks to a commit by @levkk, we should be able to log from any thread.
189e9f0
to
4bbca96
Compare
####################### | ||
|
||
|
||
class PGMLCallback(TrainerCallback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't be opposed to this functionality living in it's own file like tune.py
, since transformers is getting a bit beefy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transformers.py is hardcoded in several places. Needs some more refactoring and testing to accomplish moving finetuning code to tune.py. Will revisit this in the next iteration. #1378
self.model_id = model_id | ||
|
||
def on_log(self, args, state, control, logs=None, **kwargs): | ||
_ = logs.pop("total_flos", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why throw away total_flos
?
} | ||
|
||
#[pyfunction] | ||
fn print_info(info: String) -> PyResult<String> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be more reusable as log(level, msg)
else: | ||
self.model_name = hyperparameters.pop("model_name") | ||
|
||
if "token" in hyperparameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this a model init param, not a hyperparam, like many other things in this list? Maybe hyperparams covers everything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct. Moved all the parameters to hyperparams.
trainable_model_params += param.numel() | ||
|
||
# Calculate and print the number and percentage of trainable parameters | ||
print_info(f"Trainable model parameters: {trainable_model_params}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need logging visibility during fine tuning.
y_train, | ||
x_test, | ||
y_test, | ||
Ok::<std::option::Option<()>, i64>(Some(())) // this return type is nonsense |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:)
|
||
let text1_column_value = dataset_args | ||
.0 | ||
.get("text1_column") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we require these column names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for text pair classification - (natural language inference, qnli etc.), we need three columns - text1, text2 and the class.
|
||
let system_column_value = dataset_args | ||
.0 | ||
.get("system_column") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How standard are these names these days?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For conversation task, system, user and assistant have become standard keys.
Ok(info) | ||
} | ||
/// A Python module implemented in Rust. | ||
#[pymodule] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this crate is interdependent, what if we moved this whole pymodule into the main pgml-extension crate, under bindings/python/mod.rs instead of publishing it as a separate crate?
pgml-extension/.gitignore
Outdated
@@ -14,3 +14,5 @@ | |||
.DS_Store | |||
|
|||
|
|||
# venv | |||
pgml-venv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
newline
project_id BIGINT, | ||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | ||
logs JSONB | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new line
Example: https://github.com/postgresml/postgresml/tree/santi-llm-fine-tuning?tab=readme-ov-file#llm-fine-tuning
Refactored TextDataSet to handle different NLP tasks
Three tasks: text classification, text pair classification, conversation
PEFT/LoRA for conversation task
Pypgrx for callbacks to print info statements and insert logs into pgml.logs table
New tasks have to be added to pgml.tasks:
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text_pair_classification';
New
pgml.logs
table has to be added:Note: Training is initialized using a previous run and model from HF Hub.