Spaces:
Running
on
Zero
Running
on
Zero
make sure max_encoder_seq_length matches (#55)
Browse files* make sure max_encoder_seq_length matches
* black and assert comment
---------
Co-authored-by: Srini Iyer <[email protected]>
- bytelatent/train.py +4 -1
bytelatent/train.py
CHANGED
|
@@ -130,6 +130,9 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
|
| 130 |
if args.model is not None:
|
| 131 |
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
| 132 |
args.model.vocab_size = output_size
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
if args.entropy_model is not None:
|
| 135 |
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
|
@@ -610,7 +613,7 @@ def train(args: TrainArgs):
|
|
| 610 |
interval_total_tok_loss_across_gpus = dist_sum(
|
| 611 |
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
| 612 |
).item()
|
| 613 |
-
interval_total_n_bytes_per_gpu = n_bytes
|
| 614 |
interval_total_n_bytes_across_gpus = dist_sum(
|
| 615 |
n_bytes, reduce_dtype=torch.bfloat16
|
| 616 |
).item()
|
|
|
|
| 130 |
if args.model is not None:
|
| 131 |
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
| 132 |
args.model.vocab_size = output_size
|
| 133 |
+
assert (
|
| 134 |
+
args.model.max_encoder_seq_length == args.data.max_encoder_seq_length
|
| 135 |
+
), "max_encoder_seq_length for model and data should match"
|
| 136 |
|
| 137 |
if args.entropy_model is not None:
|
| 138 |
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
|
|
|
| 613 |
interval_total_tok_loss_across_gpus = dist_sum(
|
| 614 |
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
| 615 |
).item()
|
| 616 |
+
interval_total_n_bytes_per_gpu = n_bytes.item()
|
| 617 |
interval_total_n_bytes_across_gpus = dist_sum(
|
| 618 |
n_bytes, reduce_dtype=torch.bfloat16
|
| 619 |
).item()
|