Spaces:
Running
on
Zero
Running
on
Zero
Get evals working again. (#46)
Browse files- PPL/validation: Works now and uses multi-gpu. For some reason 1 GPU differs from multi-GPU, can debug in a followup PR
- Generation evals likely work, but are very slow, so disabled for now
Test Plan:
```
torchrun --nproc-per-node 8 -m bytelatent.eval config=../internal-blt/configs/eval.yaml
```
- bytelatent/args.py +4 -0
- bytelatent/distributed.py +42 -0
- bytelatent/eval.py +218 -41
- bytelatent/generate.py +3 -3
- bytelatent/metrics.py +1 -1
- bytelatent/train.py +11 -59
bytelatent/args.py
CHANGED
|
@@ -270,6 +270,10 @@ class EvalArgs(BaseModel):
|
|
| 270 |
dump_dir: str | None = None
|
| 271 |
ckpt_dir: str | None = None
|
| 272 |
metric_log_dir: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
generator: PackedCausalTransformerGeneratorArgs = (
|
| 274 |
PackedCausalTransformerGeneratorArgs()
|
| 275 |
)
|
|
|
|
| 270 |
dump_dir: str | None = None
|
| 271 |
ckpt_dir: str | None = None
|
| 272 |
metric_log_dir: str | None = None
|
| 273 |
+
|
| 274 |
+
run_ppl: bool = True
|
| 275 |
+
run_tasks: bool = False
|
| 276 |
+
|
| 277 |
generator: PackedCausalTransformerGeneratorArgs = (
|
| 278 |
PackedCausalTransformerGeneratorArgs()
|
| 279 |
)
|
bytelatent/distributed.py
CHANGED
|
@@ -15,6 +15,7 @@ from functools import lru_cache, partial, reduce
|
|
| 15 |
from itertools import chain
|
| 16 |
from typing import List, Optional, Tuple, Union
|
| 17 |
|
|
|
|
| 18 |
import torch
|
| 19 |
|
| 20 |
# for no recompute ops
|
|
@@ -78,6 +79,40 @@ class DistributedArgs(BaseModel):
|
|
| 78 |
|
| 79 |
spawn_method: str = "forkserver"
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
class EnvironmentArgs(BaseModel):
|
| 83 |
model_config = ConfigDict(extra="forbid")
|
|
@@ -151,6 +186,13 @@ def dist_mean_dict(x):
|
|
| 151 |
return r
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
@lru_cache()
|
| 155 |
def get_is_torch_run() -> bool:
|
| 156 |
return os.environ.get("LOCAL_RANK") is not None
|
|
|
|
| 15 |
from itertools import chain
|
| 16 |
from typing import List, Optional, Tuple, Union
|
| 17 |
|
| 18 |
+
import numpy as np
|
| 19 |
import torch
|
| 20 |
|
| 21 |
# for no recompute ops
|
|
|
|
| 79 |
|
| 80 |
spawn_method: str = "forkserver"
|
| 81 |
|
| 82 |
+
def configure_world(self):
|
| 83 |
+
pass
|
| 84 |
+
if self.dp_replicate * self.dp_shard * self.tp_size != get_world_size():
|
| 85 |
+
logging.info("Modifying TrainArgs distributed config")
|
| 86 |
+
assert get_world_size() % self.dp_shard == 0
|
| 87 |
+
logging.info("World size: %s", get_world_size())
|
| 88 |
+
logging.info(
|
| 89 |
+
"Existing setting: train_args.distributed.dp_shard=%s",
|
| 90 |
+
self.dp_shard,
|
| 91 |
+
)
|
| 92 |
+
logging.info(
|
| 93 |
+
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
| 94 |
+
get_world_size() // self.dp_shard,
|
| 95 |
+
self.dp_replicate,
|
| 96 |
+
)
|
| 97 |
+
self.dp_replicate = get_world_size() // self.dp_shard
|
| 98 |
+
|
| 99 |
+
logging.info(
|
| 100 |
+
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
| 101 |
+
self.dp_replicate,
|
| 102 |
+
self.dp_replicate // self.tp_size,
|
| 103 |
+
self.tp_size,
|
| 104 |
+
)
|
| 105 |
+
assert self.dp_replicate % self.tp_size == 0
|
| 106 |
+
self.dp_replicate = self.dp_replicate // self.tp_size
|
| 107 |
+
|
| 108 |
+
logger.warning(
|
| 109 |
+
f"Setting Data Parallel size to {self.dp_replicate * self.dp_shard}"
|
| 110 |
+
)
|
| 111 |
+
assert self.dp_replicate * self.dp_shard * self.tp_size == get_world_size()
|
| 112 |
+
|
| 113 |
+
if self.fsdp_type == "no_shard":
|
| 114 |
+
assert self.dp_shard == 1 and self.dp_replicate == get_world_size()
|
| 115 |
+
|
| 116 |
|
| 117 |
class EnvironmentArgs(BaseModel):
|
| 118 |
model_config = ConfigDict(extra="forbid")
|
|
|
|
| 186 |
return r
|
| 187 |
|
| 188 |
|
| 189 |
+
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
|
| 190 |
+
if isinstance(num, (torch.Tensor, np.ndarray)):
|
| 191 |
+
return num.item()
|
| 192 |
+
else:
|
| 193 |
+
return num
|
| 194 |
+
|
| 195 |
+
|
| 196 |
@lru_cache()
|
| 197 |
def get_is_torch_run() -> bool:
|
| 198 |
return os.environ.get("LOCAL_RANK") is not None
|
bytelatent/eval.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import logging
|
|
|
|
| 5 |
import os
|
| 6 |
from collections import defaultdict
|
| 7 |
from datetime import datetime
|
|
@@ -10,22 +11,48 @@ import torch
|
|
| 10 |
from lm_eval import simple_evaluate
|
| 11 |
from lm_eval.api.instance import Instance
|
| 12 |
from lm_eval.api.model import LM
|
| 13 |
-
|
| 14 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
| 16 |
from bytelatent.config_parser import parse_args_to_pydantic_model
|
| 17 |
from bytelatent.data.file_util import get_fs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from bytelatent.distributed import (
|
| 19 |
DistributedArgs,
|
| 20 |
dist_mean_dict,
|
|
|
|
|
|
|
| 21 |
get_global_rank,
|
| 22 |
get_world_size,
|
| 23 |
setup_torch_distributed,
|
|
|
|
| 24 |
)
|
| 25 |
from bytelatent.generate import (
|
| 26 |
PackedCausalTransformerGenerator,
|
| 27 |
load_consolidated_model_and_tokenizer,
|
| 28 |
)
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
EVAL_FOLDER_NAME = "{:010d}"
|
| 31 |
|
|
@@ -113,19 +140,134 @@ class EvalHarnessLM(LM):
|
|
| 113 |
return results
|
| 114 |
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
for src in val_args.sources:
|
| 119 |
path = os.path.join(val_args.root_dir, src)
|
| 120 |
-
srcs
|
|
|
|
| 121 |
for src in train_cfg.data.sources:
|
| 122 |
path = os.path.join(train_cfg.data.root_dir, src)
|
| 123 |
-
srcs
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
max_gen_len = generator.max_gen_len
|
| 131 |
# We temporarily lower max gen len
|
|
@@ -133,16 +275,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
|
| 133 |
|
| 134 |
all_val_metrics = {}
|
| 135 |
for src in path_to_iter:
|
| 136 |
-
|
| 137 |
texts = []
|
| 138 |
logger.info(f"Running validation on {src}...")
|
| 139 |
-
for step,
|
| 140 |
-
|
| 141 |
-
val_args.max_steps is not None and step >= val_args.max_steps
|
| 142 |
-
):
|
| 143 |
-
break
|
| 144 |
-
content_key = "text" if ("text" in content) else "content"
|
| 145 |
-
texts.append(content[content_key])
|
| 146 |
|
| 147 |
_, loglikelihood, _ = generator.generate(texts)
|
| 148 |
|
|
@@ -174,8 +311,18 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
|
| 174 |
|
| 175 |
|
| 176 |
def launch_eval(eval_args: EvalArgs):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
if not torch.distributed.is_initialized():
|
| 178 |
-
setup_torch_distributed(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
|
| 181 |
if (
|
|
@@ -187,7 +334,7 @@ def launch_eval(eval_args: EvalArgs):
|
|
| 187 |
else:
|
| 188 |
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
| 189 |
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
| 190 |
-
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
| 191 |
|
| 192 |
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
| 193 |
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
|
@@ -200,35 +347,67 @@ def launch_eval(eval_args: EvalArgs):
|
|
| 200 |
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
| 201 |
consolidate_path,
|
| 202 |
)
|
| 203 |
-
logger.info("Model loaded")
|
| 204 |
model.eval()
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
if get_global_rank() == 0:
|
| 214 |
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
| 215 |
f.write(json.dumps(results))
|
| 216 |
-
logger.info(f"All evaluation results: {results
|
| 217 |
-
if
|
| 218 |
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
| 219 |
-
f.write(json.dumps(
|
| 220 |
-
logger.info(f"All validation results: {
|
|
|
|
| 221 |
if eval_args.metric_log_dir and get_global_rank() == 0:
|
| 222 |
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
| 223 |
|
| 224 |
logger.info(f"Writing metric logs to {metric_log_path}")
|
| 225 |
-
timestamp = {
|
| 226 |
"created_at": datetime.utcnow().isoformat(),
|
| 227 |
}
|
| 228 |
if eval_args.global_step is not None:
|
| 229 |
timestamp["global_step"] = eval_args.global_step
|
| 230 |
print(
|
| 231 |
-
json.dumps(timestamp | results
|
| 232 |
file=fs.open(metric_log_path, mode="a"),
|
| 233 |
flush=True,
|
| 234 |
)
|
|
@@ -236,18 +415,16 @@ def launch_eval(eval_args: EvalArgs):
|
|
| 236 |
val_log_path = os.path.join(
|
| 237 |
eval_args.metric_log_dir, "metrics.validation.jsonl"
|
| 238 |
)
|
| 239 |
-
if
|
| 240 |
print(
|
| 241 |
-
json.dumps(timestamp |
|
| 242 |
file=fs.open(val_log_path, mode="a"),
|
| 243 |
flush=True,
|
| 244 |
)
|
| 245 |
|
| 246 |
-
del generator
|
| 247 |
-
|
| 248 |
|
| 249 |
def main():
|
| 250 |
-
eval_args =
|
| 251 |
launch_eval(eval_args)
|
| 252 |
|
| 253 |
|
|
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
+
import math
|
| 6 |
import os
|
| 7 |
from collections import defaultdict
|
| 8 |
from datetime import datetime
|
|
|
|
| 11 |
from lm_eval import simple_evaluate
|
| 12 |
from lm_eval.api.instance import Instance
|
| 13 |
from lm_eval.api.model import LM
|
| 14 |
+
from rich.progress import track
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
|
| 17 |
+
from bytelatent.args import (
|
| 18 |
+
EvalArgs,
|
| 19 |
+
TrainArgs,
|
| 20 |
+
ValidationArgs,
|
| 21 |
+
find_and_sanitize_chunks,
|
| 22 |
+
)
|
| 23 |
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
| 24 |
from bytelatent.config_parser import parse_args_to_pydantic_model
|
| 25 |
from bytelatent.data.file_util import get_fs
|
| 26 |
+
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
| 27 |
+
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
| 28 |
+
from bytelatent.data.iterators.packing_iterator import (
|
| 29 |
+
PackingArgs,
|
| 30 |
+
PackingIterator,
|
| 31 |
+
PackingMode,
|
| 32 |
+
)
|
| 33 |
+
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
| 34 |
+
from bytelatent.data.iterators.sequence_iterator import (
|
| 35 |
+
SequenceIterator,
|
| 36 |
+
SequencePackingArgs,
|
| 37 |
+
)
|
| 38 |
+
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
| 39 |
from bytelatent.distributed import (
|
| 40 |
DistributedArgs,
|
| 41 |
dist_mean_dict,
|
| 42 |
+
dist_sum,
|
| 43 |
+
get_device_mesh,
|
| 44 |
get_global_rank,
|
| 45 |
get_world_size,
|
| 46 |
setup_torch_distributed,
|
| 47 |
+
to_py_num,
|
| 48 |
)
|
| 49 |
from bytelatent.generate import (
|
| 50 |
PackedCausalTransformerGenerator,
|
| 51 |
load_consolidated_model_and_tokenizer,
|
| 52 |
)
|
| 53 |
+
from bytelatent.model.blt import ByteLatentTransformer
|
| 54 |
+
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
| 55 |
+
from bytelatent.transformer import LMTransformer
|
| 56 |
|
| 57 |
EVAL_FOLDER_NAME = "{:010d}"
|
| 58 |
|
|
|
|
| 140 |
return results
|
| 141 |
|
| 142 |
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def eval_ppl_on_path(
|
| 145 |
+
*,
|
| 146 |
+
world_rank: int,
|
| 147 |
+
world_size: int,
|
| 148 |
+
model: LMTransformer | ByteLatentTransformer,
|
| 149 |
+
tokenizer_args: TokenizerArgs,
|
| 150 |
+
patcher_args: PatcherArgs,
|
| 151 |
+
add_patches: bool,
|
| 152 |
+
path: str,
|
| 153 |
+
batch_size: int,
|
| 154 |
+
arrow_batch_size: int,
|
| 155 |
+
max_n_docs: int | None,
|
| 156 |
+
s3_profile: str | None = None,
|
| 157 |
+
):
|
| 158 |
+
model.eval()
|
| 159 |
+
tokenizer = tokenizer_args.build()
|
| 160 |
+
seq_len = model.get_output_seq_len()
|
| 161 |
+
chunks = find_and_sanitize_chunks(
|
| 162 |
+
path,
|
| 163 |
+
world_size=1,
|
| 164 |
+
file_pattern="*.val.jsonl",
|
| 165 |
+
s3_profile=s3_profile,
|
| 166 |
+
)
|
| 167 |
+
assert (
|
| 168 |
+
len(chunks) == 1
|
| 169 |
+
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
| 170 |
+
chunk = chunks[0]
|
| 171 |
+
arrow_iterator = ArrowFileIterator(
|
| 172 |
+
file_path=chunk,
|
| 173 |
+
preprocess_dir=None,
|
| 174 |
+
entropy_model_name=None,
|
| 175 |
+
worker_id=world_rank,
|
| 176 |
+
num_workers=world_size,
|
| 177 |
+
arrow_batch_size=arrow_batch_size,
|
| 178 |
+
s3_profile=s3_profile,
|
| 179 |
+
file_format="json",
|
| 180 |
+
)
|
| 181 |
+
if max_n_docs is not None:
|
| 182 |
+
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
|
| 183 |
+
preprocess_iterator = PreprocessIterator(
|
| 184 |
+
arrow_iterator,
|
| 185 |
+
patcher_args=patcher_args,
|
| 186 |
+
tokenizer_args=tokenizer_args,
|
| 187 |
+
add_patches=add_patches,
|
| 188 |
+
)
|
| 189 |
+
sequence_iterator = SequenceIterator(
|
| 190 |
+
preprocess_iterator,
|
| 191 |
+
sequence_packing_args=SequencePackingArgs(
|
| 192 |
+
output_seq_len=seq_len,
|
| 193 |
+
# Effectively disables shuffles
|
| 194 |
+
buffer_size=1,
|
| 195 |
+
),
|
| 196 |
+
rng_state=None,
|
| 197 |
+
)
|
| 198 |
+
packing_args = PackingArgs(
|
| 199 |
+
batch_size=batch_size,
|
| 200 |
+
seq_len=seq_len,
|
| 201 |
+
# TODO: make these seq lens worth with blt
|
| 202 |
+
max_length=seq_len,
|
| 203 |
+
pad_to_max_length=True,
|
| 204 |
+
enable_byte_ngrams=False,
|
| 205 |
+
pad_id=tokenizer.boe_id,
|
| 206 |
+
packing_mode=PackingMode.BYTES,
|
| 207 |
+
)
|
| 208 |
+
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
|
| 209 |
+
total_loss = 0.0
|
| 210 |
+
n_bytes = 0
|
| 211 |
+
batch_iterator = packing_iterator.create_iter()
|
| 212 |
+
for batch in batch_iterator:
|
| 213 |
+
x = torch.from_numpy(batch.x).cuda()
|
| 214 |
+
y = torch.from_numpy(batch.y).cuda()
|
| 215 |
+
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
| 216 |
+
if tokenizer_args.name in ["bytes", "blt"]:
|
| 217 |
+
n_bytes += y.numel() if mask is None else mask.sum().item()
|
| 218 |
+
pred = model(x)
|
| 219 |
+
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
| 220 |
+
total_loss += loss.item()
|
| 221 |
+
else:
|
| 222 |
+
raise NotImplementedError()
|
| 223 |
+
all_n_bytes = to_py_num(dist_sum(n_bytes))
|
| 224 |
+
all_total_loss = to_py_num(dist_sum(total_loss))
|
| 225 |
+
return {
|
| 226 |
+
"n_bytes": all_n_bytes,
|
| 227 |
+
"n_bytes_gpu": n_bytes,
|
| 228 |
+
"loss_sum": all_total_loss,
|
| 229 |
+
"loss_sum_gpu": total_loss,
|
| 230 |
+
"loss_mean": all_total_loss / all_n_bytes,
|
| 231 |
+
"loss_mean_gpu": total_loss / n_bytes,
|
| 232 |
+
"ppl": math.exp(all_total_loss / all_n_bytes) if all_n_bytes > 0 else 0.0,
|
| 233 |
+
"bpb": all_total_loss / math.log(2) / all_n_bytes,
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
| 238 |
+
srcs = []
|
| 239 |
for src in val_args.sources:
|
| 240 |
path = os.path.join(val_args.root_dir, src)
|
| 241 |
+
srcs.append(path)
|
| 242 |
+
|
| 243 |
for src in train_cfg.data.sources:
|
| 244 |
path = os.path.join(train_cfg.data.root_dir, src)
|
| 245 |
+
srcs.append(path)
|
| 246 |
+
|
| 247 |
+
path_to_iter = {}
|
| 248 |
+
for path in srcs:
|
| 249 |
+
chunks = find_and_sanitize_chunks(
|
| 250 |
+
path,
|
| 251 |
+
world_size=1,
|
| 252 |
+
file_pattern="*.val.jsonl",
|
| 253 |
+
s3_profile=train_cfg.data.s3_profile,
|
| 254 |
+
)
|
| 255 |
+
assert (
|
| 256 |
+
len(chunks) == 1
|
| 257 |
+
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
| 258 |
+
chunk = chunks[0]
|
| 259 |
+
iterator = ArrowFileIterator(
|
| 260 |
+
dataset_files=[chunk],
|
| 261 |
+
file_path=None,
|
| 262 |
+
preprocess_dir=None,
|
| 263 |
+
entropy_model_name=None,
|
| 264 |
+
worker_id=0,
|
| 265 |
+
num_workers=1,
|
| 266 |
+
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
| 267 |
+
s3_profile=train_cfg.data.s3_profile,
|
| 268 |
+
file_format="json",
|
| 269 |
+
)
|
| 270 |
+
path_to_iter[path] = iterator
|
| 271 |
|
| 272 |
max_gen_len = generator.max_gen_len
|
| 273 |
# We temporarily lower max gen len
|
|
|
|
| 275 |
|
| 276 |
all_val_metrics = {}
|
| 277 |
for src in path_to_iter:
|
| 278 |
+
example_iterator = path_to_iter[src].create_iter()
|
| 279 |
texts = []
|
| 280 |
logger.info(f"Running validation on {src}...")
|
| 281 |
+
for step, example in enumerate(example_iterator):
|
| 282 |
+
texts.append(example.text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
_, loglikelihood, _ = generator.generate(texts)
|
| 285 |
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
def launch_eval(eval_args: EvalArgs):
|
| 314 |
+
assert eval_args.dump_dir is not None
|
| 315 |
+
assert eval_args.ckpt_dir is not None
|
| 316 |
+
distributed_args = DistributedArgs()
|
| 317 |
+
distributed_args.configure_world()
|
| 318 |
if not torch.distributed.is_initialized():
|
| 319 |
+
setup_torch_distributed(distributed_args)
|
| 320 |
+
|
| 321 |
+
world_mesh = get_device_mesh(distributed_args)
|
| 322 |
+
dp_mesh = world_mesh["dp_replicate"]
|
| 323 |
+
assert distributed_args.dp_shard == 1
|
| 324 |
+
world_size = dp_mesh.size()
|
| 325 |
+
world_rank = dp_mesh.get_local_rank()
|
| 326 |
|
| 327 |
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
|
| 328 |
if (
|
|
|
|
| 334 |
else:
|
| 335 |
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
| 336 |
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
| 337 |
+
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
| 338 |
|
| 339 |
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
| 340 |
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
|
|
|
| 347 |
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
| 348 |
consolidate_path,
|
| 349 |
)
|
|
|
|
| 350 |
model.eval()
|
| 351 |
+
logger.info("Model loaded")
|
| 352 |
+
|
| 353 |
+
ppl_results = None
|
| 354 |
+
if eval_args.run_ppl:
|
| 355 |
+
assert eval_args.validation is not None
|
| 356 |
+
if len(eval_args.validation.sources) > 0:
|
| 357 |
+
ppl_results = {}
|
| 358 |
+
logger.info("Starting PPL evaluation on validation sets")
|
| 359 |
+
for source in eval_args.validation.sources:
|
| 360 |
+
ppl_results[source] = eval_ppl_on_path(
|
| 361 |
+
world_rank=world_rank,
|
| 362 |
+
world_size=world_size,
|
| 363 |
+
model=model,
|
| 364 |
+
tokenizer_args=train_cfg.data.tokenizer_args,
|
| 365 |
+
# TODO: Don't hardcode, modify based on model
|
| 366 |
+
patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
|
| 367 |
+
add_patches=False,
|
| 368 |
+
path=os.path.join(eval_args.validation.root_dir, source),
|
| 369 |
+
max_n_docs=eval_args.validation.max_n_docs,
|
| 370 |
+
batch_size=8,
|
| 371 |
+
arrow_batch_size=100,
|
| 372 |
+
s3_profile="blt",
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
task_results = None
|
| 376 |
+
if eval_args.run_tasks:
|
| 377 |
+
assert eval_args.generator is not None
|
| 378 |
+
assert eval_args.harness is not None
|
| 379 |
+
generator = PackedCausalTransformerGenerator(
|
| 380 |
+
eval_args.generator, model, tokenizer
|
| 381 |
+
)
|
| 382 |
+
wrap = EvalHarnessLM(generator)
|
| 383 |
+
# TODO: This needs to be checked/sped up
|
| 384 |
+
task_results = simple_evaluate(wrap, **eval_args.harness.model_dump())
|
| 385 |
+
|
| 386 |
+
results = {"ppl": ppl_results, "tasks": task_results}
|
| 387 |
+
# TODO: Serial and Parallel yield slightly different number of bytes, debug this later,
|
| 388 |
+
# leaving this log statement here to help with that.
|
| 389 |
+
# logging.info("Rank: %s Results: %s", world_rank, results)
|
| 390 |
+
|
| 391 |
if get_global_rank() == 0:
|
| 392 |
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
| 393 |
f.write(json.dumps(results))
|
| 394 |
+
logger.info(f"All evaluation results: {results}")
|
| 395 |
+
if ppl_results is not None:
|
| 396 |
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
| 397 |
+
f.write(json.dumps(ppl_results))
|
| 398 |
+
logger.info(f"All validation results: {ppl_results}")
|
| 399 |
+
|
| 400 |
if eval_args.metric_log_dir and get_global_rank() == 0:
|
| 401 |
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
| 402 |
|
| 403 |
logger.info(f"Writing metric logs to {metric_log_path}")
|
| 404 |
+
timestamp: dict[str, int | str] = {
|
| 405 |
"created_at": datetime.utcnow().isoformat(),
|
| 406 |
}
|
| 407 |
if eval_args.global_step is not None:
|
| 408 |
timestamp["global_step"] = eval_args.global_step
|
| 409 |
print(
|
| 410 |
+
json.dumps(timestamp | results),
|
| 411 |
file=fs.open(metric_log_path, mode="a"),
|
| 412 |
flush=True,
|
| 413 |
)
|
|
|
|
| 415 |
val_log_path = os.path.join(
|
| 416 |
eval_args.metric_log_dir, "metrics.validation.jsonl"
|
| 417 |
)
|
| 418 |
+
if ppl_results is not None:
|
| 419 |
print(
|
| 420 |
+
json.dumps(timestamp | ppl_results),
|
| 421 |
file=fs.open(val_log_path, mode="a"),
|
| 422 |
flush=True,
|
| 423 |
)
|
| 424 |
|
|
|
|
|
|
|
| 425 |
|
| 426 |
def main():
|
| 427 |
+
eval_args = parse_args_to_pydantic_model(EvalArgs)
|
| 428 |
launch_eval(eval_args)
|
| 429 |
|
| 430 |
|
bytelatent/generate.py
CHANGED
|
@@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
|
|
| 387 |
):
|
| 388 |
train_args_path = os.path.join(consolidated_path, "params.json")
|
| 389 |
fs = get_fs(train_args_path)
|
| 390 |
-
|
| 391 |
-
train_args = TrainArgs.model_validate_json(f.read())
|
| 392 |
|
| 393 |
if train_args.train_entropy_model:
|
| 394 |
model_args = train_args.entropy_model
|
|
@@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
|
|
| 401 |
train_args.distributed.model_dtype
|
| 402 |
]
|
| 403 |
tokenizer = train_args.data.tokenizer_args.build()
|
| 404 |
-
|
|
|
|
| 405 |
model.load_state_dict(st_dict["model"])
|
| 406 |
model = model.cuda().eval()
|
| 407 |
for param in model.parameters():
|
|
|
|
| 387 |
):
|
| 388 |
train_args_path = os.path.join(consolidated_path, "params.json")
|
| 389 |
fs = get_fs(train_args_path)
|
| 390 |
+
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
|
|
|
| 391 |
|
| 392 |
if train_args.train_entropy_model:
|
| 393 |
model_args = train_args.entropy_model
|
|
|
|
| 400 |
train_args.distributed.model_dtype
|
| 401 |
]
|
| 402 |
tokenizer = train_args.data.tokenizer_args.build()
|
| 403 |
+
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
| 404 |
+
st_dict = torch.load(f, weights_only=True)
|
| 405 |
model.load_state_dict(st_dict["model"])
|
| 406 |
model = model.cuda().eval()
|
| 407 |
for param in model.parameters():
|
bytelatent/metrics.py
CHANGED
|
@@ -55,7 +55,7 @@ class LoggingArgs(BaseModel):
|
|
| 55 |
class MetricLogger:
|
| 56 |
def __init__(
|
| 57 |
self,
|
| 58 |
-
outdir:
|
| 59 |
# args: TrainArgs
|
| 60 |
args: Any | None = None,
|
| 61 |
fs: fsspec.AbstractFileSystem | None = None,
|
|
|
|
| 55 |
class MetricLogger:
|
| 56 |
def __init__(
|
| 57 |
self,
|
| 58 |
+
outdir: str,
|
| 59 |
# args: TrainArgs
|
| 60 |
args: Any | None = None,
|
| 61 |
fs: fsspec.AbstractFileSystem | None = None,
|
bytelatent/train.py
CHANGED
|
@@ -48,6 +48,7 @@ from bytelatent.distributed import (
|
|
| 48 |
requeue_slurm_job,
|
| 49 |
setup_env,
|
| 50 |
setup_torch_distributed,
|
|
|
|
| 51 |
)
|
| 52 |
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
| 53 |
from bytelatent.logger import init_logger
|
|
@@ -91,13 +92,6 @@ def get_iterator_state_name(iterator_state):
|
|
| 91 |
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
| 92 |
|
| 93 |
|
| 94 |
-
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
|
| 95 |
-
if isinstance(num, (torch.Tensor, np.ndarray)):
|
| 96 |
-
return num.item()
|
| 97 |
-
else:
|
| 98 |
-
return num
|
| 99 |
-
|
| 100 |
-
|
| 101 |
# TODO: Make this pydantic based instead of data class based
|
| 102 |
# TODO: Generalize this to any iterator state
|
| 103 |
@dataclass
|
|
@@ -154,57 +148,13 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
|
| 154 |
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
| 155 |
args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
if (
|
| 163 |
-
args.distributed.dp_replicate
|
| 164 |
-
* args.distributed.dp_shard
|
| 165 |
-
* args.distributed.tp_size
|
| 166 |
-
!= get_world_size()
|
| 167 |
-
):
|
| 168 |
-
logging.info("Modifying TrainArgs distributed config")
|
| 169 |
-
assert get_world_size() % args.distributed.dp_shard == 0
|
| 170 |
-
logging.info("World size: %s", get_world_size())
|
| 171 |
-
logging.info(
|
| 172 |
-
"Existing setting: train_args.distributed.dp_shard=%s",
|
| 173 |
-
args.distributed.dp_shard,
|
| 174 |
-
)
|
| 175 |
-
logging.info(
|
| 176 |
-
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
| 177 |
-
get_world_size() // args.distributed.dp_shard,
|
| 178 |
-
args.distributed.dp_replicate,
|
| 179 |
-
)
|
| 180 |
-
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
| 181 |
-
|
| 182 |
-
logging.info(
|
| 183 |
-
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
| 184 |
-
args.distributed.dp_replicate,
|
| 185 |
-
args.distributed.dp_replicate // args.distributed.tp_size,
|
| 186 |
-
args.distributed.tp_size,
|
| 187 |
-
)
|
| 188 |
-
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
| 189 |
-
args.distributed.dp_replicate = (
|
| 190 |
-
args.distributed.dp_replicate // args.distributed.tp_size
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
logger.warning(
|
| 194 |
-
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
|
| 195 |
-
)
|
| 196 |
-
assert (
|
| 197 |
-
args.distributed.dp_replicate
|
| 198 |
-
* args.distributed.dp_shard
|
| 199 |
-
* args.distributed.tp_size
|
| 200 |
-
== get_world_size()
|
| 201 |
-
)
|
| 202 |
|
| 203 |
-
|
| 204 |
-
assert (
|
| 205 |
-
args.distributed.dp_shard == 1
|
| 206 |
-
and args.distributed.dp_replicate == get_world_size()
|
| 207 |
-
)
|
| 208 |
|
| 209 |
if args.model is not None:
|
| 210 |
args.model.max_seqlen = args.data.seq_len
|
|
@@ -243,7 +193,9 @@ def set_preemption_flag(signum, frame):
|
|
| 243 |
preemption_flag["flag"] = True
|
| 244 |
|
| 245 |
|
| 246 |
-
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
|
|
|
|
|
|
|
| 247 |
test = train_state.step % freq == 0
|
| 248 |
if acc_step is not None:
|
| 249 |
test = test and (train_state.acc_step == acc_step)
|
|
@@ -272,7 +224,7 @@ def train(args: TrainArgs):
|
|
| 272 |
tokenizer = args.data.tokenizer_args.build()
|
| 273 |
validate_train_args(
|
| 274 |
args,
|
| 275 |
-
tokenizer.
|
| 276 |
)
|
| 277 |
dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
|
| 278 |
if get_is_master():
|
|
|
|
| 48 |
requeue_slurm_job,
|
| 49 |
setup_env,
|
| 50 |
setup_torch_distributed,
|
| 51 |
+
to_py_num,
|
| 52 |
)
|
| 53 |
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
| 54 |
from bytelatent.logger import init_logger
|
|
|
|
| 92 |
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
| 93 |
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
# TODO: Make this pydantic based instead of data class based
|
| 96 |
# TODO: Generalize this to any iterator state
|
| 97 |
@dataclass
|
|
|
|
| 148 |
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
| 149 |
args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
|
| 150 |
|
| 151 |
+
if args.data.root_dir is not None:
|
| 152 |
+
data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
|
| 153 |
+
for source in args.data.sources:
|
| 154 |
+
data_path = os.path.join(args.data.root_dir, source)
|
| 155 |
+
assert data_fs.exists(data_path), f"{data_path} doesn't exist"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
args.distributed.configure_world()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
if args.model is not None:
|
| 160 |
args.model.max_seqlen = args.data.seq_len
|
|
|
|
| 193 |
preemption_flag["flag"] = True
|
| 194 |
|
| 195 |
|
| 196 |
+
def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None):
|
| 197 |
+
if freq < 0:
|
| 198 |
+
return False
|
| 199 |
test = train_state.step % freq == 0
|
| 200 |
if acc_step is not None:
|
| 201 |
test = test and (train_state.acc_step == acc_step)
|
|
|
|
| 224 |
tokenizer = args.data.tokenizer_args.build()
|
| 225 |
validate_train_args(
|
| 226 |
args,
|
| 227 |
+
tokenizer.get_vocab_size(),
|
| 228 |
)
|
| 229 |
dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
|
| 230 |
if get_is_master():
|