Spaces:
Running
on
Zero
Running
on
Zero
NielsRogge
commited on
Improve HF integration (#98)
Browse files* Add mixin
* Update license
- bytelatent/model/blt.py +7 -3
bytelatent/model/blt.py
CHANGED
|
@@ -4,7 +4,7 @@ from enum import Enum, auto
|
|
| 4 |
from typing import Any, Optional
|
| 5 |
|
| 6 |
import torch
|
| 7 |
-
from pydantic import
|
| 8 |
from torch import nn
|
| 9 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 10 |
from typing_extensions import Self
|
|
@@ -13,7 +13,6 @@ from bytelatent.base_transformer import (
|
|
| 13 |
BaseTransformerArgs,
|
| 14 |
InitStdFactor,
|
| 15 |
SequenceModelWithOutput,
|
| 16 |
-
TransformerBlock,
|
| 17 |
)
|
| 18 |
from bytelatent.data.patcher import Patcher, PatcherArgs
|
| 19 |
from bytelatent.model.latent_transformer import GlobalTransformer
|
|
@@ -21,6 +20,8 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel
|
|
| 21 |
from bytelatent.model.utils import downsample
|
| 22 |
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
| 23 |
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
| 26 |
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
|
@@ -767,7 +768,10 @@ def compute_hash_embeddings(
|
|
| 767 |
return local_encoder_embeds
|
| 768 |
|
| 769 |
|
| 770 |
-
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput
|
|
|
|
|
|
|
|
|
|
| 771 |
"""
|
| 772 |
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
| 773 |
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
|
|
|
| 4 |
from typing import Any, Optional
|
| 5 |
|
| 6 |
import torch
|
| 7 |
+
from pydantic import model_validator
|
| 8 |
from torch import nn
|
| 9 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 10 |
from typing_extensions import Self
|
|
|
|
| 13 |
BaseTransformerArgs,
|
| 14 |
InitStdFactor,
|
| 15 |
SequenceModelWithOutput,
|
|
|
|
| 16 |
)
|
| 17 |
from bytelatent.data.patcher import Patcher, PatcherArgs
|
| 18 |
from bytelatent.model.latent_transformer import GlobalTransformer
|
|
|
|
| 20 |
from bytelatent.model.utils import downsample
|
| 21 |
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
| 22 |
|
| 23 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 24 |
+
|
| 25 |
|
| 26 |
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
| 27 |
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
|
|
|
| 768 |
return local_encoder_embeds
|
| 769 |
|
| 770 |
|
| 771 |
+
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
|
| 772 |
+
repo_url="https://github.com/facebookresearch/blt",
|
| 773 |
+
pipeline_tag="text-generation",
|
| 774 |
+
license="other"):
|
| 775 |
"""
|
| 776 |
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
| 777 |
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|