Spaces:
Running
on
Zero
Running
on
Zero
Update ppl evals to work with blt model, in addition to entropy model (#82)
Browse filesSummary:
Test Plan:
Run
```
python -m bytelatent.eval config=../internal-blt/configs/eval_blt.yaml validation.max_n_docs=null
python -m bytelatent.eval config=../internal-blt/configs/eval_entropy.yaml validation.max_n_docs=null
```
- bytelatent/args.py +1 -0
- bytelatent/data/iterators/packing_iterator.py +28 -1
- bytelatent/eval.py +32 -110
bytelatent/args.py
CHANGED
|
@@ -263,6 +263,7 @@ class ValidationArgs(BaseModel):
|
|
| 263 |
use_val_from_train_src: bool = True # Use the validation set from training sources
|
| 264 |
root_dir: str = ""
|
| 265 |
sources: list[str] = [] # Other sources to eval on
|
|
|
|
| 266 |
|
| 267 |
|
| 268 |
class EvalArgs(BaseModel):
|
|
|
|
| 263 |
use_val_from_train_src: bool = True # Use the validation set from training sources
|
| 264 |
root_dir: str = ""
|
| 265 |
sources: list[str] = [] # Other sources to eval on
|
| 266 |
+
batch_size: int = 8
|
| 267 |
|
| 268 |
|
| 269 |
class EvalArgs(BaseModel):
|
bytelatent/data/iterators/packing_iterator.py
CHANGED
|
@@ -221,6 +221,7 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
|
| 221 |
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
|
| 222 |
max_length = self.packing_args.max_length
|
| 223 |
assert max_length is not None
|
|
|
|
| 224 |
while True:
|
| 225 |
tokens: list[list[int]] = []
|
| 226 |
masks: list[list[bool]] = []
|
|
@@ -252,6 +253,9 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
|
| 252 |
break
|
| 253 |
|
| 254 |
x_patch_lengths = np.array(patch_lengths)
|
|
|
|
|
|
|
|
|
|
| 255 |
# pad batch to same length
|
| 256 |
tok_seq_len = max([len(toks) for toks in tokens]) - 1
|
| 257 |
x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
|
|
@@ -263,7 +267,30 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
|
| 263 |
# Adjust patch lengths to match x
|
| 264 |
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
|
| 265 |
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
if enable_byte_ngrams:
|
| 269 |
raise NotImplementedError()
|
|
|
|
| 221 |
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
|
| 222 |
max_length = self.packing_args.max_length
|
| 223 |
assert max_length is not None
|
| 224 |
+
final_leftover_batch = False
|
| 225 |
while True:
|
| 226 |
tokens: list[list[int]] = []
|
| 227 |
masks: list[list[bool]] = []
|
|
|
|
| 253 |
break
|
| 254 |
|
| 255 |
x_patch_lengths = np.array(patch_lengths)
|
| 256 |
+
assert (
|
| 257 |
+
x_patch_lengths.shape[1] == seq_len
|
| 258 |
+
), f"{x_patch_lengths.shape[1]} vs {seq_len}"
|
| 259 |
# pad batch to same length
|
| 260 |
tok_seq_len = max([len(toks) for toks in tokens]) - 1
|
| 261 |
x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
|
|
|
|
| 267 |
# Adjust patch lengths to match x
|
| 268 |
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
|
| 269 |
|
| 270 |
+
if x_patch_lengths.shape[0] < batch_size:
|
| 271 |
+
if final_leftover_batch:
|
| 272 |
+
raise ValueError(
|
| 273 |
+
"There should only be one partial batch, but found multiple"
|
| 274 |
+
)
|
| 275 |
+
final_leftover_batch = True
|
| 276 |
+
assert len(masks) == len(x_patch_lengths)
|
| 277 |
+
n_missing = batch_size - x_patch_lengths.shape[0]
|
| 278 |
+
# Repeat the last patch length to validly pad it out, but
|
| 279 |
+
# update the mask to ignore the row
|
| 280 |
+
x_patch_lengths = np.vstack(
|
| 281 |
+
[
|
| 282 |
+
x_patch_lengths,
|
| 283 |
+
np.repeat(x_patch_lengths[-1:, :], n_missing, axis=0),
|
| 284 |
+
]
|
| 285 |
+
)
|
| 286 |
+
for _ in range(n_missing):
|
| 287 |
+
masks.append([0] * tok_seq_len)
|
| 288 |
+
assert len(masks) == batch_size
|
| 289 |
+
|
| 290 |
+
assert x_patch_lengths.shape == (
|
| 291 |
+
batch_size,
|
| 292 |
+
seq_len,
|
| 293 |
+
), f"{x_patch_lengths.shape} vs {(batch_size, seq_len)}"
|
| 294 |
|
| 295 |
if enable_byte_ngrams:
|
| 296 |
raise NotImplementedError()
|
bytelatent/eval.py
CHANGED
|
@@ -148,35 +148,25 @@ def eval_ppl_on_path(
|
|
| 148 |
model: LMTransformer | ByteLatentTransformer,
|
| 149 |
tokenizer_args: TokenizerArgs,
|
| 150 |
patcher_args: PatcherArgs,
|
|
|
|
| 151 |
add_patches: bool,
|
| 152 |
path: str,
|
| 153 |
-
batch_size: int,
|
| 154 |
arrow_batch_size: int,
|
| 155 |
max_n_docs: int | None,
|
| 156 |
s3_profile: str | None = None,
|
| 157 |
):
|
| 158 |
model.eval()
|
| 159 |
-
tokenizer = tokenizer_args.build()
|
| 160 |
seq_len = model.get_output_seq_len()
|
| 161 |
-
chunks = find_and_sanitize_chunks(
|
| 162 |
-
path,
|
| 163 |
-
world_size=1,
|
| 164 |
-
file_pattern="*.val.jsonl",
|
| 165 |
-
s3_profile=s3_profile,
|
| 166 |
-
)
|
| 167 |
-
assert (
|
| 168 |
-
len(chunks) == 1
|
| 169 |
-
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
| 170 |
-
chunk = chunks[0]
|
| 171 |
arrow_iterator = ArrowFileIterator(
|
| 172 |
-
file_path=
|
| 173 |
-
|
| 174 |
entropy_model_name=None,
|
| 175 |
worker_id=world_rank,
|
| 176 |
num_workers=world_size,
|
| 177 |
arrow_batch_size=arrow_batch_size,
|
|
|
|
| 178 |
s3_profile=s3_profile,
|
| 179 |
-
file_format="json",
|
| 180 |
)
|
| 181 |
if max_n_docs is not None:
|
| 182 |
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
|
|
@@ -195,16 +185,6 @@ def eval_ppl_on_path(
|
|
| 195 |
),
|
| 196 |
rng_state=None,
|
| 197 |
)
|
| 198 |
-
packing_args = PackingArgs(
|
| 199 |
-
batch_size=batch_size,
|
| 200 |
-
seq_len=seq_len,
|
| 201 |
-
# TODO: make these seq lens worth with blt
|
| 202 |
-
max_length=seq_len,
|
| 203 |
-
pad_to_max_length=True,
|
| 204 |
-
enable_byte_ngrams=False,
|
| 205 |
-
pad_id=tokenizer.boe_id,
|
| 206 |
-
packing_mode=PackingMode.BYTES,
|
| 207 |
-
)
|
| 208 |
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
|
| 209 |
total_loss = 0.0
|
| 210 |
n_bytes = 0
|
|
@@ -213,9 +193,16 @@ def eval_ppl_on_path(
|
|
| 213 |
x = torch.from_numpy(batch.x).cuda()
|
| 214 |
y = torch.from_numpy(batch.y).cuda()
|
| 215 |
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
if tokenizer_args.name in ["bytes", "blt"]:
|
| 217 |
n_bytes += y.numel() if mask is None else mask.sum().item()
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
| 219 |
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
| 220 |
total_loss += loss.item()
|
| 221 |
else:
|
|
@@ -234,82 +221,6 @@ def eval_ppl_on_path(
|
|
| 234 |
}
|
| 235 |
|
| 236 |
|
| 237 |
-
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
| 238 |
-
srcs = []
|
| 239 |
-
for src in val_args.sources:
|
| 240 |
-
path = os.path.join(val_args.root_dir, src)
|
| 241 |
-
srcs.append(path)
|
| 242 |
-
|
| 243 |
-
for src in train_cfg.data.sources:
|
| 244 |
-
path = os.path.join(train_cfg.data.root_dir, src)
|
| 245 |
-
srcs.append(path)
|
| 246 |
-
|
| 247 |
-
path_to_iter = {}
|
| 248 |
-
for path in srcs:
|
| 249 |
-
chunks = find_and_sanitize_chunks(
|
| 250 |
-
path,
|
| 251 |
-
world_size=1,
|
| 252 |
-
file_pattern="*.val.jsonl",
|
| 253 |
-
s3_profile=train_cfg.data.s3_profile,
|
| 254 |
-
)
|
| 255 |
-
assert (
|
| 256 |
-
len(chunks) == 1
|
| 257 |
-
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
| 258 |
-
chunk = chunks[0]
|
| 259 |
-
iterator = ArrowFileIterator(
|
| 260 |
-
dataset_files=[chunk],
|
| 261 |
-
file_path=None,
|
| 262 |
-
preprocess_dir=None,
|
| 263 |
-
entropy_model_name=None,
|
| 264 |
-
worker_id=0,
|
| 265 |
-
num_workers=1,
|
| 266 |
-
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
| 267 |
-
s3_profile=train_cfg.data.s3_profile,
|
| 268 |
-
file_format="json",
|
| 269 |
-
)
|
| 270 |
-
path_to_iter[path] = iterator
|
| 271 |
-
|
| 272 |
-
max_gen_len = generator.max_gen_len
|
| 273 |
-
# We temporarily lower max gen len
|
| 274 |
-
generator.max_gen_len = 1
|
| 275 |
-
|
| 276 |
-
all_val_metrics = {}
|
| 277 |
-
for src in path_to_iter:
|
| 278 |
-
example_iterator = path_to_iter[src].create_iter()
|
| 279 |
-
texts = []
|
| 280 |
-
logger.info(f"Running validation on {src}...")
|
| 281 |
-
for step, example in enumerate(example_iterator):
|
| 282 |
-
texts.append(example.text)
|
| 283 |
-
|
| 284 |
-
_, loglikelihood, _ = generator.generate(texts)
|
| 285 |
-
|
| 286 |
-
metrics = defaultdict(list)
|
| 287 |
-
for i, ll in enumerate(loglikelihood):
|
| 288 |
-
tmp = ll.sum().item()
|
| 289 |
-
metrics["nll"].append(tmp)
|
| 290 |
-
metrics["nll_per_token"].append(tmp / len(ll))
|
| 291 |
-
metrics["nll_per_char"].append(tmp / len(texts[i]))
|
| 292 |
-
|
| 293 |
-
metrics["avg_seqlen"].append(len(ll))
|
| 294 |
-
|
| 295 |
-
for m in metrics:
|
| 296 |
-
metrics[m] = sum(metrics[m]) / len(metrics[m])
|
| 297 |
-
metrics.update(dist_mean_dict(metrics))
|
| 298 |
-
logger.info(f"Validation on {src} done. Metrics: {metrics}")
|
| 299 |
-
|
| 300 |
-
name = os.path.basename(src)
|
| 301 |
-
if name in all_val_metrics:
|
| 302 |
-
logger.warning(
|
| 303 |
-
f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
|
| 304 |
-
)
|
| 305 |
-
name = f"{name}_1"
|
| 306 |
-
all_val_metrics[name] = metrics
|
| 307 |
-
|
| 308 |
-
generator.max_gen_len = max_gen_len
|
| 309 |
-
|
| 310 |
-
return all_val_metrics
|
| 311 |
-
|
| 312 |
-
|
| 313 |
def launch_eval(eval_args: EvalArgs):
|
| 314 |
assert eval_args.dump_dir is not None
|
| 315 |
assert eval_args.ckpt_dir is not None
|
|
@@ -342,17 +253,29 @@ def launch_eval(eval_args: EvalArgs):
|
|
| 342 |
|
| 343 |
torch.distributed.barrier()
|
| 344 |
logger.info("Loading model")
|
| 345 |
-
# TODO: Make this general so that it works with either
|
| 346 |
-
# LMTransformer or Blt, similar with args
|
| 347 |
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
| 348 |
consolidate_path,
|
| 349 |
)
|
|
|
|
| 350 |
model.eval()
|
| 351 |
logger.info("Model loaded")
|
| 352 |
|
| 353 |
ppl_results = None
|
| 354 |
if eval_args.run_ppl:
|
| 355 |
assert eval_args.validation is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
if len(eval_args.validation.sources) > 0:
|
| 357 |
ppl_results = {}
|
| 358 |
logger.info("Starting PPL evaluation on validation sets")
|
|
@@ -362,14 +285,13 @@ def launch_eval(eval_args: EvalArgs):
|
|
| 362 |
world_size=world_size,
|
| 363 |
model=model,
|
| 364 |
tokenizer_args=train_cfg.data.tokenizer_args,
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
add_patches=
|
| 368 |
path=os.path.join(eval_args.validation.root_dir, source),
|
| 369 |
max_n_docs=eval_args.validation.max_n_docs,
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
s3_profile="blt",
|
| 373 |
)
|
| 374 |
|
| 375 |
task_results = None
|
|
|
|
| 148 |
model: LMTransformer | ByteLatentTransformer,
|
| 149 |
tokenizer_args: TokenizerArgs,
|
| 150 |
patcher_args: PatcherArgs,
|
| 151 |
+
packing_args: PackingArgs,
|
| 152 |
add_patches: bool,
|
| 153 |
path: str,
|
|
|
|
| 154 |
arrow_batch_size: int,
|
| 155 |
max_n_docs: int | None,
|
| 156 |
s3_profile: str | None = None,
|
| 157 |
):
|
| 158 |
model.eval()
|
|
|
|
| 159 |
seq_len = model.get_output_seq_len()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
arrow_iterator = ArrowFileIterator(
|
| 161 |
+
file_path=None,
|
| 162 |
+
dataset_files=[path],
|
| 163 |
entropy_model_name=None,
|
| 164 |
worker_id=world_rank,
|
| 165 |
num_workers=world_size,
|
| 166 |
arrow_batch_size=arrow_batch_size,
|
| 167 |
+
preprocess_dir=None,
|
| 168 |
s3_profile=s3_profile,
|
| 169 |
+
file_format="arrow" if path.endswith("arrow") else "json",
|
| 170 |
)
|
| 171 |
if max_n_docs is not None:
|
| 172 |
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
|
|
|
|
| 185 |
),
|
| 186 |
rng_state=None,
|
| 187 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
|
| 189 |
total_loss = 0.0
|
| 190 |
n_bytes = 0
|
|
|
|
| 193 |
x = torch.from_numpy(batch.x).cuda()
|
| 194 |
y = torch.from_numpy(batch.y).cuda()
|
| 195 |
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
| 196 |
+
patch_lengths = batch.patch_lengths
|
| 197 |
+
if patch_lengths is not None:
|
| 198 |
+
patch_lengths = torch.from_numpy(patch_lengths).cuda()
|
| 199 |
+
|
| 200 |
if tokenizer_args.name in ["bytes", "blt"]:
|
| 201 |
n_bytes += y.numel() if mask is None else mask.sum().item()
|
| 202 |
+
if isinstance(model, ByteLatentTransformer):
|
| 203 |
+
pred = model(x, patch_lengths=patch_lengths)
|
| 204 |
+
else:
|
| 205 |
+
pred = model(x)
|
| 206 |
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
| 207 |
total_loss += loss.item()
|
| 208 |
else:
|
|
|
|
| 221 |
}
|
| 222 |
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
def launch_eval(eval_args: EvalArgs):
|
| 225 |
assert eval_args.dump_dir is not None
|
| 226 |
assert eval_args.ckpt_dir is not None
|
|
|
|
| 253 |
|
| 254 |
torch.distributed.barrier()
|
| 255 |
logger.info("Loading model")
|
|
|
|
|
|
|
| 256 |
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
| 257 |
consolidate_path,
|
| 258 |
)
|
| 259 |
+
pad_id = 0 if train_cfg.data.tokenizer_args.name == "bytes" else tokenizer.boe_id
|
| 260 |
model.eval()
|
| 261 |
logger.info("Model loaded")
|
| 262 |
|
| 263 |
ppl_results = None
|
| 264 |
if eval_args.run_ppl:
|
| 265 |
assert eval_args.validation is not None
|
| 266 |
+
packing_args = PackingArgs(
|
| 267 |
+
batch_size=eval_args.validation.batch_size,
|
| 268 |
+
seq_len=train_cfg.data.seq_len,
|
| 269 |
+
max_length=train_cfg.data.max_encoder_seq_length,
|
| 270 |
+
pad_to_max_length=True,
|
| 271 |
+
enable_byte_ngrams=False,
|
| 272 |
+
pad_id=pad_id,
|
| 273 |
+
packing_mode=(
|
| 274 |
+
PackingMode.BYTES
|
| 275 |
+
if train_cfg.data.patcher_args.patching_mode == PatchingModeEnum.byte
|
| 276 |
+
else PackingMode.PATCHING
|
| 277 |
+
),
|
| 278 |
+
)
|
| 279 |
if len(eval_args.validation.sources) > 0:
|
| 280 |
ppl_results = {}
|
| 281 |
logger.info("Starting PPL evaluation on validation sets")
|
|
|
|
| 285 |
world_size=world_size,
|
| 286 |
model=model,
|
| 287 |
tokenizer_args=train_cfg.data.tokenizer_args,
|
| 288 |
+
patcher_args=train_cfg.data.patcher_args,
|
| 289 |
+
packing_args=packing_args,
|
| 290 |
+
add_patches=train_cfg.data.add_patches,
|
| 291 |
path=os.path.join(eval_args.validation.root_dir, source),
|
| 292 |
max_n_docs=eval_args.validation.max_n_docs,
|
| 293 |
+
arrow_batch_size=20,
|
| 294 |
+
s3_profile=eval_args.s3_profile,
|
|
|
|
| 295 |
)
|
| 296 |
|
| 297 |
task_results = None
|