|
|
|
|
|
|
|
|
""" |
|
|
Pretrain Veronica-Polymorphic from scratch (clean mixture: FinePDFs / DCLM / FineWeb-Edu). |
|
|
|
|
|
Basic example: |
|
|
python veronica-polymorphic/scripts/train_veronica.py \ |
|
|
--config veronica-polymorphic/configs/veronica-pretrain-12L.json \ |
|
|
--dataset_paths data/mix_optimal_50_30_20_2048 \ |
|
|
--output_dir veronica-polymorphic/runs/veronica-pretrain-vMix-2048 \ |
|
|
--per_device_train_batch_size 4 \ |
|
|
--gradient_accumulation_steps 4 \ |
|
|
--learning_rate 2e-4 \ |
|
|
--label_smoothing 0.01 \ |
|
|
--rep_alpha 0.0 \ |
|
|
--max_steps 60000 \ |
|
|
--max_seq_len 2048 |
|
|
|
|
|
You can use different datasets (e.g., 512 / 1024 / 2048) in separate runs for length curriculum. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import glob |
|
|
import json |
|
|
import math |
|
|
import argparse |
|
|
import random |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from datasets import load_from_disk |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
TrainerCallback, |
|
|
CONFIG_MAPPING, |
|
|
MODEL_FOR_CAUSAL_LM_MAPPING, |
|
|
LogitsProcessorList, |
|
|
NoRepeatNGramLogitsProcessor, |
|
|
RepetitionPenaltyLogitsProcessor, |
|
|
) |
|
|
|
|
|
|
|
|
from veronica.configuration_veronica import VeronicaConfig |
|
|
from veronica.modeling_veronica import VeronicaForCausalLM |
|
|
from veronica.modeling_components import Fp32LayerNorm |
|
|
|
|
|
CONFIG_MAPPING.register("veronica", VeronicaConfig) |
|
|
MODEL_FOR_CAUSAL_LM_MAPPING.register(VeronicaConfig, VeronicaForCausalLM) |
|
|
|
|
|
|
|
|
os.environ.setdefault("TORCH_COMPILE_USE_CUDAGRAPHS", "0") |
|
|
os.environ.setdefault("TORCHINDUCTOR_DISABLE_CUDAGRAPHS", "1") |
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_latest_checkpoint(run_dir: str) -> Optional[str]: |
|
|
ckpts = glob.glob(os.path.join(run_dir, "checkpoint-*")) |
|
|
if not ckpts: |
|
|
return None |
|
|
ckpts.sort(key=lambda p: int(re.search(r"checkpoint-(\d+)", p).group(1))) |
|
|
return ckpts[-1] |
|
|
|
|
|
|
|
|
def build_tokenizer(candidates: List[str], save_dir: str) -> AutoTokenizer: |
|
|
""" |
|
|
Try to load an existing tokenizer from the provided paths; |
|
|
otherwise fallback to gpt2 and add basic special tokens. |
|
|
""" |
|
|
tok = None |
|
|
for p in candidates: |
|
|
if os.path.exists(p): |
|
|
try: |
|
|
tok = AutoTokenizer.from_pretrained(p, use_fast=True) |
|
|
print(f"[tokenizer] loaded from {p}") |
|
|
break |
|
|
except Exception: |
|
|
pass |
|
|
if tok is None: |
|
|
print("[tokenizer] fallback: gpt2") |
|
|
tok = AutoTokenizer.from_pretrained("gpt2", use_fast=True) |
|
|
|
|
|
specials: Dict[str, str] = {} |
|
|
if tok.eos_token is None: |
|
|
specials["eos_token"] = "<|eos|>" |
|
|
if tok.pad_token is None: |
|
|
specials["pad_token"] = "<|pad|>" |
|
|
if tok.bos_token is None: |
|
|
specials["bos_token"] = "<|bos|>" |
|
|
|
|
|
if specials: |
|
|
tok.add_special_tokens(specials) |
|
|
|
|
|
tok.save_pretrained(save_dir) |
|
|
tok = AutoTokenizer.from_pretrained(save_dir, use_fast=True) |
|
|
base_vocab = tok.vocab_size |
|
|
effective_vocab = len(tok) |
|
|
print( |
|
|
f"[tokenizer] base_vocab={base_vocab} added={effective_vocab - base_vocab} " |
|
|
f"effective_vocab={effective_vocab} eos={tok.eos_token_id} " |
|
|
f"pad={tok.pad_token_id} bos={tok.bos_token_id}" |
|
|
) |
|
|
return tok |
|
|
|
|
|
|
|
|
def load_cfg_with_vocab(cfg_path: str, tok: AutoTokenizer) -> VeronicaConfig: |
|
|
""" |
|
|
Load the config and adapt it to the tokenizer vocabulary. |
|
|
Model is designed as UN-TIED (lm_head != wte). |
|
|
""" |
|
|
with open(cfg_path, "r", encoding="utf-8") as f: |
|
|
d = json.load(f) |
|
|
cfg = VeronicaConfig(**d) |
|
|
cfg.model_type = "veronica" |
|
|
cfg.vocab_size = int(len(tok)) |
|
|
|
|
|
return cfg |
|
|
|
|
|
|
|
|
def init_model_from_config(cfg: VeronicaConfig, tok: AutoTokenizer) -> VeronicaForCausalLM: |
|
|
model = VeronicaForCausalLM(cfg) |
|
|
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
|
|
dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(dtype=dtype, device=device) |
|
|
|
|
|
effective_vocab = len(tok) |
|
|
emb = model.get_input_embeddings().weight |
|
|
head = model.lm_head.weight |
|
|
|
|
|
|
|
|
if emb.shape[0] != effective_vocab or head.shape[0] != effective_vocab: |
|
|
old_vocab = emb.shape[0] |
|
|
print(f"[model] resize_token_embeddings: {old_vocab} -> {effective_vocab}") |
|
|
model.resize_token_embeddings(effective_vocab) |
|
|
with torch.no_grad(): |
|
|
new_emb = model.get_input_embeddings().weight |
|
|
new_head = model.lm_head.weight |
|
|
mean_emb = new_emb[:old_vocab].mean(dim=0, keepdim=True) |
|
|
mean_head = new_head[:old_vocab].mean(dim=0, keepdim=True) |
|
|
if effective_vocab > old_vocab: |
|
|
new_emb[old_vocab:] = mean_emb |
|
|
new_head[old_vocab:] = mean_head |
|
|
|
|
|
|
|
|
for m in model.modules(): |
|
|
if isinstance(m, Fp32LayerNorm): |
|
|
m.ln.to(dtype=torch.float32) |
|
|
|
|
|
model.config.use_cache = False |
|
|
n_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"[model] params={n_params:,} vocab={effective_vocab}") |
|
|
return model |
|
|
|
|
|
|
|
|
def load_mix_dataset(path: str): |
|
|
""" |
|
|
Load a packed dataset (train/validation) from disk. |
|
|
Expected HuggingFace formats: a DatasetDict with 'train' and 'validation', |
|
|
or a single Dataset that gets split 99/1. |
|
|
""" |
|
|
ds = load_from_disk(path) |
|
|
if isinstance(ds, dict) and "train" in ds and "validation" in ds: |
|
|
return ds["train"], ds["validation"] |
|
|
split = ds.train_test_split(test_size=0.01, seed=42) |
|
|
return split["train"], split["test"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CausalCollator: |
|
|
tokenizer: AutoTokenizer |
|
|
mask_runs: bool = False |
|
|
run_len: int = 4 |
|
|
max_seq_len: Optional[int] = None |
|
|
|
|
|
def _mask_degenerate_runs(self, labels: torch.Tensor): |
|
|
""" |
|
|
Mask degenerate runs (e.g., '____', '....') with length >= run_len. |
|
|
Mostly legacy; can be left off with a clean dataset. |
|
|
""" |
|
|
try: |
|
|
id_us = self.tokenizer.encode("_", add_special_tokens=False)[0] |
|
|
id_dot = self.tokenizer.encode(".", add_special_tokens=False)[0] |
|
|
except Exception: |
|
|
return |
|
|
B, T = labels.size() |
|
|
for b in range(B): |
|
|
cnt_u = cnt_d = 0 |
|
|
for t in range(T): |
|
|
tok = int(labels[b, t].item()) |
|
|
if tok == id_us: |
|
|
cnt_u += 1 |
|
|
cnt_d = 0 |
|
|
elif tok == id_dot: |
|
|
cnt_d += 1 |
|
|
cnt_u = 0 |
|
|
else: |
|
|
cnt_u = cnt_d = 0 |
|
|
if cnt_u >= self.run_len or cnt_d >= self.run_len: |
|
|
labels[b, t] = -100 |
|
|
|
|
|
def _crop(self, ids: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
If max_seq_len is set and the sequence is longer, |
|
|
crop a random window of length max_seq_len. |
|
|
""" |
|
|
if self.max_seq_len is None: |
|
|
return ids |
|
|
L = ids.size(0) |
|
|
if L <= self.max_seq_len: |
|
|
return ids |
|
|
start = random.randint(0, L - self.max_seq_len) |
|
|
end = start + self.max_seq_len |
|
|
return ids[start:end] |
|
|
|
|
|
def __call__(self, features): |
|
|
ids_list = [] |
|
|
for f in features: |
|
|
ids = torch.tensor(f["input_ids"], dtype=torch.long) |
|
|
ids = self._crop(ids) |
|
|
ids_list.append(ids) |
|
|
|
|
|
pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id |
|
|
ids = torch.nn.utils.rnn.pad_sequence(ids_list, batch_first=True, padding_value=pad_id) |
|
|
attn = torch.where(ids == pad_id, 0, 1) |
|
|
|
|
|
labels = ids.clone() |
|
|
labels[labels == pad_id] = -100 |
|
|
if self.mask_runs: |
|
|
self._mask_degenerate_runs(labels) |
|
|
|
|
|
return {"input_ids": ids, "attention_mask": attn, "labels": labels} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SMOKE_PROMPTS = [ |
|
|
"The world we live in today is", |
|
|
"Understanding complex ideas requires", |
|
|
"Human intelligence differs from artificial intelligence because", |
|
|
"A good system design is based on", |
|
|
"In the middle of every difficulty lies", |
|
|
"Once upon a time, there was a scientist who", |
|
|
] |
|
|
|
|
|
|
|
|
class RouterAndSmokeCallback(TrainerCallback): |
|
|
def __init__(self, tok: AutoTokenizer): |
|
|
self.tok = tok |
|
|
|
|
|
def on_log(self, args, state, control, **kwargs): |
|
|
model = kwargs.get("model", None) |
|
|
if model is None: |
|
|
return |
|
|
try: |
|
|
if hasattr(model, "router_alpha_mean") and model.router_alpha_mean is not None: |
|
|
alpha = model.router_alpha_mean.detach().float().cpu() |
|
|
p = alpha / alpha.sum() |
|
|
ent = -(p * (p.clamp_min(1e-9)).log()).sum() |
|
|
ent_norm = float(ent / math.log(len(p))) |
|
|
print(f"[router] alpha={alpha.tolist()} entropy_norm={ent_norm:.4f}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def on_evaluate(self, args, state, control, **kwargs): |
|
|
model = kwargs.get("model", None) |
|
|
if model is None: |
|
|
return |
|
|
model.eval() |
|
|
dev = next(model.parameters()).device |
|
|
|
|
|
prompt = random.choice(SMOKE_PROMPTS) |
|
|
ids = self.tok(prompt, return_tensors="pt").to(dev) |
|
|
|
|
|
processors = LogitsProcessorList([ |
|
|
NoRepeatNGramLogitsProcessor(3), |
|
|
RepetitionPenaltyLogitsProcessor(1.1), |
|
|
]) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**ids, |
|
|
max_new_tokens=64, |
|
|
do_sample=False, |
|
|
logits_processor=processors, |
|
|
eos_token_id=self.tok.eos_token_id, |
|
|
pad_token_id=(self.tok.pad_token_id or self.tok.eos_token_id), |
|
|
use_cache=True, |
|
|
) |
|
|
txt = self.tok.decode(out[0], skip_special_tokens=True) |
|
|
completion = txt[len(prompt):].strip() if txt.startswith(prompt) else txt |
|
|
print(f"\n[SMOKE] {prompt} → {completion}\n") |
|
|
model.train() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RouterScheduleCallback(TrainerCallback): |
|
|
""" |
|
|
Linearly schedule router_tau and router_aux_weight between start and end of training. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tau_start: float, |
|
|
tau_end: float, |
|
|
aux_start: float, |
|
|
aux_end: float, |
|
|
total_steps: int, |
|
|
tau_freeze_steps: int = 0, |
|
|
force_prob: float = 0.0, |
|
|
force_warmup_steps: int = 0, |
|
|
): |
|
|
self.tau_start = float(tau_start) |
|
|
self.tau_end = float(tau_end) |
|
|
self.aux_start = float(aux_start) |
|
|
self.aux_end = float(aux_end) |
|
|
self.total_steps = max(int(total_steps), 1) |
|
|
self.tau_freeze_steps = max(int(tau_freeze_steps), 0) |
|
|
self.force_prob = float(force_prob) |
|
|
self.force_warmup_steps = max(int(force_warmup_steps), 0) |
|
|
|
|
|
def _interp(self, start: float, end: float, step: int, span: int) -> float: |
|
|
t = min(max(step, 0), span) |
|
|
alpha = t / float(max(span, 1)) |
|
|
return (1.0 - alpha) * start + alpha * end |
|
|
|
|
|
def on_step_begin(self, args, state, control, **kwargs): |
|
|
model = kwargs.get("model", None) |
|
|
if model is None: |
|
|
return |
|
|
step = state.global_step |
|
|
|
|
|
if step < self.tau_freeze_steps: |
|
|
new_tau = self.tau_start |
|
|
else: |
|
|
rem_step = step - self.tau_freeze_steps |
|
|
rem_span = max(self.total_steps - self.tau_freeze_steps, 1) |
|
|
new_tau = self._interp(self.tau_start, self.tau_end, rem_step, rem_span) |
|
|
|
|
|
|
|
|
new_aux = self._interp(self.aux_start, self.aux_end, step, self.total_steps) |
|
|
|
|
|
|
|
|
if hasattr(model, "config"): |
|
|
model.config.router_tau = new_tau |
|
|
model.config.router_aux_weight = new_aux |
|
|
|
|
|
|
|
|
for block in getattr(model, "blocks", []): |
|
|
if hasattr(block, "mlp"): |
|
|
|
|
|
block.mlp.router_tau = new_tau |
|
|
block.mlp.force_func = -1 |
|
|
|
|
|
|
|
|
if step < self.force_warmup_steps and self.force_prob > 0.0: |
|
|
if random.random() < self.force_prob: |
|
|
for block in getattr(model, "blocks", []): |
|
|
if hasattr(block, "mlp") and hasattr(block.mlp, "num_funcs"): |
|
|
k = block.mlp.num_funcs |
|
|
block.mlp.force_func = random.randint(0, max(k - 1, 0)) |
|
|
|
|
|
if step % 1000 == 0: |
|
|
print( |
|
|
f"[router-sched] step={step} tau={new_tau:.4f} aux_w={new_aux:.5f} " |
|
|
f"freeze<= {self.tau_freeze_steps} force_p={self.force_prob:.3f} warmup<= {self.force_warmup_steps}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VeronicaTrainer(Trainer): |
|
|
def __init__(self, *args, label_smoothing: float = 0.0, rep_alpha: float = 0.0, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.label_smoothing = float(label_smoothing) |
|
|
self.rep_alpha = float(rep_alpha) |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
|
|
labels = inputs.get("labels") |
|
|
if labels is None: |
|
|
raise ValueError("compute_loss called without labels") |
|
|
model_inputs = {k: v for k, v in inputs.items() if k != "labels"} |
|
|
|
|
|
outputs = model(**model_inputs) |
|
|
logits = outputs.logits |
|
|
|
|
|
ignore_index = -100 |
|
|
|
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
|
shift_labels = labels[:, 1:].contiguous() |
|
|
|
|
|
valid_mask = (shift_labels != ignore_index) |
|
|
safe_labels = shift_labels.clone() |
|
|
safe_labels[~valid_mask] = 0 |
|
|
|
|
|
log_probs = F.log_softmax(shift_logits, dim=-1) |
|
|
nll_full = -log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) |
|
|
nll_loss = nll_full[valid_mask].mean() |
|
|
|
|
|
if self.label_smoothing > 0.0: |
|
|
smooth_full = -log_probs.mean(dim=-1) |
|
|
smooth_loss = smooth_full[valid_mask].mean() |
|
|
ce_loss = (1.0 - self.label_smoothing) * nll_loss + self.label_smoothing * smooth_loss |
|
|
else: |
|
|
ce_loss = nll_loss |
|
|
|
|
|
total_loss = ce_loss |
|
|
|
|
|
|
|
|
if self.rep_alpha > 0.0: |
|
|
labels_prev = labels[:, :-1] |
|
|
labels_next = shift_labels |
|
|
valid_prev = (labels_prev != ignore_index) |
|
|
same_mask = valid_prev & valid_mask & (labels_prev == labels_next) |
|
|
if same_mask.any(): |
|
|
rep_logp = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) |
|
|
rep_p = rep_logp[same_mask].exp() |
|
|
total_loss = total_loss + self.rep_alpha * rep_p.mean() |
|
|
|
|
|
|
|
|
aux_loss = getattr(model, "_last_router_aux", None) |
|
|
if aux_loss is not None and hasattr(model, "config"): |
|
|
aux_w = float(getattr(model.config, "router_aux_weight", 0.0)) |
|
|
if aux_w > 0: |
|
|
if not torch.is_tensor(aux_loss): |
|
|
aux_loss = torch.as_tensor(aux_loss, device=logits.device, dtype=logits.dtype) |
|
|
|
|
|
total_loss = total_loss - aux_w * aux_loss.clamp_min(0.0) |
|
|
|
|
|
return (total_loss, outputs) if return_outputs else total_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, required=True) |
|
|
parser.add_argument("--dataset_paths", type=str, required=True) |
|
|
parser.add_argument("--output_dir", type=str, required=True, default="veronica-polymorphic/runs/veronica-pretrain") |
|
|
|
|
|
parser.add_argument( |
|
|
"--tokenizer_candidates", |
|
|
type=str, |
|
|
nargs="*", |
|
|
default=["veronica-polymorphic/tokenizer", "gpt2"], |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tokenizer_out", |
|
|
type=str, |
|
|
default="veronica-polymorphic/tokenizer_vmix", |
|
|
) |
|
|
|
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=4) |
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=4) |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
|
|
parser.add_argument("--max_steps", type=int, default=60000) |
|
|
parser.add_argument("--learning_rate", type=float, default=2e-4) |
|
|
parser.add_argument("--warmup_ratio", type=float, default=0.02) |
|
|
parser.add_argument("--weight_decay", type=float, default=0.1) |
|
|
parser.add_argument("--eval_steps", type=int, default=1000) |
|
|
parser.add_argument("--save_steps", type=int, default=1000) |
|
|
parser.add_argument("--logging_steps", type=int, default=100) |
|
|
parser.add_argument("--label_smoothing", type=float, default=0.01) |
|
|
parser.add_argument("--rep_alpha", type=float, default=0.0) |
|
|
parser.add_argument("--mask_degenerate_runs", action="store_true") |
|
|
parser.add_argument("--seed", type=int, default=42) |
|
|
|
|
|
parser.add_argument( |
|
|
"--resume_from", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Checkpoint to resume from (e.g., .../checkpoint-22000)", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--max_seq_len", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Maximum window length (e.g., 512, 1024, 2048). If None, uses the full dataset sequence.", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument("--router_tau_start", type=float, default=1.6) |
|
|
parser.add_argument("--router_tau_end", type=float, default=1.1) |
|
|
parser.add_argument("--router_aux_start", type=float, default=0.005) |
|
|
parser.add_argument("--router_aux_end", type=float, default=0.012) |
|
|
parser.add_argument("--router_tau_freeze_steps", type=int, default=4000, |
|
|
help="Keep tau constant for the first N steps to avoid early specialization.") |
|
|
parser.add_argument("--router_force_prob", type=float, default=0.05, |
|
|
help="Per-step probability to force a single branch during warmup.") |
|
|
parser.add_argument("--router_force_warmup_steps", type=int, default=3000, |
|
|
help="Apply random branch forcing only within these initial steps.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
tok = build_tokenizer(args.tokenizer_candidates, args.tokenizer_out) |
|
|
|
|
|
|
|
|
cfg = load_cfg_with_vocab(args.config, tok) |
|
|
cfg.router_tau = args.router_tau_start |
|
|
cfg.router_aux_weight = args.router_aux_start |
|
|
|
|
|
model = init_model_from_config(cfg, tok) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
dummy = torch.randint(0, model.config.vocab_size, (1, 32), device=model.device) |
|
|
out = model(input_ids=dummy, labels=dummy) |
|
|
loss_model = out.loss.item() |
|
|
|
|
|
logits = out.logits |
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
|
shift_labels = dummy[:, 1:].contiguous() |
|
|
loss_manual = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1) |
|
|
).item() |
|
|
|
|
|
print(f"[diag] loss_model_forward={loss_model:.4f} loss_manual_shift={loss_manual:.4f}") |
|
|
model.train() |
|
|
|
|
|
|
|
|
train_ds, val_ds = load_mix_dataset(args.dataset_paths) |
|
|
collator = CausalCollator( |
|
|
tokenizer=tok, |
|
|
mask_runs=args.mask_degenerate_runs, |
|
|
max_seq_len=args.max_seq_len, |
|
|
) |
|
|
|
|
|
|
|
|
resume_ckpt = args.resume_from or find_latest_checkpoint(args.output_dir) |
|
|
if resume_ckpt: |
|
|
print(f"🟢 Resuming from: {resume_ckpt}") |
|
|
else: |
|
|
print("⚪ No checkpoint: training from scratch.") |
|
|
|
|
|
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
|
|
|
|
|
train_args = TrainingArguments( |
|
|
output_dir=args.output_dir, |
|
|
run_name=os.path.basename(args.output_dir.rstrip("/")), |
|
|
num_train_epochs=1_000, |
|
|
max_steps=args.max_steps, |
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
learning_rate=args.learning_rate, |
|
|
warmup_ratio=args.warmup_ratio, |
|
|
weight_decay=args.weight_decay, |
|
|
lr_scheduler_type="cosine", |
|
|
logging_steps=args.logging_steps, |
|
|
eval_steps=args.eval_steps, |
|
|
save_steps=args.save_steps, |
|
|
eval_strategy="steps", |
|
|
save_total_limit=5, |
|
|
bf16=use_bf16, |
|
|
fp16=(torch.cuda.is_available() and not use_bf16), |
|
|
gradient_checkpointing=True, |
|
|
report_to=["tensorboard"], |
|
|
dataloader_num_workers=2, |
|
|
seed=args.seed, |
|
|
label_smoothing_factor=0.0, |
|
|
max_grad_norm=1.0, |
|
|
save_safetensors=False, |
|
|
) |
|
|
|
|
|
callbacks: List[TrainerCallback] = [ |
|
|
RouterAndSmokeCallback(tok), |
|
|
RouterScheduleCallback( |
|
|
tau_start=args.router_tau_start, |
|
|
tau_end=args.router_tau_end, |
|
|
aux_start=args.router_aux_start, |
|
|
aux_end=args.router_aux_end, |
|
|
total_steps=args.max_steps, |
|
|
tau_freeze_steps=args.router_tau_freeze_steps, |
|
|
force_prob=args.router_force_prob, |
|
|
force_warmup_steps=args.router_force_warmup_steps, |
|
|
), |
|
|
] |
|
|
|
|
|
trainer = VeronicaTrainer( |
|
|
model=model, |
|
|
args=train_args, |
|
|
train_dataset=train_ds, |
|
|
eval_dataset=val_ds, |
|
|
tokenizer=tok, |
|
|
data_collator=collator, |
|
|
callbacks=callbacks, |
|
|
label_smoothing=args.label_smoothing, |
|
|
rep_alpha=args.rep_alpha, |
|
|
) |
|
|
|
|
|
|
|
|
effective_vocab = len(tok) |
|
|
emb = model.get_input_embeddings().weight |
|
|
head = model.lm_head.weight |
|
|
assert emb.shape[0] == effective_vocab == head.shape[0], "Mismatch vocab/emb/lm_head" |
|
|
|
|
|
|
|
|
trainer.train(resume_from_checkpoint=resume_ckpt) |
|
|
trainer.save_state() |
|
|
trainer.save_model(args.output_dir) |
|
|
tok.save_pretrained(args.output_dir) |
|
|
with open(os.path.join(args.output_dir, "config.final.json"), "w", encoding="utf-8") as f: |
|
|
json.dump(model.config.to_dict(), f, indent=2) |
|
|
print("✅ Pretraining completed/saved.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|