Spaces:
Running
on
Zero
Running
on
Zero
Reduce per file resources arrow uses (#77)
Browse files
bytelatent/args.py
CHANGED
|
@@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel):
|
|
| 138 |
preprocess_dir: str | None = None
|
| 139 |
dataset_files: list[str] | None = None
|
| 140 |
entropy_model_name: str | None = "transformer_100m"
|
| 141 |
-
|
|
|
|
| 142 |
buffer_size: int = 64
|
| 143 |
file_format: str = "arrow"
|
| 144 |
|
|
|
|
| 138 |
preprocess_dir: str | None = None
|
| 139 |
dataset_files: list[str] | None = None
|
| 140 |
entropy_model_name: str | None = "transformer_100m"
|
| 141 |
+
# Be very careful with increasing, increases memory usage by that factor per rank, per data source
|
| 142 |
+
arrow_batch_size: int = 20
|
| 143 |
buffer_size: int = 64
|
| 144 |
file_format: str = "arrow"
|
| 145 |
|
bytelatent/data/iterators/arrow_iterator.py
CHANGED
|
@@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator):
|
|
| 226 |
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
| 227 |
yield out
|
| 228 |
|
| 229 |
-
self.batch_iterator = self.dataset.to_batches(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
for batch in self.batch_iterator:
|
| 231 |
batch_columns = batch.to_pydict()
|
| 232 |
if self.file_format == "arrow":
|
|
|
|
| 226 |
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
| 227 |
yield out
|
| 228 |
|
| 229 |
+
self.batch_iterator = self.dataset.to_batches(
|
| 230 |
+
batch_size=self.arrow_batch_size,
|
| 231 |
+
# We have large files in GBs, no need to readahead
|
| 232 |
+
fragment_readahead=1,
|
| 233 |
+
# Don't readahead in case batches are huge (e.g., books)
|
| 234 |
+
batch_readahead=1,
|
| 235 |
+
)
|
| 236 |
for batch in self.batch_iterator:
|
| 237 |
batch_columns = batch.to_pydict()
|
| 238 |
if self.file_format == "arrow":
|
bytelatent/data/iterators/sequence_iterator.py
CHANGED
|
@@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import (
|
|
| 10 |
PydanticIteratorState,
|
| 11 |
StatefulIterator,
|
| 12 |
)
|
|
|
|
|
|
|
|
|
|
| 13 |
from bytelatent.data.iterators.preprocess_iterator import (
|
| 14 |
PreprocessIterator,
|
| 15 |
PreprocessIteratorState,
|
|
@@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState):
|
|
| 40 |
)
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
class SequenceIterator(StatefulIterator):
|
| 44 |
def __init__(
|
| 45 |
self,
|
|
@@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator):
|
|
| 74 |
tokens: list[int] = []
|
| 75 |
mask: list[bool] = []
|
| 76 |
first = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
for example in example_iter:
|
| 78 |
assert example.tokens is not None
|
| 79 |
assert example.mask is not None
|
|
@@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator):
|
|
| 97 |
while len(patch_lengths) >= n_buffer_patches:
|
| 98 |
if first:
|
| 99 |
first = False
|
| 100 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
| 103 |
self.buffer_size, self.output_seq_len
|
|
|
|
| 10 |
PydanticIteratorState,
|
| 11 |
StatefulIterator,
|
| 12 |
)
|
| 13 |
+
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
| 14 |
+
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
| 15 |
+
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
| 16 |
from bytelatent.data.iterators.preprocess_iterator import (
|
| 17 |
PreprocessIterator,
|
| 18 |
PreprocessIteratorState,
|
|
|
|
| 43 |
)
|
| 44 |
|
| 45 |
|
| 46 |
+
def get_datafile(
|
| 47 |
+
iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator,
|
| 48 |
+
):
|
| 49 |
+
if isinstance(iterator, ArrowFileIterator):
|
| 50 |
+
return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}"
|
| 51 |
+
elif isinstance(iterator, PreprocessIterator):
|
| 52 |
+
return get_datafile(iterator.arrow_iterator)
|
| 53 |
+
elif isinstance(iterator, LoopingIterator):
|
| 54 |
+
return get_datafile(iterator.file_iterator)
|
| 55 |
+
elif isinstance(iterator, LimitIterator):
|
| 56 |
+
return get_datafile(iterator.base_iterator)
|
| 57 |
+
else:
|
| 58 |
+
raise NotImplementedError()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
class SequenceIterator(StatefulIterator):
|
| 62 |
def __init__(
|
| 63 |
self,
|
|
|
|
| 92 |
tokens: list[int] = []
|
| 93 |
mask: list[bool] = []
|
| 94 |
first = True
|
| 95 |
+
logger.info(
|
| 96 |
+
"Starting first buffer for: %s",
|
| 97 |
+
get_datafile(self.preprocess_iterator),
|
| 98 |
+
)
|
| 99 |
for example in example_iter:
|
| 100 |
assert example.tokens is not None
|
| 101 |
assert example.mask is not None
|
|
|
|
| 119 |
while len(patch_lengths) >= n_buffer_patches:
|
| 120 |
if first:
|
| 121 |
first = False
|
| 122 |
+
logger.info(
|
| 123 |
+
"First buffer complete for: %s",
|
| 124 |
+
get_datafile(self.preprocess_iterator),
|
| 125 |
+
)
|
| 126 |
|
| 127 |
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
| 128 |
self.buffer_size, self.output_seq_len
|
bytelatent/iterate_data.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import pyarrow
|
| 4 |
+
import typer
|
| 5 |
+
from rich.progress import track
|
| 6 |
+
|
| 7 |
+
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
|
| 8 |
+
from bytelatent.logger import init_logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main(state_file: str):
|
| 12 |
+
init_logger()
|
| 13 |
+
pyarrow.set_io_thread_count(4)
|
| 14 |
+
pyarrow.set_cpu_count(4)
|
| 15 |
+
with open(state_file) as f:
|
| 16 |
+
train_state = json.load(f)
|
| 17 |
+
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
|
| 18 |
+
packing_iterator_state = dl_state.base_iterator_state
|
| 19 |
+
print("building")
|
| 20 |
+
packing_iterator = packing_iterator_state.build()
|
| 21 |
+
print("iter")
|
| 22 |
+
batch_iter = packing_iterator.create_iter()
|
| 23 |
+
batch = None
|
| 24 |
+
print("looping")
|
| 25 |
+
for i in track(range(1_000)):
|
| 26 |
+
batch = next(batch_iter)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
typer.run(main)
|
bytelatent/train.py
CHANGED
|
@@ -13,6 +13,7 @@ from timeit import default_timer as timer
|
|
| 13 |
from typing import Any, TypeVar
|
| 14 |
|
| 15 |
import numpy as np
|
|
|
|
| 16 |
import torch
|
| 17 |
import torch.distributed
|
| 18 |
import torch.nn.functional
|
|
@@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale):
|
|
| 266 |
|
| 267 |
def train(args: TrainArgs):
|
| 268 |
with ExitStack() as context_stack:
|
|
|
|
|
|
|
| 269 |
tokenizer = args.data.tokenizer_args.build()
|
| 270 |
validate_train_args(
|
| 271 |
args,
|
|
|
|
| 13 |
from typing import Any, TypeVar
|
| 14 |
|
| 15 |
import numpy as np
|
| 16 |
+
import pyarrow
|
| 17 |
import torch
|
| 18 |
import torch.distributed
|
| 19 |
import torch.nn.functional
|
|
|
|
| 267 |
|
| 268 |
def train(args: TrainArgs):
|
| 269 |
with ExitStack() as context_stack:
|
| 270 |
+
pyarrow.set_io_thread_count(4)
|
| 271 |
+
pyarrow.set_cpu_count(4)
|
| 272 |
tokenizer = args.data.tokenizer_args.build()
|
| 273 |
validate_train_args(
|
| 274 |
args,
|