|
|
import dataclasses |
|
|
import inspect |
|
|
import warnings |
|
|
from functools import wraps |
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union |
|
|
from torch.utils.data import DataLoader |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from accelerate.state import PartialState |
|
|
import datasets |
|
|
from datasets import Dataset |
|
|
from datasets.arrow_writer import SchemaInferenceError |
|
|
from datasets.builder import DatasetGenerationError |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
DataCollator, |
|
|
DataCollatorForLanguageModeling, |
|
|
PreTrainedModel, |
|
|
PreTrainedTokenizerBase, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
) |
|
|
from transformers.modeling_utils import unwrap_model |
|
|
from transformers.trainer_callback import TrainerCallback |
|
|
from transformers.trainer_utils import EvalPrediction |
|
|
from transformers.trainer import _is_peft_model |
|
|
from trl.extras.dataset_formatting import get_formatting_func_from_dataset |
|
|
from trl.import_utils import is_peft_available |
|
|
from trl.trainer.utils import ( |
|
|
ConstantLengthDataset, |
|
|
DataCollatorForCompletionOnlyLM, |
|
|
RichProgressCallback, |
|
|
neftune_post_forward_hook, |
|
|
peft_module_casting_to_bf16, |
|
|
trl_sanitze_kwargs_for_tagging, |
|
|
) |
|
|
from transformers.utils import is_datasets_available |
|
|
from transformers.trainer_utils import seed_worker |
|
|
from transformers.models.auto.modeling_auto import ( |
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, |
|
|
MODEL_MAPPING_NAMES, |
|
|
) |
|
|
|
|
|
if is_peft_available(): |
|
|
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training |
|
|
|
|
|
|
|
|
|
|
|
class CustomTrainer(Trainer): |
|
|
def __init__( |
|
|
self, |
|
|
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
|
|
args: Optional[TrainingArguments] = None, |
|
|
data_collator: Optional[DataCollator] = None, |
|
|
train_dataset: Optional[Dataset] = None, |
|
|
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, |
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None, |
|
|
model_init: Optional[Callable[[], PreTrainedModel]] = None, |
|
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, |
|
|
callbacks: Optional[List[TrainerCallback]] = None, |
|
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), |
|
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, |
|
|
peft_config: Optional["PeftConfig"] = None, |
|
|
dataset_text_field: Optional[str] = None, |
|
|
packing: Optional[bool] = False, |
|
|
formatting_func: Optional[Callable] = None, |
|
|
max_seq_length: Optional[int] = None, |
|
|
infinite: Optional[bool] = None, |
|
|
num_of_sequences: Optional[int] = 1024, |
|
|
chars_per_token: Optional[float] = 3.6, |
|
|
dataset_num_proc: Optional[int] = None, |
|
|
dataset_batch_size: int = 1000, |
|
|
neftune_noise_alpha: Optional[float] = None, |
|
|
model_init_kwargs: Optional[Dict] = None, |
|
|
dataset_kwargs: Optional[Dict] = None, |
|
|
eval_packing: Optional[bool] = None, |
|
|
): |
|
|
if model_init_kwargs is None: |
|
|
model_init_kwargs = {} |
|
|
elif not isinstance(model, str): |
|
|
raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): |
|
|
raise ValueError( |
|
|
"You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." |
|
|
) |
|
|
|
|
|
if is_peft_available() and peft_config is not None: |
|
|
if not isinstance(peft_config, PeftConfig): |
|
|
raise ValueError( |
|
|
"If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." |
|
|
f" and you passed a {type(peft_config)}." |
|
|
) |
|
|
|
|
|
if not isinstance(model, PeftModel): |
|
|
_support_gc_kwargs = hasattr( |
|
|
args, "gradient_checkpointing_kwargs" |
|
|
) and "gradient_checkpointing_kwargs" in list( |
|
|
inspect.signature(prepare_model_for_kbit_training).parameters |
|
|
) |
|
|
gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} |
|
|
is_sharded_qlora = False |
|
|
|
|
|
|
|
|
|
|
|
if getattr(model, "is_loaded_in_4bit", False): |
|
|
for _, param in model.named_parameters(): |
|
|
if param.__class__.__name__ == "Params4bit": |
|
|
is_sharded_qlora = param.data.device.type == "cpu" |
|
|
break |
|
|
if getattr(model, "is_loaded_in_8bit", False) or ( |
|
|
getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora |
|
|
): |
|
|
prepare_model_kwargs = { |
|
|
"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) |
|
|
} |
|
|
|
|
|
if _support_gc_kwargs: |
|
|
prepare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs |
|
|
|
|
|
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) |
|
|
|
|
|
if args is not None: |
|
|
args = dataclasses.replace(args, gradient_checkpointing=False) |
|
|
elif getattr(args, "gradient_checkpointing", False) and ( |
|
|
"use_reentrant" not in gradient_checkpointing_kwargs |
|
|
or gradient_checkpointing_kwargs["use_reentrant"] |
|
|
): |
|
|
|
|
|
if hasattr(model, "enable_input_require_grads"): |
|
|
model.enable_input_require_grads() |
|
|
else: |
|
|
|
|
|
def make_inputs_require_grad(module, input, output): |
|
|
output.requires_grad_(True) |
|
|
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
model = get_peft_model(model, peft_config) |
|
|
if ( |
|
|
args is not None |
|
|
and args.bf16 |
|
|
and getattr(model, "is_loaded_in_4bit", False) |
|
|
and not is_sharded_qlora |
|
|
): |
|
|
peft_module_casting_to_bf16(model) |
|
|
|
|
|
if tokenizer is None: |
|
|
raise Exception("pleae provide a tokenizer") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) |
|
|
if getattr(tokenizer, "pad_token", None) is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
if max_seq_length is None: |
|
|
|
|
|
max_seq_length = min(tokenizer.model_max_length, 1024) |
|
|
|
|
|
warnings.warn( |
|
|
f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}" |
|
|
) |
|
|
|
|
|
self.dataset_num_proc = dataset_num_proc |
|
|
self.dataset_batch_size = dataset_batch_size |
|
|
|
|
|
self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") |
|
|
|
|
|
if neftune_noise_alpha is not None and self._trainer_supports_neftune: |
|
|
args.neftune_noise_alpha = neftune_noise_alpha |
|
|
warnings.warn( |
|
|
"You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`." |
|
|
) |
|
|
|
|
|
elif not self._trainer_supports_neftune: |
|
|
self.neftune_noise_alpha = neftune_noise_alpha |
|
|
|
|
|
if formatting_func is None and dataset_text_field is None: |
|
|
|
|
|
|
|
|
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) |
|
|
|
|
|
if not packing: |
|
|
if dataset_text_field is None and formatting_func is None: |
|
|
raise ValueError( |
|
|
"You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument." |
|
|
) |
|
|
|
|
|
if data_collator is None: |
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
|
|
|
|
|
with PartialState().local_main_process_first(): |
|
|
if dataset_kwargs is None: |
|
|
dataset_kwargs = {} |
|
|
if train_dataset is not None: |
|
|
train_dataset = self._prepare_dataset( |
|
|
train_dataset, |
|
|
tokenizer, |
|
|
packing, |
|
|
dataset_text_field, |
|
|
max_seq_length, |
|
|
formatting_func, |
|
|
num_of_sequences, |
|
|
chars_per_token, |
|
|
remove_unused_columns=args.remove_unused_columns if args is not None else True, |
|
|
**dataset_kwargs, |
|
|
) |
|
|
if eval_dataset is not None: |
|
|
_multiple = isinstance(eval_dataset, dict) |
|
|
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} |
|
|
|
|
|
eval_packing = packing if eval_packing is None else eval_packing |
|
|
|
|
|
for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): |
|
|
_eval_datasets[_eval_dataset_name] = self._prepare_dataset( |
|
|
_eval_dataset, |
|
|
tokenizer, |
|
|
eval_packing, |
|
|
dataset_text_field, |
|
|
max_seq_length, |
|
|
formatting_func, |
|
|
num_of_sequences, |
|
|
chars_per_token, |
|
|
remove_unused_columns=args.remove_unused_columns if args is not None else True, |
|
|
**dataset_kwargs, |
|
|
) |
|
|
if not _multiple: |
|
|
eval_dataset = _eval_datasets["singleton"] |
|
|
|
|
|
if tokenizer.padding_side is not None and tokenizer.padding_side != "right": |
|
|
warnings.warn( |
|
|
"You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " |
|
|
"overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." |
|
|
) |
|
|
|
|
|
super().__init__( |
|
|
model=model, |
|
|
args=args, |
|
|
data_collator=data_collator, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
tokenizer=tokenizer, |
|
|
model_init=model_init, |
|
|
compute_metrics=compute_metrics, |
|
|
callbacks=callbacks, |
|
|
optimizers=optimizers, |
|
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(self.model, "add_model_tags"): |
|
|
self.model.add_model_tags(self._tag_names) |
|
|
|
|
|
if self.args.max_steps > 0 and packing: |
|
|
warnings.warn( |
|
|
"You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached." |
|
|
) |
|
|
self.train_dataset.infinite = True |
|
|
elif self.args.max_steps == -1 and packing: |
|
|
self.train_dataset.infinite = False |
|
|
|
|
|
if any(isinstance(callback, RichProgressCallback) for callback in self.callback_handler.callbacks): |
|
|
for callback in self.callback_handler.callbacks: |
|
|
|
|
|
if callback.__class__.__name__ == "PrinterCallback": |
|
|
self.callback_handler.pop_callback(callback) |
|
|
|
|
|
def _prepare_dataset( |
|
|
self, |
|
|
dataset, |
|
|
tokenizer, |
|
|
packing, |
|
|
dataset_text_field, |
|
|
max_seq_length, |
|
|
formatting_func, |
|
|
num_of_sequences, |
|
|
chars_per_token, |
|
|
remove_unused_columns=True, |
|
|
append_concat_token=True, |
|
|
add_special_tokens=True, |
|
|
): |
|
|
if dataset is None: |
|
|
raise ValueError("The dataset should not be None") |
|
|
|
|
|
|
|
|
if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)): |
|
|
return dataset |
|
|
|
|
|
return self._prepare_non_packed_dataloader( |
|
|
tokenizer, |
|
|
dataset, |
|
|
dataset_text_field, |
|
|
max_seq_length, |
|
|
formatting_func, |
|
|
add_special_tokens, |
|
|
remove_unused_columns, |
|
|
) |
|
|
|
|
|
def _prepare_non_packed_dataloader( |
|
|
self, |
|
|
tokenizer, |
|
|
dataset, |
|
|
dataset_text_field, |
|
|
max_seq_length, |
|
|
formatting_func=None, |
|
|
add_special_tokens=True, |
|
|
remove_unused_columns=True, |
|
|
): |
|
|
use_formatting_func = formatting_func is not None and dataset_text_field is None |
|
|
self._dataset_sanity_checked = False |
|
|
|
|
|
|
|
|
def tokenize(element): |
|
|
outputs = tokenizer( |
|
|
|
|
|
element if not use_formatting_func else formatting_func(element), |
|
|
add_special_tokens=add_special_tokens, |
|
|
truncation=True, |
|
|
padding=False, |
|
|
max_length=max_seq_length, |
|
|
return_overflowing_tokens=False, |
|
|
return_length=False, |
|
|
) |
|
|
|
|
|
if use_formatting_func and not self._dataset_sanity_checked: |
|
|
if not isinstance(formatting_func(element), list): |
|
|
raise ValueError( |
|
|
"The `formatting_func` should return a list of processed strings since it can lead to silent bugs." |
|
|
) |
|
|
else: |
|
|
self._dataset_sanity_checked = True |
|
|
|
|
|
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} |
|
|
|
|
|
signature_columns = ["input_ids", "labels", "attention_mask"] |
|
|
|
|
|
extra_columns = list(set(dataset.column_names) - set(signature_columns)) |
|
|
|
|
|
if not remove_unused_columns and len(extra_columns) > 0: |
|
|
warnings.warn( |
|
|
"You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to " |
|
|
f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." |
|
|
) |
|
|
|
|
|
tokenized_dataset = dataset.map( |
|
|
tokenize, |
|
|
batched=False, |
|
|
remove_columns=['text'], |
|
|
num_proc=self.dataset_num_proc, |
|
|
batch_size=self.dataset_batch_size, |
|
|
input_columns=['text'], |
|
|
) |
|
|
|
|
|
return tokenized_dataset |
|
|
|
|
|
def get_train_dataloader(self) -> DataLoader: |
|
|
""" |
|
|
Returns the training [`~torch.utils.data.DataLoader`]. |
|
|
|
|
|
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed |
|
|
training if necessary) otherwise. |
|
|
|
|
|
Subclass and override this method if you want to inject some custom behavior. |
|
|
""" |
|
|
if self.train_dataset is None: |
|
|
raise ValueError("Trainer: training requires a train_dataset.") |
|
|
|
|
|
train_dataset = self.train_dataset |
|
|
data_collator = self.data_collator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataloader_params = { |
|
|
"batch_size": self._train_batch_size, |
|
|
"collate_fn": data_collator, |
|
|
"num_workers": self.args.dataloader_num_workers, |
|
|
"pin_memory": self.args.dataloader_pin_memory, |
|
|
"persistent_workers": self.args.dataloader_persistent_workers, |
|
|
} |
|
|
|
|
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset): |
|
|
dataloader_params["sampler"] = self._get_train_sampler() |
|
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last |
|
|
dataloader_params["worker_init_fn"] = seed_worker |
|
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor |
|
|
|
|
|
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) |
|
|
|
|
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: |
|
|
""" |
|
|
Returns the evaluation [`~torch.utils.data.DataLoader`]. |
|
|
|
|
|
Subclass and override this method if you want to inject some custom behavior. |
|
|
|
|
|
Args: |
|
|
eval_dataset (`torch.utils.data.Dataset`, *optional*): |
|
|
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted |
|
|
by the `model.forward()` method are automatically removed. It must implement `__len__`. |
|
|
""" |
|
|
if eval_dataset is None and self.eval_dataset is None: |
|
|
raise ValueError("Trainer: evaluation requires an eval_dataset.") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers: |
|
|
return self.accelerator.prepare(self._eval_dataloader) |
|
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset |
|
|
data_collator = self.data_collator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataloader_params = { |
|
|
"batch_size": self.args.eval_batch_size, |
|
|
"collate_fn": data_collator, |
|
|
"num_workers": self.args.dataloader_num_workers, |
|
|
"pin_memory": self.args.dataloader_pin_memory, |
|
|
"persistent_workers": self.args.dataloader_persistent_workers, |
|
|
} |
|
|
|
|
|
if not isinstance(eval_dataset, torch.utils.data.IterableDataset): |
|
|
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) |
|
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last |
|
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor |
|
|
|
|
|
|
|
|
|
|
|
eval_dataloader = DataLoader(eval_dataset, **dataloader_params) |
|
|
if self.args.dataloader_persistent_workers: |
|
|
self._eval_dataloader = eval_dataloader |
|
|
|
|
|
return self.accelerator.prepare(eval_dataloader) |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
|
if self.label_smoother is not None and "labels" in inputs: |
|
|
labels = inputs.pop("labels") |
|
|
else: |
|
|
labels = None |
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
if self.args.past_index >= 0: |
|
|
self._past = outputs[self.args.past_index] |
|
|
|
|
|
if labels is not None: |
|
|
unwrapped_model = unwrap_model(model) |
|
|
if _is_peft_model(unwrapped_model): |
|
|
model_name = unwrapped_model.base_model.model._get_name() |
|
|
else: |
|
|
model_name = unwrapped_model._get_name() |
|
|
model_name = unwrapped_model._get_name() |
|
|
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): |
|
|
loss = self.label_smoother(outputs, labels, shift_labels=True) |
|
|
else: |
|
|
loss = self.label_smoother(outputs, labels) |
|
|
else: |
|
|
if isinstance(outputs, dict) and "loss" not in outputs: |
|
|
raise ValueError( |
|
|
"The model did not return a loss from the inputs, only the following keys: " |
|
|
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." |
|
|
) |
|
|
|
|
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] |
|
|
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
|
|
|