Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Finetune] replace fine-tuning DefaultTrainer with transformers.Trainer #204

Merged
merged 19 commits into from May 13, 2024

Conversation

harborn
Copy link
Contributor

@harborn harborn commented Apr 25, 2024

refactoring: replace fine-tuning DefaultTrainer with transformers.Trainer

this update will:

  1. disable DefaultTrainer, which only contains only a very small subset of training functionality.
  2. enable completely arguments support for different training task.

@harborn harborn changed the title [finetune] replace fine-tuning DefaultTrainer with transformers.Trainer [Finetune] replace fine-tuning DefaultTrainer with transformers.Trainer Apr 25, 2024
try:
common.logger.info("trainer prepare start")
model.training = True
trainer.prepare(model, tokenizer, datasets, optimizer, accelerator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does remove prepare function in default_trainer.py?

Comment on lines -205 to -210
if accelerate_mode == "FSDP":
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(
offload_to_cpu=False, rank0_only=False
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformers.Trainer how to distinguish FSDP and Deepspeed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for FSDP training, we should change TrainingArguments's option as following:
fsdp_config with a json config file: fsdp_config.json
fsdp auto_wrap set True

Comment on lines 130 to 141
args = {
"output_dir": config["General"]["output_dir"],
"gradient_checkpointing": config["General"]["enable_gradient_checkpointing"],
"save_strategy": save_strategy,
"bf16": config["Training"]["mixed_precision"] == "bf16",
"num_train_epochs": config["Training"]["epochs"],
"per_device_train_batch_size": config["Training"]["batch_size"],
"per_device_eval_batch_size": config["Training"]["batch_size"],
"learning_rate": config["Training"]["learning_rate"],
"logging_steps": config["Training"]["logging_steps"],
"lr_scheduler_type": config["Training"]["lr_scheduler"],
"weight_decay": config["Training"]["weight_decay"],
"gradient_accumulation_steps": config["Training"]["gradient_accumulation_steps"],
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the max_train_steps parameter, otherwise, the UI will not be able to demo finetuning task in a short time.

if max_train_step != 0:
finetune_config["Training"]["max_train_steps"] = max_train_step

In addition, can other parameters of lr_scheduler such as num_warmup_steps be supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.
in our yaml file, the option is max_train_steps
while for the TrainingArguments, the option is max_steps

Comment on lines 342 to 339
# accelerate_env_vars = get_accelerate_environment_variable(config)
# runtime_env["env_vars"].update(accelerate_env_vars)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is get_accelerate_environment_variable no longer needed? Let us remove this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, no longer needed. removed.

@@ -176,13 +173,22 @@ def group_texts(examples):
desc=f"Grouping texts in chunks of {block_size}",
)

return tokenized_datasets

def convert_dataset(self, tokenizer, dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general_processer is just generating dataloader, it may be better to call it 'prepare' or 'prepare_dataloader'. And please align function names in other files such as pretrain modules'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

train_dataloader, eval_dataloader = self.dataprocesser.prepare(tokenizer, dataset)
train_dataloader, eval_dataloader = self.dataprocesser.convert_dataset(tokenizer, dataset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we no longer use default_trainer, should this file be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe default_trainer.py will be used later, if sure this file can be removed, will delete it later.

Comment on lines +253 to 260
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is resuming finetuning from checkpoint supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint and model result saving are control by argument in training_args, there are options:
save_only_model, save_strategy, save_steps, output_dir, save_steps, etc.
all this options can be used to match different need of all kinds saving and loading models and checkpoints.
For more details, please see here.https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/trainer#transformers.TrainingArguments

Comment on lines -196 to -201
optimizer = common.optimizer.Optimizer.registory.get("DefaultOptimizer")()(
model,
config={
"name": config["Training"]["optimizer"],
"config": {"lr": config["Training"]["learning_rate"]},
},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove optimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimizer will create in transformers.Trainer or optimum.habana.transformers.GaudiTrainer

Comment on lines +253 to +259
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pass optimizer parameter into the Trainer according to the user's configuration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimizer's parameters are parts of TrainingArguments, and optimizer will be create before calling epoch-step-loop in Trainer.train()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I don’t see the optimizer configured in TrainingArguments. How does user set optimizer parameter?

args.update({"use_lazy_mode": config["Training"]["hpu_execution_mode"] == "lazy"})
args.update({"pipelining_fwd_bwd": True})
args.update({"throughput_warmup_steps": 3})
args.update({"adam_epsilon": 1e-8})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above three configs are hard coded? Are these values also used in our previous implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this hard coded value. I am thinking should add those options to our yaml config file.

def convert_to_training_args(cls, config):
device = config["Training"]["device"]
accelerate_mode = config["Training"]["accelerate_mode"]
checkpoint_dir = config["General"]["checkpoint_dir"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this is not set.

Copy link
Contributor Author

@harborn harborn May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this option in our yaml config file is used to save checkpoint files, while for Trainer, checkpoint files will be saved to output_dir, so checkpoint_dir seems meaningless.
I will change this option to save_strategy, it will control how the checkpoint files saving.

"logging_steps": config["Training"]["logging_steps"],
"lr_scheduler_type": config["Training"]["lr_scheduler"],
"weight_decay": config["Training"]["weight_decay"],
"gradient_accumulation_steps": config["Training"]["gradient_accumulation_steps"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have default values for all these configurations in config? Previously we wrote config["Training"].get("gradient_accumulation_steps", 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this option has default value 1 in finetune_config.py

@harborn harborn force-pushed the replace-trainer branch 3 times, most recently from 55e7968 to 4adf98a Compare May 10, 2024 02:22
Comment on lines 334 to 336
"CCL_ZE_IPC_EXCHANGE": "sockets",
"CCL_WORKER_COUNT": str(ccl_worker_count),
"CCL_LOG_LEVEL": "info",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these ccl configurations no longer needed?

@KepingYan
Copy link
Contributor

LGTM

@harborn harborn merged commit 3523011 into intel:main May 13, 2024
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants