How do I speedup my callbacks and reduce stall before they start?

Hey

I am finetuning models using HF with a few custom callbacks to measure loss and perform MCQ evals. I notice that the training is very fast but eval is slow and I’m wondering if there are any obvious tricks that I’ve left on the table. Here’s my code -



    # Load dataset
    with log_step("Load dataset"):
        logging.info(f"data_file={args.data_file}")
        ds_splits = load_and_prepare_datasets(args.data_file, eval_split_percentage=args.eval_split_percentage, seed=args.seed)
        logging.info(f"splits={{ {', '.join(f'{k}: {len(v)}' for k, v in ds_splits.items())} }}")

    # Load tokenizer
    with log_step("Load tokenizer"):
        logging.info(f"model_name_or_path={args.model_name_or_path}")
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)

    # Some tokenizers don't have pad token -- set it to eos_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Tokenize dataset
    logging.info("Tokenizing dataset...")
    tokenized_splits = {}
    for split_name, ds in ds_splits.items():
        with log_step(f"Tokenize [{split_name}]"):
            tokenized = ds.map(
                lambda examples: tokenize_function(examples, tokenizer),
                batched=True,
                # Remove original columns like 'text' so downstream only
                # sees tokenized fields (e.g., input_ids, attention_mask).
                # Keeping 'text' here leads to TypeError when concatenating.
                remove_columns=ds.column_names,
                desc=f"Tokenizing {split_name}",
            )
            logging.info(f"{split_name}: {len(tokenized)} examples")
            tokenized_splits[split_name] = tokenized

    lm_splits = tokenized_splits

    # Load multiple evaluation datasets if provided
    with log_step("Load extra evaluation datasets"):
        extra_eval_raw = load_eval_datasets(args.eval_data_files)

    extra_eval_tokenized = {}

    # Tokenize each eval dataset
    for name, ds in extra_eval_raw.items():
        with log_step(f"Tokenize extra eval set [{name}]"):
            tok = ds.map(
                lambda ex: tokenize_function(ex, tokenizer),
                batched=True,
                remove_columns=ds.column_names,
                desc=f"Tokenizing {name}",
            )
        extra_eval_tokenized[name] = tok

    # Load multiple MCQ evaluation datasets if provided
    with log_step("Load extra MCQ evaluation datasets"):
        if args.mcq_eval_data_files:
            mcq_eval_sets = load_mcq_eval_datasets(args.mcq_eval_data_files)

    # Optionally limit samples (for quick tests)
    if args.max_train_samples:
        lm_splits["train"] = lm_splits["train"].select(range(min(len(lm_splits["train"]), args.max_train_samples)))
    if args.max_eval_samples and "validation" in lm_splits:
        lm_splits["validation"] = lm_splits["validation"].select(range(min(len(lm_splits["validation"]), args.max_eval_samples)))

    # Load model
    with log_step("Load model"):
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, 
        # attn_implementation='flash_attention_2'
        )
    # Resize tokenizer embeddings if we added tokens
    # model.resize_token_embeddings(len(tokenizer))
    try:
        n_params = sum(p.numel() for p in model.parameters())
        logging.info(f"Model parameters: {n_params:,}")
    except Exception:
        pass

    # Let W&B watch the model (if enabled)
    if getattr(args, "report_to", "none") == "wandb":
        try:
            import wandb  # type: ignore
            wandb.watch(model, log="all", log_freq=100)
        except Exception:
            pass

    # Data collator
    # For causal LM, DataCollatorForLanguageModeling with mlm=False returns labels=input_ids
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Training args
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        overwrite_output_dir=True,
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        eval_strategy="steps" if "validation" in lm_splits else "no",
        eval_steps=500 if "validation" in lm_splits else None,
        save_strategy="no" if (getattr(args, "disable_checkpoints", False) or getattr(args, "disable_checkpoints_including_last", False)) else "epoch",
        save_only_model=True if getattr(args, "save_only_model", False) else False,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_steps=args.warmup_steps,
        lr_scheduler_type=args.lr_scheduler_type,
        # Log loss after every optimizer step
        logging_strategy="steps",
        logging_steps=1,
        logging_first_step=True,
        logging_dir=os.path.join(args.output_dir, "logs"),
        bf16=use_bf16,
        run_name=args.wandb_run_name if getattr(args, "wandb_run_name", None) else None,
        dataloader_num_workers=4,
        push_to_hub=args.push_to_hub,
        seed=args.seed,
        report_to=args.report_to,
    )

    # Trainer
    logging.info("Preparing Trainer...")

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=lm_splits["train"],
        eval_dataset=lm_splits.get("validation", None),
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=None,
    )

    class InitialEvalCallback(TrainerCallback):
        def __init__(
            self,
            trainer,
            eval_train: bool = True,
            eval_validation: bool = True,
            extra_eval_sets: Dict[str, Dataset] = None,
        ):
            self.trainer = trainer
            self.eval_train = eval_train
            self.eval_validation = eval_validation
            self.extra_eval_sets = extra_eval_sets or {}
            self._done = False

        def on_train_begin(self, args, state, control, **kwargs):
            if self._done:
                return control

            try:
                logging.info("[Callback] Computing initial metrics at step 0…")

                # ---- TRAIN SET ----
                if self.eval_train and self.trainer.train_dataset is not None:
                    init_train = self.trainer.evaluate(
                        eval_dataset=self.trainer.train_dataset,
                        metric_key_prefix="initial_train",
                    )
                    self.trainer.log_metrics("initial_train", init_train)
                    self.trainer.save_metrics("initial_train", init_train)

                    if "initial_train_loss" in init_train:
                        logging.info(f"Initial train loss: {init_train['initial_train_loss']:.4f}")
                        try:
                            self.trainer.log({"loss": float(init_train["initial_train_loss"])})
                        except Exception:
                            pass

                # ---- VALIDATION SET ----
                if self.eval_validation and self.trainer.eval_dataset is not None:
                    init_val = self.trainer.evaluate(
                        eval_dataset=self.trainer.eval_dataset,
                        metric_key_prefix="initial_validation",
                    )
                    self.trainer.log_metrics("initial_validation", init_val)
                    self.trainer.save_metrics("initial_validation", init_val)

                    if "initial_validation_loss" in init_val:
                        logging.info(f"Initial validation loss: {init_val['initial_validation_loss']:.4f}")
                        try:
                            self.trainer.log({"eval_loss": float(init_val["initial_validation_loss"])})
                        except Exception:
                            pass

                # ---- EXTRA EVAL DATASETS ----
                for name, dataset in self.extra_eval_sets.items():
                    logging.info(f"[Callback] Evaluating extra dataset '{name}' at step 0…")

                    # CRITICAL CHANGE: Use `name` directly, not `initial_{name}`
                    metrics = self.trainer.evaluate(
                        eval_dataset=dataset,
                        metric_key_prefix=name, 
                    )

                    # Log to disk (keep as is)
                    self.trainer.log_metrics(name, metrics)
                    self.trainer.save_metrics(name, metrics)

                    # CRITICAL CHANGE: Explicitly log to W&B 
                    # This pushes keys like "nyt_loss" at step 0
                    self.trainer.log(metrics)

            except Exception as e:
                logging.warning(f"[Callback] Could not compute/log initial metrics: {e}")

            finally:
                self._done = True

            return control


    class EvalAggregateLossCallback(TrainerCallback):
        """
        At end of each epoch, run Trainer.evaluate() on each extra dataset
        and log the aggregated loss.
        """
        def __init__(self, trainer, eval_sets):
            self.trainer = trainer
            self.eval_sets = eval_sets

        def on_epoch_end(self, args, state, control, **kwargs):
            for name, dataset in self.eval_sets.items():
                metrics = self.trainer.evaluate(
                    eval_dataset=dataset,
                    metric_key_prefix=f"{name}"
                )

                # Log through trainer (handles W&B, JSON logs, etc.)
                self.trainer.log_metrics(name, metrics)
                self.trainer.save_metrics(name, metrics)

                # Pretty print
                loss_key = f"{name}_loss"
                if loss_key in metrics:
                    logging.info(f"[Epoch {state.epoch}] {name} loss = {metrics[loss_key]:.4f}")

            return control

    class MCQEvalCallback(TrainerCallback):
            """
            At end of each epoch, evaluate MCQ datasets (raw JSONL datasets that have
            'query' and 'choices' fields). For each row, compute log-likelihood of
            each choice given the prefix (query). Writes a JSONL file per dataset
            containing original fields plus "choice_logliks": [float,...] and 
            "choice_token_details": [[{token_id: int, token_str: str, logprob: float}, ...], ...].

            Efficient design:
            - Uses the Trainer's model in-place (no copying).
            - Batches all choices for the same query into a single forward pass (padded).
            - Runs inference with torch.no_grad and model.eval().
            """
            def __init__(self, trainer, mcq_datasets: Dict[str, Dataset], tokenizer, output_dir: str, max_choices_batch: int, mcq_eval_debug: bool = False):
                self.trainer = trainer
                self.mcq_datasets = mcq_datasets
                self.tokenizer = tokenizer
                self.output_dir = Path(output_dir)
                self.output_dir.mkdir(parents=True, exist_ok=True)
                self.max_choices_batch = max_choices_batch
                self._initial_eval_done = False # New flag to prevent re-running initial eval
                self.mcq_eval_debug = mcq_eval_debug

            # --- MODIFIED: Added return type for token_details ---
            def _compute_logliks_for_query(self, model, device, query: str, choices: list) -> Tuple[List[float], List[List[Dict]]]:
                """
                Given a query (string) and a list of choice strings, return:
                1. A list of total log-likelihoods (floats) where each value is log p(choice_tokens | query_tokens).
                2. A list of lists, where the outer list corresponds to each choice, and the inner list
                contains dicts with token_id, token_str, and logprob for each token in the choice.
                """
                # prepare combined inputs: query + choice (no special tokens added)
                # We intentionally add add_special_tokens=True so we get BOS prefix for models that use BOS

                inputs = [query + " " + c for c in choices]  # adding a space to separate can help tokenization
                enc = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=False, add_special_tokens=True)
                input_ids = enc["input_ids"].to(device)  # (B, L)
                attention_mask = enc.get("attention_mask", None)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device)

                # compute query token length once (use tokenizer with special tokens)
                q_enc = self.tokenizer(query, add_special_tokens=True)
                qlen = len(q_enc["input_ids"])
                B, L = input_ids.shape

                with torch.no_grad():
                    outputs = model(input_ids, attention_mask=attention_mask)
                    logits = outputs.logits  # (B, L, V)

                    # shift logits/labels: logits[:-1] predict tokens 1..L-1
                    shift_logits = logits[:, :-1, :].contiguous()  # (B, L-1, V)
                    shift_labels = input_ids[:, 1:].contiguous()  # (B, L-1)

                    # compute log softmax once
                    log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)  # (B, L-1, V)

                    result_lls = []
                    # --- NEW: for token-wise details (logprob, id, str) ---
                    if self.mcq_eval_debug:
                        result_token_details = [] 

                    for b in range(B):
                        # actual unpadded length
                        if attention_mask is None:
                            # if no mask was returned, assume entire length is used
                            seq_len = L
                        else:
                            seq_len = int(attention_mask[b].sum().item())
                            
                        # query tokens occupy indices 0..qlen-1
                        # choice tokens occupy indices qlen .. seq_len-1 (inclusive)
                        if qlen >= seq_len:
                            # no choice tokens present (or query uses whole sequence) => undefined conditional prob
                            result_lls.append(float("-1e9"))
                            if self.mcq_eval_debug:
                                result_token_details.append([])
                            continue

                        # in shifted labels coordinates, prediction for token at original index t is at index (t-1)
                        start_label_idx = qlen - 1  # corresponds to predicting token at index qlen
                        end_label_idx = seq_len - 2  # corresponds to predicting token at index seq_len-1
                        
                        if start_label_idx > end_label_idx:
                            # no tokens to sum
                            result_lls.append(float("-1e9"))
                            if self.mcq_eval_debug:
                                result_token_details.append([])
                            continue

                        # gather the log probs for the true labels at those positions
                        label_positions = torch.arange(start_label_idx, end_label_idx + 1, device=device)
                        # shift_labels[b] length is L-1; index into it at label_positions
                        token_ids_tensor = shift_labels[b].index_select(0, label_positions)  # (num_choice_tokens,)
                        
                        # Step 1 — pick the relevant positions where tokens corresponding to choice are
                        # Step 2 — pick the log-prob of the correct token at each selected time
                        # Step 3 — remove the singleton dimension
                        token_logps_tensor = log_probs[b].index_select(0, label_positions).gather(1, token_ids_tensor.unsqueeze(-1)).squeeze(-1)

                        if self.mcq_eval_debug:
                            # Store the token-wise log probs, IDs, and strings
                            token_ids_list = token_ids_tensor.tolist()
                            token_logps_list = token_logps_tensor.tolist()
                            
                            # --- NEW: Decode token IDs for token strings ---
                            # We can't batch decode a tensor efficiently here because we only want to decode
                            # the *choice* tokens, not the whole padded sequence.
                            token_strings = self.tokenizer.convert_ids_to_tokens(token_ids_list)
                            
                            token_details = []
                            for token_id, token_str, logprob in zip(token_ids_list, token_strings, token_logps_list):
                                # Use tokenizer.decode([token_id]) for a cleaner string representation 
                                # that handles special prefix characters (like 'Ä ' for GPT-2 tokenizers).
                                token_details.append({
                                    "token_id": token_id,
                                    "token_str": self.tokenizer.decode([token_id], skip_special_tokens=False),
                                    "logprob": logprob
                                })
                            
                            result_token_details.append(token_details)
                            # ---------------------------------------------

                        # sum log probs
                        ll = float(token_logps_tensor.sum().item())
                        result_lls.append(ll)

                if self.mcq_eval_debug:
                    return result_lls, result_token_details
                else:
                    return result_lls, None

            # --- MODIFIED: Updated to handle new return from _compute_logliks_for_query ---
            def _run_mcq_evaluation(self, model, device, epoch_str: str):
                """Common logic to run MCQ evaluation and save results."""
                model.eval()

                for name, ds in self.mcq_datasets.items():
                    out_folder = self.output_dir / f"mcq_evals/{name}/"
                    out_file = out_folder / f"{epoch_str}.jsonl"
                    os.makedirs(out_folder, exist_ok=True)
                    logging.info(f"[MCQEval] Writing MCQ logliks for dataset '{name}' -> {out_file}")
                    
                    with out_file.open("w", encoding="utf-8") as fout:
                        # iterate dataset rows
                        for row_idx, row in enumerate(ds):
                            # Accept either 'query' or 'text' for prefix
                            query = None
                            # Ensure row is a dictionary for consistent access
                            if not isinstance(row, dict):
                                row = dict(row)
                                
                            query = row.get("query")
                            choices = row.get("choices")

                            if query is None or choices is None:
                                logging.warning(f"[MCQEval] Row {row_idx} missing 'query' or 'choices' — skipping")
                                continue

                            if not isinstance(choices, (list, tuple)):
                                logging.warning(f"[MCQEval] Row {row_idx} 'choices' not a list — skipping")
                                continue

                            # Batch all choices for this query; if there are very many choices, chunk them
                            all_lls = []    
                            if self.mcq_eval_debug:
                                all_token_details = [] # --- NEW: for token-wise details ---
                            i = 0
                            while i < len(choices):
                                batch_choices = choices[i : i + self.max_choices_batch]
                                
                                # --- MODIFIED: Capture both return values ---
                                lls, token_details = self._compute_logliks_for_query(model, device, query, batch_choices)
                                
                                all_lls.extend(lls)
                                if self.mcq_eval_debug:
                                    all_token_details.extend(token_details) # --- NEW: extend token-wise list ---
                                i += self.max_choices_batch

                            # Save a copy of the row with appended 'choice_logliks' and 'choice_token_details'
                            out_row = dict(row)
                            out_row["choice_logliks"] = all_lls
                            if self.mcq_eval_debug:
                                out_row["choice_token_details"] = all_token_details # --- NEW: Save token-wise log-probabilities/IDs/strings ---
                            fout.write(json.dumps(out_row, ensure_ascii=False) + "\n")

                model.train()

            # on_train_begin and on_epoch_end remain the same, just calling _run_mcq_evaluation
            # which now saves the required data.
            
            def on_train_begin(self, args, state, control, **kwargs):
                """Run MCQ evaluation before training starts, if not already done."""
                if self._initial_eval_done or not self.mcq_datasets:
                    return control
                
                logging.info("[MCQEval] Computing initial MCQ metrics at step 0...")

                # Find model device
                try:
                    device = next(self.trainer.model.parameters()).device
                except StopIteration:
                    device = torch.device("cpu")

                try:
                    # Use a descriptive string for the initial epoch file name
                    self._run_mcq_evaluation(self.trainer.model, device, epoch_str=f"epoch_0")
                except Exception as e:
                    logging.error(f"[MCQEval] Initial evaluation failed: {e}")
                finally:
                    self._initial_eval_done = True
                    
                return control

            def on_epoch_end(self, args, state, control, **kwargs):
                """Run MCQ evaluation at the end of each epoch."""
                if not self.mcq_datasets:
                    return control
                    
                # Find model device
                try:
                    device = next(self.trainer.model.parameters()).device
                except StopIteration:
                    device = torch.device("cpu")
                    
                epoch_str = f"epoch_{int(state.epoch)}"
                logging.info(f"[MCQEval] Running MCQ evaluation at end of {epoch_str}...")

                try:
                    # Call the common function
                    self._run_mcq_evaluation(self.trainer.model, device, epoch_str=epoch_str)
                except Exception as e:
                    logging.error(f"[MCQEval] Evaluation at epoch {state.epoch} failed: {e}")
                    
                return control

    trainer.add_callback(
        InitialEvalCallback(
            trainer,
            eval_train=True,
            eval_validation=("validation" in lm_splits),
            extra_eval_sets=extra_eval_tokenized,
        )
    )

    trainer.add_callback(
        EvalAggregateLossCallback(
            trainer,
            eval_sets=extra_eval_tokenized,
        )
    )

    if args.mcq_eval_data_files:
        trainer.add_callback(
            MCQEvalCallback(
                trainer=trainer,
                mcq_datasets=mcq_eval_sets,
                tokenizer=tokenizer,
                output_dir=args.output_dir,
                max_choices_batch=100,
                mcq_eval_debug=args.mcq_eval_debug,
            )
        )

    # Train
    logging.info("Starting training...")
    if args.disable_tqdm:
        trainer.args.disable_tqdm = True
    train_result = trainer.train()
    # Save final model unless checkpointing is disabled
    if not getattr(args, "disable_checkpoints_including_last", False):
        trainer.save_model()  # Saves the tokenizer too for Trainer.save_model
    else:
        logging.info("Skipping final save_model() due to --disable_checkpoints_including_last")

    # Save training metrics & state
    metrics = train_result.metrics
    metrics["train_samples"] = len(lm_splits["train"])
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    # Evaluate final model on train and validation (if exists)
    logging.info("Final evaluation on training dataset…")
    final_train_metrics = trainer.evaluate(eval_dataset=lm_splits["train"], metric_key_prefix="final_train")
    trainer.log_metrics("final_train", final_train_metrics)
    trainer.save_metrics("final_train", final_train_metrics)
    if "final_train_loss" in final_train_metrics:
        logging.info(f"Final train loss: {final_train_metrics['final_train_loss']:.4f}")

    if "validation" in lm_splits:
        logging.info("Final evaluation on validation dataset…")
        final_val_metrics = trainer.evaluate(eval_dataset=lm_splits["validation"], metric_key_prefix="final_validation")
        trainer.log_metrics("final_validation", final_val_metrics)
        trainer.save_metrics("final_validation", final_val_metrics)

    logging.info(f"Done. Model & tokenizer saved to {args.output_dir}")
    if args.push_to_hub:
        logging.info("Pushed to hub (if logged in).")


if __name__ == "__main__":
    main()
1 Like

Hmm…?


Your eval is slow because you are doing a lot of full-dataset evaluation inside callbacks, and on multi-GPU you are almost certainly doing that work once per rank.

You can speed this up a lot without losing signal by:

  1. Evaluating on subsets in callbacks.
  2. Running heavy eval less often.
  3. Making heavy callbacks run on the main process only.
  4. Simplifying your TrainingArguments eval behavior.

I will map this directly to your code.


1. Where your time is actually going

1.1 Trainer.evaluate is a full pass every time

Trainer.evaluate always runs a full forward pass over the provided eval_dataset and iterates the entire dataloader. There is no built-in “evaluate only first N batches” option. (Hugging Face)

So in your code:

  • At on_train_begin you call trainer.evaluate on:

    • Full train
    • Full validation
    • Every extra_eval_tokenized[name]
  • At every on_epoch_end you call trainer.evaluate on every extra eval dataset again.

Plus, TrainingArguments has:

eval_strategy="steps",
eval_steps=500,

So Trainer also evaluates on the validation set every 500 steps.

This is exactly the pattern people report when “evaluation is much slower than training” in HF discussions and StackOverflow. (Stack Overflow)

1.2 Multi-GPU: callbacks likely run once per rank

Trainer uses Accelerate internally. Callbacks are not automatically limited to rank 0. Unless you explicitly gate them, on_train_begin and on_epoch_end run on every process. Accelerate’s PartialState and docs show is_main_process as the correct way to restrict code to one process. (Hugging Face)

So for MCQ and extra eval:

  • Every GPU rank iterates all rows.
  • Every GPU rank may write JSONL files.

On 4 GPUs, that is roughly 4Ă— eval work.

1.3 InitialEvalCallback explains the long stall at start

Your InitialEvalCallback.on_train_begin:

  • Runs multiple full evaluate calls before the first optimization step.
  • Does not subsample.
  • Does not gate on rank.

That is a big, fixed stall before trainer.train() actually starts stepping.


2. General strategy

You do not need to change your high-level design. Just adjust how much work each eval does and how often it runs.

Four levers:

  1. Subsample datasets in callbacks using Dataset.select. (Hugging Face)
  2. Reduce frequency of heavy callbacks (every N epochs, not every epoch and not at step 0 on full data).
  3. Gate heavy callbacks to main process using PartialState.is_main_process. (Hugging Face)
  4. Simplify eval_strategy to avoid redundant builtin evals. (GitHub)

I will show concrete changes in this order:

  • TrainingArguments knobs.
  • A small maybe_subsample_dataset helper.
  • Faster InitialEvalCallback.
  • Faster EvalAggregateLossCallback.
  • Multi-GPU aware MCQEvalCallback.

3. TrainingArguments tweaks

Right now:

eval_strategy="steps",
eval_steps=500,

And your callbacks do their own heavy eval.

Better:

  • Either let Trainer handle only the main validation eval, and keep callbacks light.
  • Or turn Trainer eval off and own everything in callbacks.

A simple and common pattern for large runs:

training_args = TrainingArguments(
    # ...
    eval_strategy="epoch" if "validation" in lm_splits else "no",
    eval_steps=None,
    prediction_loss_only=True,  # loss only, no stored logits
)
  • eval_strategy="epoch" gives you one full validation eval per epoch. This is documented in TrainingArguments as one of the main modes and is cheaper than frequent step-based eval. (GitHub)
  • prediction_loss_only=True avoids storing full prediction arrays, which reduces some eval overhead and memory when you do not use compute_metrics. (Hugging Face)

Keep per_device_eval_batch_size as large as the GPU allows; larger eval batches mean fewer steps per evaluation. (Hugging Face)


4. Subsampling helper

Add this once and reuse it:

from typing import Optional
from datasets import Dataset

def maybe_subsample_dataset(ds: Optional[Dataset], max_samples: Optional[int]) -> Optional[Dataset]:
    if ds is None or max_samples is None:
        return ds
    if max_samples <= 0:
        return ds
    try:
        n = len(ds)
    except TypeError:
        # IterableDataset: length unknown, skip subsampling
        return ds
    if n <= max_samples:
        return ds
    # Hugging Face datasets.select selects rows by indices and is cheap
    return ds.select(range(max_samples))

This uses Dataset.select, which HF recommends for making small subsets. (Hugging Face)


5. Fast InitialEvalCallback: subset and main process

Current behavior:

  • Full train, full val, full extra eval at step 0.
  • On all ranks.

Replace with:

  • Train at step 0 optional, and on a subset only.
  • Val and extra eval on subsets.
  • Run on main process only with PartialState.
from accelerate import PartialState
from transformers import TrainerCallback, Trainer
from typing import Dict, Optional
from datasets import Dataset

class InitialEvalCallback(TrainerCallback):
    def __init__(
        self,
        trainer: Trainer,
        eval_train: bool = False,  # default off
        eval_validation: bool = True,
        extra_eval_sets: Optional[Dict[str, Dataset]] = None,
        max_train_samples: Optional[int] = None,
        max_eval_samples: Optional[int] = 1024,
    ):
        self.trainer = trainer
        self.eval_train = eval_train
        self.eval_validation = eval_validation
        self.extra_eval_sets = extra_eval_sets or {}
        self.max_train_samples = max_train_samples
        self.max_eval_samples = max_eval_samples
        self._done = False
        self._state = PartialState()

    def on_train_begin(self, args, state, control, **kwargs):
        if self._done:
            return control

        # only main process does this
        if not self._state.is_main_process:
            return control

        try:
            logging.info("[InitialEval] computing initial metrics on subsets")

            if self.eval_train and self.trainer.train_dataset is not None:
                train_sub = maybe_subsample_dataset(self.trainer.train_dataset, self.max_train_samples)
                metrics = self.trainer.evaluate(eval_dataset=train_sub, metric_key_prefix="initial_train")
                self.trainer.log_metrics("initial_train", metrics)
                self.trainer.save_metrics("initial_train", metrics)

            if self.eval_validation and self.trainer.eval_dataset is not None:
                val_sub = maybe_subsample_dataset(self.trainer.eval_dataset, self.max_eval_samples)
                metrics = self.trainer.evaluate(eval_dataset=val_sub, metric_key_prefix="initial_validation")
                self.trainer.log_metrics("initial_validation", metrics)
                self.trainer.save_metrics("initial_validation", metrics)

            for name, ds in self.extra_eval_sets.items():
                ds_sub = maybe_subsample_dataset(ds, self.max_eval_samples)
                metrics = self.trainer.evaluate(eval_dataset=ds_sub, metric_key_prefix=f"initial_{name}")
                self.trainer.log_metrics(name, metrics)
                self.trainer.save_metrics(name, metrics)

        except Exception as e:
            logging.warning(f"[InitialEval] could not compute initial metrics: {e}")
        finally:
            self._done = True

        self._state.wait_for_everyone()
        return control

Effect:

  • Initial stall is now one small eval per dataset, not full passes.
  • In multi-GPU, this runs once, not per rank. PartialState.is_main_process is the standard pattern for “do this only once” operations with Accelerate. (Hugging Face)

6. Fast EvalAggregateLossCallback: every N epochs, subset

Current behavior:

  • Every epoch, full pass over every extra eval dataset.

Add:

  • every_n_epochs knob.
  • max_eval_samples knob.
from transformers import TrainerCallback, TrainerState, TrainerControl

class EvalAggregateLossCallback(TrainerCallback):
    def __init__(
        self,
        trainer,
        eval_sets: Dict[str, Dataset],
        every_n_epochs: int = 1,
        max_eval_samples: Optional[int] = 2048,
    ):
        self.trainer = trainer
        self.eval_sets = dict(eval_sets)
        self.every_n_epochs = max(1, every_n_epochs)
        self.max_eval_samples = max_eval_samples

    def _should_run_for_epoch(self, epoch: Optional[float]) -> bool:
        if epoch is None:
            return False
        epoch_idx = int(epoch)
        return (epoch_idx % self.every_n_epochs) == 0

    def on_epoch_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        if not self._should_run_for_epoch(state.epoch):
            return control

        for name, ds in self.eval_sets.items():
            ds_sub = maybe_subsample_dataset(ds, self.max_eval_samples)
            metrics = self.trainer.evaluate(eval_dataset=ds_sub, metric_key_prefix=name)
            self.trainer.log_metrics(name, metrics)
            self.trainer.save_metrics(name, metrics)

            loss_key = f"{name}_loss"
            if loss_key in metrics:
                logging.info(f"[Epoch {int(state.epoch)}] {name} loss = {metrics[loss_key]:.4f}")

        return control

Typical values:

  • Debug: every_n_epochs=1, max_eval_samples=512.
  • Normal: every_n_epochs=3, max_eval_samples=2048.

This matches the “evaluate on subset during training” pattern discussed on HF forums. (Hugging Face Forums)


7. Multi-GPU aware MCQEvalCallback

Your MCQ callback is independent of Trainer.evaluate, which is good. The main issues:

  • Runs at on_train_begin and at every on_epoch_end on full MCQ datasets.
  • Probably runs once per rank and writes files from all processes.

Use:

  • A small config object for knobs.
  • PartialState to run MCQ only on main.
  • max_examples_per_dataset to bound work.
from dataclasses import dataclass
from accelerate import PartialState
from transformers import TrainerCallback
from pathlib import Path
from typing import Dict, Optional

@dataclass
class MCQEvalConfig:
    run_on_train_begin: bool = True
    run_every_n_epochs: int = 1        # 0 disables per-epoch MCQ
    max_examples_per_dataset: Optional[int] = 512
    max_choices_batch: int = 100
    mcq_eval_debug: bool = False

class MCQEvalCallback(TrainerCallback):
    def __init__(
        self,
        trainer,
        mcq_datasets: Dict[str, Dataset],
        tokenizer,
        output_dir: str,
        cfg: MCQEvalConfig,
    ):
        self.trainer = trainer
        self.mcq_datasets = mcq_datasets
        self.tokenizer = tokenizer
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.cfg = cfg
        self._initial_done = False
        self._state = PartialState()

    def _should_run_epoch(self, epoch: Optional[float]) -> bool:
        if not self.cfg.run_every_n_epochs:
            return False
        if epoch is None:
            return False
        epoch_idx = int(epoch)
        return (epoch_idx % self.cfg.run_every_n_epochs) == 0

    def on_train_begin(self, args, state, control, **kwargs):
        if self._initial_done or not self.cfg.run_on_train_begin:
            return control
        if not self.mcq_datasets:
            return control
        if not self._state.is_main_process:
            return control

        logging.info("[MCQEval] initial MCQ eval on limited examples")
        device = next(self.trainer.model.parameters()).device
        self._run_mcq(device, epoch_str="epoch_0")
        self._initial_done = True
        self._state.wait_for_everyone()
        return control

    def on_epoch_end(self, args, state, control, **kwargs):
        if not self._should_run_epoch(state.epoch):
            return control
        if not self.mcq_datasets:
            return control
        if not self._state.is_main_process:
            return control

        epoch_str = f"epoch_{int(state.epoch)}"
        logging.info(f"[MCQEval] MCQ eval at {epoch_str} on limited examples")
        device = next(self.trainer.model.parameters()).device
        self._run_mcq(device, epoch_str)
        self._state.wait_for_everyone()
        return control

    def _run_mcq(self, device, epoch_str: str):
        model = self.trainer.model
        model.eval()

        for name, ds in self.mcq_datasets.items():
            out_dir = self.output_dir / "mcq_evals" / name
            out_dir.mkdir(parents=True, exist_ok=True)
            out_file = out_dir / f"{epoch_str}.jsonl"

            with out_file.open("w", encoding="utf-8") as f:
                for row_idx, row in enumerate(ds):
                    if (
                        self.cfg.max_examples_per_dataset is not None
                        and row_idx >= self.cfg.max_examples_per_dataset
                    ):
                        break

                    # your current query/choices + _compute_logliks_for_query logic here

        model.train()

Again, the pattern “do logging or custom evaluation only on main process using PartialState” is what Accelerate recommends for experiment trackers and other side effects. (Hugging Face)


8. What stalls will remain and where they belong

After these changes:

  • Initial stall: one short eval over small subsets at on_train_begin (or you can disable it entirely).
  • Periodic stall: subset eval every N epochs for extra LM datasets.
  • MCQ: bounded cost (max rows per dataset) and only on main process.

The final full eval at the end on full train and full validation is fine; that happens once and is expected to be heavier.

If you later want full-dataset MCQ benchmarks or full LM benchmarks, run them outside training in a separate evaluation script that loads saved checkpoints. HF’s evaluate and evaluator utilities are designed for that pattern. (Hugging Face)


Short bullet summary

  • Trainer.evaluate always runs on the full dataset; you are calling it many times in callbacks and via eval_strategy="steps". (Hugging Face)
  • Use eval_strategy="epoch" or "no" and prediction_loss_only=True to reduce builtin eval overhead. (GitHub)
  • Subsample eval datasets in callbacks with Dataset.select plus max_*_samples knobs. (Hugging Face)
  • Run extra eval and MCQ only every N epochs, not every epoch. (Hugging Face Forums)
  • Use accelerate.PartialState.is_main_process so heavy callbacks (MCQ, extra eval) run only once per job, not once per rank.