Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM fine-tuning #1350

Merged
merged 36 commits into from
Mar 26, 2024
Merged

LLM fine-tuning #1350

merged 36 commits into from
Mar 26, 2024

Conversation

santiatpml
Copy link
Contributor

@santiatpml santiatpml commented Mar 4, 2024

  • Example: https://github.com/postgresml/postgresml/tree/santi-llm-fine-tuning?tab=readme-ov-file#llm-fine-tuning

  • Refactored TextDataSet to handle different NLP tasks

  • Three tasks: text classification, text pair classification, conversation

  • PEFT/LoRA for conversation task

  • Pypgrx for callbacks to print info statements and insert logs into pgml.logs table

  • New tasks have to be added to pgml.tasks:
    ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
    ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text_pair_classification';

  • New pgml.logs table has to be added:

CREATE TABLE pgml.logs (
    id SERIAL PRIMARY KEY,
    model_id BIGINT,
    project_id BIGINT,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    logs JSONB
);
SELECT pgml.tune(
    'financial_phrasebank_sentiment',
    task => 'text-classification',
    relation_name => 'pgml.financial_phrasebank_view',
    model_name => 'distilbert-base-uncased',
    test_size => 0.2,
    test_sampling => 'last',
    hyperparams => '{
        "training_args" : {
          "learning_rate": 2e-5,
          "per_device_train_batch_size": 16,
          "per_device_eval_batch_size": 16,
          "num_train_epochs": 10,
          "weight_decay": 0.01,
          "hub_token" : "token",
          "push_to_hub" : true
        },
        "dataset_args" : { 
          "text_column" : "sentence", 
          "class_column" : "class" 
        }
    }'
);
  • Text pair classification.
    Note: Training is initialized using a previous run and model from HF Hub.
SELECT pgml.tune(
    'glue_mrpc_nli_2',
    task => 'text_pair_classification',
    relation_name => 'pgml.glue_view',
    model_name => 'santiadavani/glue_mrpc_nli_2',
    test_size => 0.5,
    test_sampling => 'last',
    hyperparams => '{
        "training_args" : {
            "learning_rate": 2e-5,
            "per_device_train_batch_size": 16,
            "per_device_eval_batch_size": 16,
            "num_train_epochs": 1,
            "weight_decay": 0.01
        },
        "dataset_args" : { "text1_column" : "sentence1", "text2_column" : "sentence2", "class_column" : "class" }
    }'
);
  • Conversation
SELECT pgml.tune(
    'alpaca-gpt4-conversation-llama2-7b-chat',
    task => 'conversation',
    relation_name => 'pgml.chat_sample',
    model_name => 'meta-llama/Llama-2-7b-chat-hf',
    test_size => 0.8,
    test_sampling => 'last',
    hyperparams => '{
        "training_args" : {
            "learning_rate": 2e-5,
            "per_device_train_batch_size": 4,
            "per_device_eval_batch_size": 4,
            "num_train_epochs": 1,
            "weight_decay": 0.01,
            "hub_token" : "read_write_token", 
            "push_to_hub" : true,
            "optim" : "adamw_bnb_8bit",
            "gradient_accumulation_steps" : 4,
            "gradient_checkpointing" : true
        },
        "dataset_args" : { "system_column" : "instruction", "user_column" : "input", "assistant_column" : "output" },
        "lora_config" : {"r": 2, "lora_alpha" : 4, "lora_dropout" : 0.05, "bias": "none", "task_type": "CAUSAL_LM"},
        "load_in_8bit" : false,
        "token" : "read_token"
    }'
);

fn insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult<String> {

let id_value = Spi::get_one_with_args::<i64>(
"INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we include a migration for this table somewhere? We need to make sure it's created on all databases running PostgresML.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, need to add the following three to our migration once we freeze on the version number.

ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text_pair_classification';
CREATE TABLE IF NOT EXISTS pgml.logs (
    id SERIAL PRIMARY KEY,
    model_id BIGINT,
    project_id BIGINT,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    logs JSONB
);

MarkupSafe==2.1.3
marshmallow==3.20.1
matplotlib==3.8.2
maturin==1.4.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think you need maturin inside PostgresML deployments. This may be a "leak" from the pypgrx extension venv.

@@ -803,7 +803,7 @@ fn tune(
project_name: &str,
task: default!(Option<&str>, "NULL"),
relation_name: default!(Option<&str>, "NULL"),
y_column_name: default!(Option<&str>, "NULL"),
_y_column_name: default!(Option<&str>, "NULL"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the underscore? Is it because it's not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct.

from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl.trainer import ConstantLengthDataset
from peft import LoraConfig, get_peft_model
from pypgrx import print_info, insert_logs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to make sure we either import this conditionally (only for fine tuning) and we include this in requirements.linux.txt. I didn't see a Mac OS build for this and for the M1/M2 architecture, we've been doing releases manually from our Macs (Github actions doesn't have M1 builders).

This makes me thing we should start cross-compiling soon. Rust supports this pretty well, maturin may need a patch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't get fine tuning to work on Mac OS. It keeps crashing. How about I check for the operating system and bail out if it is mac?
requirements.linux.txt is updated with trl and peft.

logs["step"] = state.global_step
logs["max_steps"] = state.max_steps
logs["timestamp"] = str(datetime.now())
print_info(json.dumps(logs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use use print(), this will appear in Postgres logs. It won't be pretty, but we can add a function that formats it correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add indent in json.dumps() to pretty print.

trainable_model_params += param.numel()

# Calculate and print the number and percentage of trainable parameters
print_info(f"Trainable model parameters: {trainable_model_params}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kczimm This will require us to use the main thread for ML workloads in our cloud.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A PR with that is close. What's the reason we need main thread here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need logging visibility during fine tuning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to a commit by @levkk, we should be able to log from any thread.

#######################


class PGMLCallback(TrainerCallback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't be opposed to this functionality living in it's own file like tune.py, since transformers is getting a bit beefy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformers.py is hardcoded in several places. Needs some more refactoring and testing to accomplish moving finetuning code to tune.py. Will revisit this in the next iteration. #1378

self.model_id = model_id

def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why throw away total_flos?

}

#[pyfunction]
fn print_info(info: String) -> PyResult<String> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be more reusable as log(level, msg)

else:
self.model_name = hyperparameters.pop("model_name")

if "token" in hyperparameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a model init param, not a hyperparam, like many other things in this list? Maybe hyperparams covers everything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. Moved all the parameters to hyperparams.

trainable_model_params += param.numel()

# Calculate and print the number and percentage of trainable parameters
print_info(f"Trainable model parameters: {trainable_model_params}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need logging visibility during fine tuning.

y_train,
x_test,
y_test,
Ok::<std::option::Option<()>, i64>(Some(())) // this return type is nonsense
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)


let text1_column_value = dataset_args
.0
.get("text1_column")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we require these column names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for text pair classification - (natural language inference, qnli etc.), we need three columns - text1, text2 and the class.


let system_column_value = dataset_args
.0
.get("system_column")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How standard are these names these days?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For conversation task, system, user and assistant have become standard keys.

Ok(info)
}
/// A Python module implemented in Rust.
#[pymodule]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this crate is interdependent, what if we moved this whole pymodule into the main pgml-extension crate, under bindings/python/mod.rs instead of publishing it as a separate crate?

@@ -14,3 +14,5 @@
.DS_Store


# venv
pgml-venv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

@santiatpml santiatpml requested a review from levkk March 26, 2024 19:54
project_id BIGINT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
logs JSONB
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new line

@santiatpml santiatpml merged commit f75114b into master Mar 26, 2024
1 check passed
@santiatpml santiatpml deleted the santi-llm-fine-tuning branch March 26, 2024 20:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants