Hey
I am finetuning models using HF with a few custom callbacks to measure loss and perform MCQ evals. I notice that the training is very fast but eval is slow and I’m wondering if there are any obvious tricks that I’ve left on the table. Here’s my code -
# Load dataset
with log_step("Load dataset"):
logging.info(f"data_file={args.data_file}")
ds_splits = load_and_prepare_datasets(args.data_file, eval_split_percentage=args.eval_split_percentage, seed=args.seed)
logging.info(f"splits={{ {', '.join(f'{k}: {len(v)}' for k, v in ds_splits.items())} }}")
# Load tokenizer
with log_step("Load tokenizer"):
logging.info(f"model_name_or_path={args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
# Some tokenizers don't have pad token -- set it to eos_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Tokenize dataset
logging.info("Tokenizing dataset...")
tokenized_splits = {}
for split_name, ds in ds_splits.items():
with log_step(f"Tokenize [{split_name}]"):
tokenized = ds.map(
lambda examples: tokenize_function(examples, tokenizer),
batched=True,
# Remove original columns like 'text' so downstream only
# sees tokenized fields (e.g., input_ids, attention_mask).
# Keeping 'text' here leads to TypeError when concatenating.
remove_columns=ds.column_names,
desc=f"Tokenizing {split_name}",
)
logging.info(f"{split_name}: {len(tokenized)} examples")
tokenized_splits[split_name] = tokenized
lm_splits = tokenized_splits
# Load multiple evaluation datasets if provided
with log_step("Load extra evaluation datasets"):
extra_eval_raw = load_eval_datasets(args.eval_data_files)
extra_eval_tokenized = {}
# Tokenize each eval dataset
for name, ds in extra_eval_raw.items():
with log_step(f"Tokenize extra eval set [{name}]"):
tok = ds.map(
lambda ex: tokenize_function(ex, tokenizer),
batched=True,
remove_columns=ds.column_names,
desc=f"Tokenizing {name}",
)
extra_eval_tokenized[name] = tok
# Load multiple MCQ evaluation datasets if provided
with log_step("Load extra MCQ evaluation datasets"):
if args.mcq_eval_data_files:
mcq_eval_sets = load_mcq_eval_datasets(args.mcq_eval_data_files)
# Optionally limit samples (for quick tests)
if args.max_train_samples:
lm_splits["train"] = lm_splits["train"].select(range(min(len(lm_splits["train"]), args.max_train_samples)))
if args.max_eval_samples and "validation" in lm_splits:
lm_splits["validation"] = lm_splits["validation"].select(range(min(len(lm_splits["validation"]), args.max_eval_samples)))
# Load model
with log_step("Load model"):
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
# attn_implementation='flash_attention_2'
)
# Resize tokenizer embeddings if we added tokens
# model.resize_token_embeddings(len(tokenizer))
try:
n_params = sum(p.numel() for p in model.parameters())
logging.info(f"Model parameters: {n_params:,}")
except Exception:
pass
# Let W&B watch the model (if enabled)
if getattr(args, "report_to", "none") == "wandb":
try:
import wandb # type: ignore
wandb.watch(model, log="all", log_freq=100)
except Exception:
pass
# Data collator
# For causal LM, DataCollatorForLanguageModeling with mlm=False returns labels=input_ids
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Training args
training_args = TrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
eval_strategy="steps" if "validation" in lm_splits else "no",
eval_steps=500 if "validation" in lm_splits else None,
save_strategy="no" if (getattr(args, "disable_checkpoints", False) or getattr(args, "disable_checkpoints_including_last", False)) else "epoch",
save_only_model=True if getattr(args, "save_only_model", False) else False,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_steps=args.warmup_steps,
lr_scheduler_type=args.lr_scheduler_type,
# Log loss after every optimizer step
logging_strategy="steps",
logging_steps=1,
logging_first_step=True,
logging_dir=os.path.join(args.output_dir, "logs"),
bf16=use_bf16,
run_name=args.wandb_run_name if getattr(args, "wandb_run_name", None) else None,
dataloader_num_workers=4,
push_to_hub=args.push_to_hub,
seed=args.seed,
report_to=args.report_to,
)
# Trainer
logging.info("Preparing Trainer...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_splits["train"],
eval_dataset=lm_splits.get("validation", None),
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=None,
)
class InitialEvalCallback(TrainerCallback):
def __init__(
self,
trainer,
eval_train: bool = True,
eval_validation: bool = True,
extra_eval_sets: Dict[str, Dataset] = None,
):
self.trainer = trainer
self.eval_train = eval_train
self.eval_validation = eval_validation
self.extra_eval_sets = extra_eval_sets or {}
self._done = False
def on_train_begin(self, args, state, control, **kwargs):
if self._done:
return control
try:
logging.info("[Callback] Computing initial metrics at step 0…")
# ---- TRAIN SET ----
if self.eval_train and self.trainer.train_dataset is not None:
init_train = self.trainer.evaluate(
eval_dataset=self.trainer.train_dataset,
metric_key_prefix="initial_train",
)
self.trainer.log_metrics("initial_train", init_train)
self.trainer.save_metrics("initial_train", init_train)
if "initial_train_loss" in init_train:
logging.info(f"Initial train loss: {init_train['initial_train_loss']:.4f}")
try:
self.trainer.log({"loss": float(init_train["initial_train_loss"])})
except Exception:
pass
# ---- VALIDATION SET ----
if self.eval_validation and self.trainer.eval_dataset is not None:
init_val = self.trainer.evaluate(
eval_dataset=self.trainer.eval_dataset,
metric_key_prefix="initial_validation",
)
self.trainer.log_metrics("initial_validation", init_val)
self.trainer.save_metrics("initial_validation", init_val)
if "initial_validation_loss" in init_val:
logging.info(f"Initial validation loss: {init_val['initial_validation_loss']:.4f}")
try:
self.trainer.log({"eval_loss": float(init_val["initial_validation_loss"])})
except Exception:
pass
# ---- EXTRA EVAL DATASETS ----
for name, dataset in self.extra_eval_sets.items():
logging.info(f"[Callback] Evaluating extra dataset '{name}' at step 0…")
# CRITICAL CHANGE: Use `name` directly, not `initial_{name}`
metrics = self.trainer.evaluate(
eval_dataset=dataset,
metric_key_prefix=name,
)
# Log to disk (keep as is)
self.trainer.log_metrics(name, metrics)
self.trainer.save_metrics(name, metrics)
# CRITICAL CHANGE: Explicitly log to W&B
# This pushes keys like "nyt_loss" at step 0
self.trainer.log(metrics)
except Exception as e:
logging.warning(f"[Callback] Could not compute/log initial metrics: {e}")
finally:
self._done = True
return control
class EvalAggregateLossCallback(TrainerCallback):
"""
At end of each epoch, run Trainer.evaluate() on each extra dataset
and log the aggregated loss.
"""
def __init__(self, trainer, eval_sets):
self.trainer = trainer
self.eval_sets = eval_sets
def on_epoch_end(self, args, state, control, **kwargs):
for name, dataset in self.eval_sets.items():
metrics = self.trainer.evaluate(
eval_dataset=dataset,
metric_key_prefix=f"{name}"
)
# Log through trainer (handles W&B, JSON logs, etc.)
self.trainer.log_metrics(name, metrics)
self.trainer.save_metrics(name, metrics)
# Pretty print
loss_key = f"{name}_loss"
if loss_key in metrics:
logging.info(f"[Epoch {state.epoch}] {name} loss = {metrics[loss_key]:.4f}")
return control
class MCQEvalCallback(TrainerCallback):
"""
At end of each epoch, evaluate MCQ datasets (raw JSONL datasets that have
'query' and 'choices' fields). For each row, compute log-likelihood of
each choice given the prefix (query). Writes a JSONL file per dataset
containing original fields plus "choice_logliks": [float,...] and
"choice_token_details": [[{token_id: int, token_str: str, logprob: float}, ...], ...].
Efficient design:
- Uses the Trainer's model in-place (no copying).
- Batches all choices for the same query into a single forward pass (padded).
- Runs inference with torch.no_grad and model.eval().
"""
def __init__(self, trainer, mcq_datasets: Dict[str, Dataset], tokenizer, output_dir: str, max_choices_batch: int, mcq_eval_debug: bool = False):
self.trainer = trainer
self.mcq_datasets = mcq_datasets
self.tokenizer = tokenizer
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.max_choices_batch = max_choices_batch
self._initial_eval_done = False # New flag to prevent re-running initial eval
self.mcq_eval_debug = mcq_eval_debug
# --- MODIFIED: Added return type for token_details ---
def _compute_logliks_for_query(self, model, device, query: str, choices: list) -> Tuple[List[float], List[List[Dict]]]:
"""
Given a query (string) and a list of choice strings, return:
1. A list of total log-likelihoods (floats) where each value is log p(choice_tokens | query_tokens).
2. A list of lists, where the outer list corresponds to each choice, and the inner list
contains dicts with token_id, token_str, and logprob for each token in the choice.
"""
# prepare combined inputs: query + choice (no special tokens added)
# We intentionally add add_special_tokens=True so we get BOS prefix for models that use BOS
inputs = [query + " " + c for c in choices] # adding a space to separate can help tokenization
enc = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=False, add_special_tokens=True)
input_ids = enc["input_ids"].to(device) # (B, L)
attention_mask = enc.get("attention_mask", None)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
# compute query token length once (use tokenizer with special tokens)
q_enc = self.tokenizer(query, add_special_tokens=True)
qlen = len(q_enc["input_ids"])
B, L = input_ids.shape
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits # (B, L, V)
# shift logits/labels: logits[:-1] predict tokens 1..L-1
shift_logits = logits[:, :-1, :].contiguous() # (B, L-1, V)
shift_labels = input_ids[:, 1:].contiguous() # (B, L-1)
# compute log softmax once
log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1) # (B, L-1, V)
result_lls = []
# --- NEW: for token-wise details (logprob, id, str) ---
if self.mcq_eval_debug:
result_token_details = []
for b in range(B):
# actual unpadded length
if attention_mask is None:
# if no mask was returned, assume entire length is used
seq_len = L
else:
seq_len = int(attention_mask[b].sum().item())
# query tokens occupy indices 0..qlen-1
# choice tokens occupy indices qlen .. seq_len-1 (inclusive)
if qlen >= seq_len:
# no choice tokens present (or query uses whole sequence) => undefined conditional prob
result_lls.append(float("-1e9"))
if self.mcq_eval_debug:
result_token_details.append([])
continue
# in shifted labels coordinates, prediction for token at original index t is at index (t-1)
start_label_idx = qlen - 1 # corresponds to predicting token at index qlen
end_label_idx = seq_len - 2 # corresponds to predicting token at index seq_len-1
if start_label_idx > end_label_idx:
# no tokens to sum
result_lls.append(float("-1e9"))
if self.mcq_eval_debug:
result_token_details.append([])
continue
# gather the log probs for the true labels at those positions
label_positions = torch.arange(start_label_idx, end_label_idx + 1, device=device)
# shift_labels[b] length is L-1; index into it at label_positions
token_ids_tensor = shift_labels[b].index_select(0, label_positions) # (num_choice_tokens,)
# Step 1 — pick the relevant positions where tokens corresponding to choice are
# Step 2 — pick the log-prob of the correct token at each selected time
# Step 3 — remove the singleton dimension
token_logps_tensor = log_probs[b].index_select(0, label_positions).gather(1, token_ids_tensor.unsqueeze(-1)).squeeze(-1)
if self.mcq_eval_debug:
# Store the token-wise log probs, IDs, and strings
token_ids_list = token_ids_tensor.tolist()
token_logps_list = token_logps_tensor.tolist()
# --- NEW: Decode token IDs for token strings ---
# We can't batch decode a tensor efficiently here because we only want to decode
# the *choice* tokens, not the whole padded sequence.
token_strings = self.tokenizer.convert_ids_to_tokens(token_ids_list)
token_details = []
for token_id, token_str, logprob in zip(token_ids_list, token_strings, token_logps_list):
# Use tokenizer.decode([token_id]) for a cleaner string representation
# that handles special prefix characters (like 'Ä ' for GPT-2 tokenizers).
token_details.append({
"token_id": token_id,
"token_str": self.tokenizer.decode([token_id], skip_special_tokens=False),
"logprob": logprob
})
result_token_details.append(token_details)
# ---------------------------------------------
# sum log probs
ll = float(token_logps_tensor.sum().item())
result_lls.append(ll)
if self.mcq_eval_debug:
return result_lls, result_token_details
else:
return result_lls, None
# --- MODIFIED: Updated to handle new return from _compute_logliks_for_query ---
def _run_mcq_evaluation(self, model, device, epoch_str: str):
"""Common logic to run MCQ evaluation and save results."""
model.eval()
for name, ds in self.mcq_datasets.items():
out_folder = self.output_dir / f"mcq_evals/{name}/"
out_file = out_folder / f"{epoch_str}.jsonl"
os.makedirs(out_folder, exist_ok=True)
logging.info(f"[MCQEval] Writing MCQ logliks for dataset '{name}' -> {out_file}")
with out_file.open("w", encoding="utf-8") as fout:
# iterate dataset rows
for row_idx, row in enumerate(ds):
# Accept either 'query' or 'text' for prefix
query = None
# Ensure row is a dictionary for consistent access
if not isinstance(row, dict):
row = dict(row)
query = row.get("query")
choices = row.get("choices")
if query is None or choices is None:
logging.warning(f"[MCQEval] Row {row_idx} missing 'query' or 'choices' — skipping")
continue
if not isinstance(choices, (list, tuple)):
logging.warning(f"[MCQEval] Row {row_idx} 'choices' not a list — skipping")
continue
# Batch all choices for this query; if there are very many choices, chunk them
all_lls = []
if self.mcq_eval_debug:
all_token_details = [] # --- NEW: for token-wise details ---
i = 0
while i < len(choices):
batch_choices = choices[i : i + self.max_choices_batch]
# --- MODIFIED: Capture both return values ---
lls, token_details = self._compute_logliks_for_query(model, device, query, batch_choices)
all_lls.extend(lls)
if self.mcq_eval_debug:
all_token_details.extend(token_details) # --- NEW: extend token-wise list ---
i += self.max_choices_batch
# Save a copy of the row with appended 'choice_logliks' and 'choice_token_details'
out_row = dict(row)
out_row["choice_logliks"] = all_lls
if self.mcq_eval_debug:
out_row["choice_token_details"] = all_token_details # --- NEW: Save token-wise log-probabilities/IDs/strings ---
fout.write(json.dumps(out_row, ensure_ascii=False) + "\n")
model.train()
# on_train_begin and on_epoch_end remain the same, just calling _run_mcq_evaluation
# which now saves the required data.
def on_train_begin(self, args, state, control, **kwargs):
"""Run MCQ evaluation before training starts, if not already done."""
if self._initial_eval_done or not self.mcq_datasets:
return control
logging.info("[MCQEval] Computing initial MCQ metrics at step 0...")
# Find model device
try:
device = next(self.trainer.model.parameters()).device
except StopIteration:
device = torch.device("cpu")
try:
# Use a descriptive string for the initial epoch file name
self._run_mcq_evaluation(self.trainer.model, device, epoch_str=f"epoch_0")
except Exception as e:
logging.error(f"[MCQEval] Initial evaluation failed: {e}")
finally:
self._initial_eval_done = True
return control
def on_epoch_end(self, args, state, control, **kwargs):
"""Run MCQ evaluation at the end of each epoch."""
if not self.mcq_datasets:
return control
# Find model device
try:
device = next(self.trainer.model.parameters()).device
except StopIteration:
device = torch.device("cpu")
epoch_str = f"epoch_{int(state.epoch)}"
logging.info(f"[MCQEval] Running MCQ evaluation at end of {epoch_str}...")
try:
# Call the common function
self._run_mcq_evaluation(self.trainer.model, device, epoch_str=epoch_str)
except Exception as e:
logging.error(f"[MCQEval] Evaluation at epoch {state.epoch} failed: {e}")
return control
trainer.add_callback(
InitialEvalCallback(
trainer,
eval_train=True,
eval_validation=("validation" in lm_splits),
extra_eval_sets=extra_eval_tokenized,
)
)
trainer.add_callback(
EvalAggregateLossCallback(
trainer,
eval_sets=extra_eval_tokenized,
)
)
if args.mcq_eval_data_files:
trainer.add_callback(
MCQEvalCallback(
trainer=trainer,
mcq_datasets=mcq_eval_sets,
tokenizer=tokenizer,
output_dir=args.output_dir,
max_choices_batch=100,
mcq_eval_debug=args.mcq_eval_debug,
)
)
# Train
logging.info("Starting training...")
if args.disable_tqdm:
trainer.args.disable_tqdm = True
train_result = trainer.train()
# Save final model unless checkpointing is disabled
if not getattr(args, "disable_checkpoints_including_last", False):
trainer.save_model() # Saves the tokenizer too for Trainer.save_model
else:
logging.info("Skipping final save_model() due to --disable_checkpoints_including_last")
# Save training metrics & state
metrics = train_result.metrics
metrics["train_samples"] = len(lm_splits["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluate final model on train and validation (if exists)
logging.info("Final evaluation on training dataset…")
final_train_metrics = trainer.evaluate(eval_dataset=lm_splits["train"], metric_key_prefix="final_train")
trainer.log_metrics("final_train", final_train_metrics)
trainer.save_metrics("final_train", final_train_metrics)
if "final_train_loss" in final_train_metrics:
logging.info(f"Final train loss: {final_train_metrics['final_train_loss']:.4f}")
if "validation" in lm_splits:
logging.info("Final evaluation on validation dataset…")
final_val_metrics = trainer.evaluate(eval_dataset=lm_splits["validation"], metric_key_prefix="final_validation")
trainer.log_metrics("final_validation", final_val_metrics)
trainer.save_metrics("final_validation", final_val_metrics)
logging.info(f"Done. Model & tokenizer saved to {args.output_dir}")
if args.push_to_hub:
logging.info("Pushed to hub (if logged in).")
if __name__ == "__main__":
main()