| | |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from model import MiniGPT |
| | from datasets import load_dataset |
| | from dataloader import TinyLLMDataset |
| | from torch.utils.data import DataLoader |
| | import os |
| | from torch.nn.utils.rnn import pad_sequence |
| | from tokenizer import load_tokenizer |
| | from utils import print_gpu_memory |
| | import time |
| | from torch.optim.lr_scheduler import OneCycleLR |
| |
|
| |
|
| |
|
| | |
| | block_size = 128 |
| | batch_size = 32 |
| | max_iters = 100000 |
| | eval_interval = 100 |
| | learning_rate = 1e-3 |
| | embed_dim = 256 |
| | n_heads = 32 |
| | n_layers = 20 |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
|
| | |
| | |
| | dt = load_dataset("iproskurina/TinyStories-French") |
| | texts = dt["train"]["french-tinystories"] |
| |
|
| | stoi, itos, encode, decode, pad_token_id = load_tokenizer("tokenizer_wtw_tinystories.json") |
| | vocab_size = len(stoi) |
| |
|
| |
|
| | resume_path = "checkpoints/model_step_best.pt" |
| | if os.path.exists(resume_path): |
| | checkpoint = torch.load(resume_path) |
| | start_iter = checkpoint["step"] + 1 |
| | print(f"Reprise à l'étape {start_iter}") |
| | else: |
| | start_iter = 0 |
| | |
| | model = MiniGPT( |
| | vocab_size=vocab_size, |
| | block_size=block_size, |
| | embed_dim=embed_dim, |
| | depth=n_layers, |
| | heads=n_heads |
| | ).to(device) |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
| |
|
| | |
| | if os.path.exists(resume_path): |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| |
|
| |
|
| |
|
| | def collate_fn(batch): |
| | xs, ys = zip(*batch) |
| | xs_padded = pad_sequence(xs, batch_first=True, padding_value=pad_token_id) |
| | ys_padded = pad_sequence(ys, batch_first=True, padding_value=pad_token_id) |
| | return xs_padded, ys_padded |
| |
|
| |
|
| |
|
| | list_of_sentences = texts[:10000] |
| | split_idx = int(0.9 * len(list_of_sentences)) |
| | train_sentences = list_of_sentences[:split_idx] |
| | val_sentences = list_of_sentences[split_idx:] |
| | train_dataset = TinyLLMDataset(train_sentences, block_size, encode) |
| | val_dataset = TinyLLMDataset(val_sentences, block_size, encode) |
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) |
| | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn) |
| |
|
| |
|
| | def count_parameters(model): |
| | total = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | if total >= 1e9: |
| | return f"{total/1e9:.2f}B" |
| | elif total >= 1e6: |
| | return f"{total/1e6:.2f}M" |
| | elif total >= 1e3: |
| | return f"{total/1e3:.2f}K" |
| | return str(total) |
| |
|
| | print("Nombre de paramètres du modèle :", count_parameters(model)) |
| |
|
| |
|
| | |
| | scheduler = OneCycleLR( |
| | optimizer, |
| | max_lr=learning_rate, |
| | total_steps=max_iters, |
| | ) |
| |
|
| |
|
| | |
| | num_epochs = 10 |
| | global_step = start_iter |
| | best_loss = 10000 |
| | for epoch in range(num_epochs): |
| | print(f"\n=== Epoch {epoch + 1}/{num_epochs} ===") |
| |
|
| | for xb, yb in train_loader: |
| | start_time_total = time.time() |
| | xb = xb.to(device) |
| | yb = yb.to(device) |
| | model.train() |
| | |
| | |
| | start_time = time.time() |
| | logits = model(xb) |
| | forward_time = time.time() - start_time |
| | |
| | |
| | |
| | start_time = time.time() |
| | B, T, C = logits.shape |
| | loss = F.cross_entropy(logits.view(B * T, C), yb.view(B * T), ignore_index=pad_token_id) |
| | loss_time = time.time() - start_time |
| | |
| | |
| | |
| | start_time = time.time() |
| | optimizer.zero_grad() |
| | |
| | loss.backward() |
| | backward_time = time.time() - start_time |
| | |
| | |
| | |
| | start_time = time.time() |
| | optimizer.step() |
| | scheduler.step() |
| | step_time = time.time() - start_time |
| | |
| | |
| | end_time_total = time.time() |
| | |
| | total_time = time.time() - start_time_total |
| | print(f"[Step {global_step}] Perte = {loss.item():.4f} | total: {total_time:.3f}s | forward: {forward_time:.3f}s | loss: {loss_time:.3f}s | backward: {backward_time:.3f}s | step: {step_time:.3f}s") |
| |
|
| |
|
| | |
| | if global_step % eval_interval == 0: |
| | print(f"[Epoch {epoch+1} | Step {global_step}] Perte = {loss.item():.4f}") |
| | model.eval() |
| | context = torch.zeros((1, 1), dtype=torch.long, device=device) |
| | generated = model.generate(context, max_new_tokens=500)[0].tolist() |
| | print("\n--- Généré ---") |
| | print(decode(generated)) |
| | print("--------------\n") |
| | else: |
| | print(f"[Epoch {epoch+1} | Step {global_step}] Perte = {loss.item():.4f}") |
| |
|
| |
|
| | if loss.item() < best_loss: |
| | best_loss = loss.item() |
| | torch.save({ |
| | 'step': global_step, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'loss': loss.item(), |
| | 'vocab': {'stoi': stoi, 'itos': itos} |
| | }, f"checkpoints/model_step_best.pt") |
| |
|
| | global_step += 1 |
| |
|