| | from typing import Optional, Tuple, List |
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import PreTrainedModel |
| | from transformers.generation.utils import GenerationMixin |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| | from .configuration_veronica import VeronicaConfig |
| | from .modeling_components import PolymorphicMLP, router_aux_loss, Fp32LayerNorm, apply_rotary_pos_emb |
| |
|
| |
|
| | class MultiHeadSelfAttention(nn.Module): |
| | def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.0, max_position_embeddings: int = 4096, rope_theta: float = 10000.0): |
| | super().__init__() |
| | assert hidden_size % num_heads == 0, "hidden_size must be divisible by n_head" |
| | self.num_heads = num_heads |
| | self.head_dim = hidden_size // num_heads |
| | self.scale = 1.0 / math.sqrt(self.head_dim) |
| | self.max_position_embeddings = max_position_embeddings |
| | self.rope_theta = rope_theta |
| |
|
| | self.qkv = nn.Linear(hidden_size, 3 * hidden_size) |
| | self.out_proj = nn.Linear(hidden_size, hidden_size) |
| | self.attn_drop = nn.Dropout(dropout) |
| | self.resid_drop = nn.Dropout(dropout) |
| | |
| | |
| | self._rope_cached_seq_len = 0 |
| | self._rope_cos_cached = None |
| | self._rope_sin_cached = None |
| |
|
| | def _split_heads(self, x: torch.Tensor) -> torch.Tensor: |
| | B, T, C = x.shape |
| | x = x.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| | return x |
| |
|
| | def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: |
| | B, nh, T, hd = x.shape |
| | return x.transpose(1, 2).contiguous().view(B, T, nh * hd) |
| | |
| | def _get_rope_cos_sin(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Genera o recupera dalla cache cos/sin per RoPE.""" |
| | if seq_len <= self._rope_cached_seq_len and self._rope_cos_cached is not None: |
| | return self._rope_cos_cached[:, :, :seq_len, :].to(device=device, dtype=dtype), \ |
| | self._rope_sin_cached[:, :, :seq_len, :].to(device=device, dtype=dtype) |
| | |
| | |
| | self._rope_cached_seq_len = max(seq_len, self.max_position_embeddings) |
| | |
| | |
| | dim = self.head_dim |
| | inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) |
| | |
| | |
| | t = torch.arange(self._rope_cached_seq_len, dtype=torch.float32, device=device) |
| | |
| | |
| | freqs = torch.outer(t, inv_freq) |
| | |
| | |
| | emb = torch.cat([freqs, freqs], dim=-1) |
| | |
| | |
| | cos = emb.cos().unsqueeze(0).unsqueeze(0) |
| | sin = emb.sin().unsqueeze(0).unsqueeze(0) |
| | |
| | self._rope_cos_cached = cos |
| | self._rope_sin_cached = sin |
| | |
| | return cos[:, :, :seq_len, :].to(dtype=dtype), sin[:, :, :seq_len, :].to(dtype=dtype) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| | use_cache: bool = False, |
| | position_offset: int = 0, |
| | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| | B, T, C = x.shape |
| | qkv = self.qkv(x) |
| | q, k, v = qkv.split(C, dim=-1) |
| | q = self._split_heads(q) |
| | k = self._split_heads(k) |
| | v = self._split_heads(v) |
| | |
| | |
| | cos, sin = self._get_rope_cos_sin(position_offset + T, q.device, q.dtype) |
| | |
| | cos = cos[:, :, position_offset:position_offset+T, :] |
| | sin = sin[:, :, position_offset:position_offset+T, :] |
| | q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| |
|
| | present = None |
| | if past_key_value is not None: |
| | pk, pv = past_key_value |
| | k = torch.cat([pk, k], dim=-2) |
| | v = torch.cat([pv, v], dim=-2) |
| | if use_cache: |
| | present = (k, v) |
| |
|
| | att = (q @ k.transpose(-2, -1)) * self.scale |
| | att = att.float() |
| | if attn_mask is not None: |
| | att = att + attn_mask |
| | att = F.softmax(att, dim=-1) |
| | att = self.attn_drop(att) |
| | att = att.to(v.dtype) |
| | y = att @ v |
| | y = self._merge_heads(y) |
| | y = self.out_proj(y) |
| | y = self.resid_drop(y) |
| | return y, present |
| |
|
| |
|
| | class VeronicaBlock(nn.Module): |
| | def __init__(self, config: VeronicaConfig): |
| | super().__init__() |
| | self.ln_1 = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| | self.attn = MultiHeadSelfAttention( |
| | config.n_embd, |
| | config.n_head, |
| | dropout=config.dropout, |
| | max_position_embeddings=config.max_position_embeddings, |
| | rope_theta=getattr(config, 'rope_theta', 10000.0) |
| | ) |
| | self.ln_2 = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| | self.mlp = PolymorphicMLP( |
| | hidden_size=config.n_embd, |
| | mlp_mult=config.mlp_mult, |
| | num_funcs=config.num_funcs, |
| | router_dim=config.router_dim, |
| | dropout=config.dropout, |
| | use_channel_attention=config.use_channel_attention, |
| | router_tau=config.router_tau, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| | use_cache: bool = False, |
| | position_offset: int = 0, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| | h = self.ln_1(x) |
| | attn_out, present = self.attn(h, attn_mask, past_key_value=past_key_value, use_cache=use_cache, position_offset=position_offset) |
| | x = x + attn_out |
| | h = self.ln_2(x) |
| | y, alpha = self.mlp(h) |
| | x = x + y |
| | return x, alpha, present |
| |
|
| |
|
| | class VeronicaModel(PreTrainedModel): |
| | config_class = VeronicaConfig |
| |
|
| | def __init__(self, config: VeronicaConfig): |
| | super().__init__(config) |
| | self.embed_dim = config.n_embd |
| | self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| | |
| | self.drop = nn.Dropout(config.dropout) |
| | self.blocks = nn.ModuleList([VeronicaBlock(config) for _ in range(config.n_layer)]) |
| | self.ln_f = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| |
|
| | self.register_buffer( |
| | "causal_mask", |
| | torch.tril( |
| | torch.ones( |
| | config.max_position_embeddings, |
| | config.max_position_embeddings, |
| | dtype=torch.uint8, |
| | ) |
| | ).view(1, 1, config.max_position_embeddings, config.max_position_embeddings), |
| | persistent=False, |
| | ) |
| |
|
| | |
| | self.router_alpha_entropy: Optional[torch.Tensor] = None |
| | self.router_alpha_mean: Optional[torch.Tensor] = None |
| |
|
| | self._use_gradient_checkpointing: bool = getattr(config, "gradient_checkpointing", False) |
| |
|
| | def get_input_embeddings(self): |
| | return self.wte |
| |
|
| | def set_input_embeddings(self, value): |
| | self.wte = value |
| |
|
| | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| | self._use_gradient_checkpointing = True |
| |
|
| | def gradient_checkpointing_disable(self): |
| | self._use_gradient_checkpointing = False |
| |
|
| | def _build_attn_mask( |
| | self, |
| | attention_mask: Optional[torch.Tensor], |
| | seq_len: int, |
| | past_kv_len: int, |
| | device: torch.device, |
| | dtype: torch.dtype, |
| | ) -> torch.Tensor: |
| | |
| | T, P = seq_len, past_kv_len |
| | causal = torch.full((T, T + P), float("-inf"), device=device, dtype=dtype) |
| | causal = torch.triu(causal, diagonal=1 + P) |
| |
|
| | if attention_mask is None: |
| | return causal.unsqueeze(0).unsqueeze(1) |
| |
|
| | |
| | attn_full = attention_mask.to(dtype) |
| | pad_add = (1.0 - attn_full) * torch.finfo(dtype).min |
| | pad_add = pad_add.unsqueeze(1).unsqueeze(2) |
| | causal = causal.unsqueeze(0).unsqueeze(1) |
| | return causal + pad_add |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | output_router_stats: bool = True, |
| | past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| | use_cache: Optional[bool] = None, |
| | **kwargs, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
| | device = input_ids.device |
| | B, T = input_ids.shape |
| |
|
| | if use_cache is None: |
| | use_cache = False if self.training else True |
| |
|
| | pkv_list: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None |
| |
|
| | P = 0 |
| | if ( |
| | past_key_values is not None |
| | and len(past_key_values) > 0 |
| | and past_key_values[0] is not None |
| | and isinstance(past_key_values[0], (tuple, list)) |
| | and past_key_values[0][0] is not None |
| | ): |
| | P = past_key_values[0][0].size(-2) |
| |
|
| | |
| | x = self.wte(input_ids) |
| | x = self.drop(x) |
| |
|
| | |
| | attn_full = None |
| | if attention_mask is not None: |
| | if attention_mask.size(-1) == T + P: |
| | attn_full = attention_mask |
| | elif attention_mask.size(-1) == T: |
| | if P > 0: |
| | ones = torch.ones((B, P), dtype=attention_mask.dtype, device=attention_mask.device) |
| | attn_full = torch.cat([ones, attention_mask], dim=-1) |
| | else: |
| | attn_full = attention_mask |
| | else: |
| | attn_full = None |
| |
|
| | attn_bias = self._build_attn_mask(attn_full, T, P, device, torch.float32) |
| |
|
| | alpha_list: List[torch.Tensor] = [] |
| | if self.training: |
| | self._acc_aux_sum = 0.0 |
| | self._acc_aux_count = 0 |
| |
|
| | if getattr(self, "_use_gradient_checkpointing", False) and self.training: |
| | def create_custom_forward(module, pkv): |
| | def custom_forward(x): |
| | out_x, out_alpha, _ = module(x, attn_bias, past_key_value=pkv, use_cache=False, position_offset=P) |
| | return out_x, out_alpha |
| |
|
| | return custom_forward |
| |
|
| | if past_key_values is not None: |
| | curr_past = [ |
| | pkv |
| | if (pkv is not None and isinstance(pkv, (tuple, list)) and pkv[0] is not None and pkv[1] is not None) |
| | else None |
| | for pkv in past_key_values |
| | ] |
| | else: |
| | curr_past = [None] * len(self.blocks) |
| | for layer_idx, block in enumerate(self.blocks): |
| | x, alpha = torch.utils.checkpoint.checkpoint( |
| | create_custom_forward(block, curr_past[layer_idx]), x, use_reentrant=False |
| | ) |
| | alpha_list.append(alpha) |
| | if self.training and getattr(block.mlp, "last_aux", None) is not None: |
| | self._acc_aux_sum = self._acc_aux_sum + block.mlp.last_aux |
| | self._acc_aux_count += 1 |
| | else: |
| | if past_key_values is not None: |
| | curr_past = [ |
| | pkv |
| | if (pkv is not None and isinstance(pkv, (tuple, list)) and pkv[0] is not None and pkv[1] is not None) |
| | else None |
| | for pkv in past_key_values |
| | ] |
| | else: |
| | curr_past = [None] * len(self.blocks) |
| | for layer_idx, block in enumerate(self.blocks): |
| | x, alpha, present = block(x, attn_bias, past_key_value=curr_past[layer_idx], use_cache=use_cache, position_offset=P) |
| | alpha_list.append(alpha) |
| | if self.training and getattr(block.mlp, "last_aux", None) is not None: |
| | self._acc_aux_sum = self._acc_aux_sum + block.mlp.last_aux |
| | self._acc_aux_count += 1 |
| | if use_cache and pkv_list is not None: |
| | pkv_list.append(present) |
| |
|
| | x = self.ln_f(x) |
| |
|
| | |
| | if output_router_stats and len(alpha_list) > 0: |
| | alpha_stack = torch.stack(alpha_list, dim=0) |
| | alpha_mean = alpha_stack.mean(dim=(0, 1, 2)) |
| | self.router_alpha_mean = alpha_mean.detach() |
| | self.router_alpha_entropy = router_aux_loss(alpha_stack.mean(dim=0)) |
| |
|
| | |
| | if hasattr(self, "_acc_aux_sum"): |
| | if self._acc_aux_count > 0: |
| | self._last_router_aux = self._acc_aux_sum / self._acc_aux_count |
| | else: |
| | self._last_router_aux = None |
| | delattr(self, "_acc_aux_sum") |
| | delattr(self, "_acc_aux_count") |
| |
|
| | return x, pkv_list |
| |
|
| |
|
| | class VeronicaForCausalLM(VeronicaModel, GenerationMixin): |
| | def __init__(self, config: VeronicaConfig): |
| | super().__init__(config) |
| | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| | self.post_init() |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head = new_embeddings |
| |
|
| | def tie_weights(self): |
| | self._tie_or_clone_weights(self.lm_head, self.get_input_embeddings()) |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: torch.LongTensor, |
| | past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ): |
| | if past_key_values is not None and len(past_key_values) > 0: |
| | input_ids = input_ids[:, -1:] |
| | return { |
| | "input_ids": input_ids, |
| | "past_key_values": past_key_values, |
| | "attention_mask": attention_mask, |
| | "use_cache": True, |
| | } |
| |
|
| | def _reorder_cache(self, past_key_values, beam_idx: torch.LongTensor): |
| | if past_key_values is None: |
| | return past_key_values |
| | reordered = [] |
| | for (k, v) in past_key_values: |
| | reordered.append((k.index_select(0, beam_idx), v.index_select(0, beam_idx))) |
| | return reordered |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| | **kwargs, |
| | ) -> CausalLMOutputWithPast: |
| | hidden_states, present = super().forward( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | labels=None, |
| | use_cache=use_cache, |
| | past_key_values=past_key_values, |
| | **kwargs, |
| | ) |
| | logits = self.lm_head(hidden_states) |
| |
|
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[:, :-1, :].contiguous() |
| | shift_labels = labels[:, 1:].contiguous() |
| | loss = F.cross_entropy( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1), |
| | ignore_index=-100, |
| | ) |
| |
|
| | aux = getattr(self, "_last_router_aux", None) |
| | if aux is not None and getattr(self.config, "router_aux_weight", 0.0) > 0: |
| | if not torch.is_tensor(aux): |
| | aux = torch.as_tensor(aux, device=logits.device, dtype=logits.dtype) |
| | else: |
| | aux = aux.to(device=logits.device, dtype=logits.dtype) |
| | aux = aux.clamp_min(0.0) |
| | loss = loss + float(self.config.router_aux_weight) * aux |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=present if use_cache else None, |
| | hidden_states=None, |
| | attentions=None, |
| | ) |
| |
|