Spaces:
Running
on
A100
Running
on
A100
File size: 36,178 Bytes
ae238b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 |
from __future__ import annotations
# Standard library imports
import logging
import os
import tempfile
from typing import Dict, List, Optional, Tuple
# Third-party imports
import librosa
import numpy as np
import soundfile as sf
import torch
import uroman
# fairseq2 imports
from inference.align_utils import get_uroman_tokens
from inference.audio_chunker import AudioChunker
from inference.audio_reading_tools import wav_to_bytes
# Import AudioAlignment and its config classes
from inference.audio_sentence_alignment import AudioAlignment
from inference.mms_model_pipeline import MMSModel
from inference.text_normalization import text_normalize
from transcription_status import transcription_status
from env_vars import USE_CHUNKING
# Constants
SAMPLE_RATE = 16000
logger = logging.getLogger(__name__)
def transcribe_single_chunk(audio_tensor: torch.Tensor, sample_rate: int = 16000, language_with_script: str = None):
"""
Basic transcription pipeline for a single audio chunk using MMS model pipeline.
This is the lowest-level transcription function that handles individual audio segments.
Args:
audio_tensor (torch.Tensor): Audio tensor (1D waveform)
sample_rate (int): Sample rate of the audio tensor
language_with_script (str): language_with_script for transcription (3-letter ISO codes like "eng", "spa") with script
Returns:
str: Transcribed text
"""
logger.info("Starting complete audio transcription pipeline...")
try:
logger.info("Using pipeline transcription...")
# Use the singleton model instance
model = MMSModel.get_instance()
# Transcribe using pipeline - convert tensor to list format
lang_list = [language_with_script] if language_with_script else None
results = model.transcribe_audio(audio_tensor, batch_size=1, language_with_scripts=lang_list)
result = results[0] if results else {}
# Convert pipeline result to expected format
if isinstance(result, dict) and 'text' in result:
transcription_text = result['text']
elif isinstance(result, str):
transcription_text = result
else:
transcription_text = str(result)
if not transcription_text.strip():
logger.warning("Pipeline returned empty transcription")
return ""
logger.info(f"✓ Pipeline transcription successful: '{transcription_text}'")
# Return the transcription text
return transcription_text
except Exception as e:
logger.error(f"Error in transcription pipeline: {str(e)}", exc_info=True)
raise
def perform_forced_alignment(
audio_tensor: torch.Tensor,
transcription_tokens: List[str],
device,
sample_rate: int = 16000,
) -> List[Dict]:
"""
Perform forced alignment using the AudioAlignment class from audio_sentence_alignment.py.
Uses the provided audio tensor directly.
Args:
audio_tensor (torch.Tensor): Audio tensor (1D waveform)
transcription_tokens (List[str]): List of tokens from transcription
device: Device for computation
sample_rate (int): Audio sample rate
Returns:
List[Dict]: List of segments with timestamps and text
"""
try:
logger.info(f"Starting forced alignment with audio tensor")
logger.info(f"Audio shape: {audio_tensor.shape}, sample_rate: {sample_rate}")
logger.info(f"Tokens to align: {transcription_tokens}")
# Use the provided audio tensor directly
# Convert to the format expected by AudioAlignment.get_one_row_alignments
if hasattr(audio_tensor, "cpu"):
# If it's a torch tensor, use it directly
alignment_tensor = audio_tensor.float()
else:
# If it's numpy, convert to tensor
alignment_tensor = torch.from_numpy(audio_tensor).float()
# Ensure it's 1D (flatten if needed)
if len(alignment_tensor.shape) > 1:
alignment_tensor = alignment_tensor.flatten()
# Convert audio tensor to bytes format expected by AudioAlignment
# Use wav_to_bytes to create proper audio bytes
# Move tensor to CPU first to avoid CUDA tensor to numpy conversion error
audio_tensor_cpu = alignment_tensor.cpu() if alignment_tensor.is_cuda else alignment_tensor
audio_arr = wav_to_bytes(
audio_tensor_cpu, sample_rate=sample_rate, format="wav"
)
logger.info(
f"Converted audio to bytes: shape={audio_arr.shape}, dtype={audio_arr.dtype}"
)
# Preprocess tokens for MMS alignment model using the same approach as TextRomanizer
# The MMS alignment model expects romanized tokens in the same format as text_sentences_tokens
try:
# Join tokens back to text for uroman processing
transcription_text = " ".join(transcription_tokens)
# Create uroman instance and process the text the same way as TextRomanizer
uroman_instance = uroman.Uroman()
# Step 1: Normalize the text first using text_normalize function (same as TextRomanizer)
normalized_text = text_normalize(transcription_text.strip(), "en")
# Step 2: Get uroman tokens using the same function as TextRomanizer
# This creates character-level tokens with spaces between characters
uroman_tokens_str = get_uroman_tokens(
[normalized_text], uroman_instance, "en"
)[0]
# Step 3: Split by spaces to get individual character tokens (same as real MMS pipeline)
alignment_tokens = uroman_tokens_str.split()
logger.info(f"Original tokens: {transcription_tokens}")
logger.info(f"Original text: '{transcription_text}'")
logger.info(f"Normalized text: '{normalized_text}'")
logger.info(f"Uroman tokens string: '{uroman_tokens_str}'")
logger.info(
f"Alignment tokens (count={len(alignment_tokens)}): {alignment_tokens[:20]}..."
)
# Additional debugging - check for any unusual characters
for i, token in enumerate(alignment_tokens[:10]): # Check first 10 tokens
logger.info(
f"Token {i}: '{token}' (length={len(token)}, chars={[c for c in token]})"
)
except Exception as e:
logger.warning(
f"Failed to preprocess tokens with TextRomanizer approach: {e}"
)
logger.exception("Full error traceback:")
# Fallback: use simple character-level tokenization
transcription_text = " ".join(transcription_tokens).lower()
# Simple character-level tokenization as fallback
alignment_tokens = []
for char in transcription_text:
if char == " ":
alignment_tokens.append(" ")
else:
alignment_tokens.append(char)
logger.info(f"Using fallback character tokens: {alignment_tokens[:20]}...")
logger.info(
f"Using {len(alignment_tokens)} alignment tokens for forced alignment"
)
# Create AudioAlignment instance
logger.info("Creating AudioAlignment instance...")
alignment = AudioAlignment()
# Perform alignment using get_one_row_alignments
logger.info("Performing alignment...")
logger.info(f"About to call get_one_row_alignments with:")
logger.info(f" audio_arr type: {type(audio_arr)}, shape: {audio_arr.shape}")
logger.info(
f" alignment_tokens type: {type(alignment_tokens)}, length: {len(alignment_tokens)}"
)
logger.info(
f" First 10 tokens: {alignment_tokens[:10] if len(alignment_tokens) >= 10 else alignment_tokens}"
)
# Check for any problematic characters in tokens
for i, token in enumerate(alignment_tokens[:5]):
token_chars = [ord(c) for c in str(token)]
logger.info(f" Token {i} '{token}' char codes: {token_chars}")
# Check if tokens contain any RTL characters that might cause the LTR assertion
rtl_chars = []
for i, token in enumerate(alignment_tokens):
for char in str(token):
# Check for Arabic, Hebrew, and other RTL characters
if (
"\u0590" <= char <= "\u08ff"
or "\ufb1d" <= char <= "\ufdff"
or "\ufe70" <= char <= "\ufeff"
):
rtl_chars.append((i, token, char, ord(char)))
if rtl_chars:
logger.warning(f"Found RTL characters in tokens: {rtl_chars[:10]}...")
try:
audio_segments = alignment.get_one_row_alignments(
audio_arr, sample_rate, alignment_tokens
)
except Exception as alignment_error:
logger.error(f"Alignment failed with error: {alignment_error}")
logger.error(f"Error type: {type(alignment_error)}")
# Try to provide more context about the error
if "ltr" in str(alignment_error).lower():
logger.error("LTR assertion error detected. This might be due to:")
logger.error("1. RTL characters in the input tokens")
logger.error(
"2. Incorrect token format - tokens should be individual characters"
)
logger.error("3. Unicode normalization issues")
# Try a simple ASCII-only fallback
logger.info("Attempting ASCII-only fallback...")
ascii_tokens = []
for token in alignment_tokens:
# Keep only ASCII characters
ascii_token = "".join(c for c in str(token) if ord(c) < 128)
if ascii_token:
ascii_tokens.append(ascii_token)
logger.info(
f"ASCII tokens (count={len(ascii_tokens)}): {ascii_tokens[:20]}..."
)
try:
audio_segments = alignment.get_one_row_alignments(
audio_arr, ascii_tokens
)
alignment_tokens = ascii_tokens # Update for later use
logger.info("ASCII fallback successful!")
except Exception as ascii_error:
logger.error(f"ASCII fallback also failed: {ascii_error}")
raise alignment_error
else:
raise
logger.info(
f"Alignment completed, got {len(audio_segments)} character segments"
)
# Debug: Log the actual structure of audio_segments
if audio_segments:
logger.info("=== Audio Segments Debug Info ===")
logger.info(f"Total segments: {len(audio_segments)}")
# Print ALL audio segments for complete debugging
logger.info("=== ALL AUDIO SEGMENTS ===")
for i, segment in enumerate(audio_segments):
logger.info(f"Segment {i}: {segment}")
if i > 0 and i % 20 == 0: # Print progress every 20 segments
logger.info(
f"... printed {i+1}/{len(audio_segments)} segments so far..."
)
logger.info("=== End All Audio Segments ===")
logger.info("=== End Audio Segments Debug ===")
# Convert character-level segments back to word-level segments
# Use the actual alignment timings to preserve silence and natural timing
aligned_segments = []
logger.info(
f"Converting {len(audio_segments)} character segments to word segments"
)
logger.info(f"Original tokens: {transcription_tokens}")
logger.info(f"Alignment tokens: {alignment_tokens[:20]}...")
# Validate that we have segments and tokens
if not audio_segments or not transcription_tokens:
logger.warning("No audio segments or transcription tokens available")
return []
# Get actual timing from character segments
if audio_segments:
# Use the known segment keys from audio_sentence_alignment
start_key, duration_key = "segment_start_sec", "segment_duration"
first_segment = audio_segments[0]
last_segment = audio_segments[-1]
total_audio_duration = last_segment.get(start_key, 0) + last_segment.get(
duration_key, 0
)
logger.info(
f"Total audio duration from segments: {total_audio_duration:.3f}s"
)
else:
total_audio_duration = 0.0
start_key, duration_key = "segment_start_sec", "segment_duration"
# Strategy: Group character segments by words using the actual alignment timing
# This preserves the natural timing including silences from the forced alignment
# First, reconstruct the alignment character sequence
alignment_char_sequence = "".join(alignment_tokens)
transcription_text = "".join(
transcription_tokens
) # Remove spaces for character matching
logger.info(f"Alignment sequence length: {len(alignment_char_sequence)}")
logger.info(f"Transcription length: {len(transcription_text)}")
# Create word boundaries based on romanized alignment tokens
# We need to map each original word to its position in the romanized sequence
word_boundaries = []
alignment_pos = 0
# Process each word individually to get its romanized representation
for word in transcription_tokens:
try:
# Get romanized version of this individual word
normalized_word = text_normalize(word.strip(), "en")
uroman_word_str = get_uroman_tokens([normalized_word], uroman_instance, "en")[0]
romanized_word_tokens = uroman_word_str.split()
word_start = alignment_pos
word_end = alignment_pos + len(romanized_word_tokens)
word_boundaries.append((word_start, word_end))
alignment_pos = word_end
logger.info(f"Word '{word}' -> romanized tokens {romanized_word_tokens} -> positions {word_start}-{word_end}")
except Exception as e:
logger.warning(f"Failed to romanize word '{word}': {e}")
# Fallback: estimate based on character length ratio
estimated_length = max(1, int(len(word) * len(alignment_tokens) / len(transcription_text)))
word_start = alignment_pos
word_end = min(alignment_pos + estimated_length, len(alignment_tokens))
word_boundaries.append((word_start, word_end))
alignment_pos = word_end
logger.info(f"Word '{word}' (fallback) -> estimated positions {word_start}-{word_end}")
logger.info(f"Word boundaries (romanized): {word_boundaries[:5]}...")
logger.info(f"Total alignment tokens used: {alignment_pos}/{len(alignment_tokens)}")
# Map each word to its character segments using the boundaries
for word_idx, (word, (word_start, word_end)) in enumerate(
zip(transcription_tokens, word_boundaries)
):
# Find character segments that belong to this word
word_segments = []
# Map word character range to alignment token indices
# Since alignment_tokens might be slightly different due to normalization,
# we'll be flexible and use a range around the expected positions
start_idx = max(0, min(word_start, len(audio_segments) - 1))
end_idx = min(word_end, len(audio_segments))
# Ensure we don't go beyond available segments
for seg_idx in range(start_idx, end_idx):
if seg_idx < len(audio_segments):
word_segments.append(audio_segments[seg_idx])
if word_segments:
# Use actual timing from the character segments for this word
start_times = [seg.get(start_key, 0) for seg in word_segments]
end_times = [
seg.get(start_key, 0) + seg.get(duration_key, 0)
for seg in word_segments
]
start_time = min(start_times) if start_times else 0
end_time = max(end_times) if end_times else start_time + 0.1
duration = end_time - start_time
# Ensure minimum duration
if duration < 0.05: # Minimum 50ms
duration = 0.05
end_time = start_time + duration
logger.debug(
f"Word '{word}' (segments {start_idx}-{end_idx}, {len(word_segments)} segs): {start_time:.3f}s - {end_time:.3f}s ({duration:.3f}s)"
)
else:
logger.warning(
f"No segments found for word '{word}' at position {word_start}-{word_end}"
)
# Fallback: use proportional timing if no segments found
if total_audio_duration > 0 and len(transcription_text) > 0:
start_proportion = word_start / len(transcription_text)
end_proportion = word_end / len(transcription_text)
start_time = start_proportion * total_audio_duration
end_time = end_proportion * total_audio_duration
duration = end_time - start_time
else:
# Ultimate fallback
word_duration = 0.5
start_time = word_idx * word_duration
end_time = start_time + word_duration
duration = word_duration
logger.debug(
f"Word '{word}' (fallback): {start_time:.3f}s - {end_time:.3f}s"
)
aligned_segments.append(
{
"text": word,
"start": start_time,
"end": end_time,
"duration": duration,
}
)
# Validate segments don't overlap but preserve natural gaps/silences
for i in range(1, len(aligned_segments)):
prev_end = aligned_segments[i - 1]["end"]
current_start = aligned_segments[i]["start"]
if current_start < prev_end:
# Only fix actual overlaps, don't force adjacency
gap = prev_end - current_start
logger.debug(
f"Overlap detected: segment {i-1} ends at {prev_end:.3f}s, segment {i} starts at {current_start:.3f}s (overlap: {gap:.3f}s)"
)
# Fix overlap by adjusting current segment start to previous end
aligned_segments[i]["start"] = prev_end
aligned_segments[i]["duration"] = (
aligned_segments[i]["end"] - aligned_segments[i]["start"]
)
logger.debug(
f"Fixed overlap for segment {i}: adjusted start to {prev_end:.3f}s"
)
else:
# Log natural gaps (this is normal and expected)
gap = current_start - prev_end
if gap > 0.1: # Log gaps > 100ms
logger.debug(
f"Natural gap preserved: {gap:.3f}s between segments {i-1} and {i}"
)
logger.info(f"Forced alignment completed: {len(aligned_segments)} segments")
return aligned_segments
except Exception as e:
logger.error(f"Error in forced alignment: {str(e)}", exc_info=True)
# Fallback: create uniform timestamps based on audio tensor length
logger.info("Using fallback uniform timestamps")
try:
# Calculate duration from the audio tensor
total_duration = (
len(audio_tensor) / sample_rate
if len(audio_tensor) > 0
else len(transcription_tokens) * 0.5
)
except:
total_duration = len(transcription_tokens) * 0.5 # Fallback
segment_duration = (
total_duration / len(transcription_tokens) if transcription_tokens else 1.0
)
fallback_segments = []
for i, token in enumerate(transcription_tokens):
start_time = i * segment_duration
end_time = (i + 1) * segment_duration
fallback_segments.append(
{
"text": token,
"start": start_time,
"end": end_time,
"duration": segment_duration,
}
)
logger.info(
f"Using fallback uniform timestamps: {len(fallback_segments)} segments"
)
return fallback_segments
def transcribe_with_word_alignment(audio_tensor: torch.Tensor, sample_rate: int = 16000, language_with_script: str = None) -> Dict:
"""
Transcription pipeline that includes word-level timing through forced alignment.
Adds precise word-level timestamps to the basic transcription capability.
Args:
audio_tensor (torch.Tensor): Audio tensor (1D waveform)
sample_rate (int): Sample rate of the audio tensor
language_with_script (str): language_with_script code for transcription (3-letter ISO codes like "eng", "spa") with script
Returns:
Dict: Transcription results with alignment information including word-level timestamps
"""
try:
# Get model and device first
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Get the transcription results
transcription_text = transcribe_single_chunk(audio_tensor, sample_rate=sample_rate, language_with_script=language_with_script)
if not transcription_text:
return {
"transcription": "",
"tokens": [],
"aligned_segments": [],
"total_duration": 0.0,
}
# Tokenize the transcription for alignment
tokens = transcription_text.split()
# Perform forced alignment using the original audio tensor
logger.info("Performing forced alignment with original audio tensor...")
aligned_segments = perform_forced_alignment(audio_tensor, tokens, device, sample_rate)
# Calculate total duration
total_duration = aligned_segments[-1]["end"] if aligned_segments else 0.0
result = {
"transcription": transcription_text,
"tokens": tokens,
"aligned_segments": aligned_segments,
"total_duration": total_duration,
"num_segments": len(aligned_segments),
}
logger.info(
f"Transcription with alignment completed: {len(aligned_segments)} segments, {total_duration:.2f}s total"
)
return result
except Exception as e:
logger.error(f"Error in transcription with alignment: {str(e)}", exc_info=True)
# Return basic transcription without alignment
try:
transcription_text = transcribe_single_chunk(audio_tensor, sample_rate=sample_rate, language_with_script=language_with_script)
tokens = transcription_text.split() if transcription_text else []
return {
"transcription": transcription_text,
"tokens": tokens,
"aligned_segments": [],
"total_duration": 0.0,
"alignment_error": str(e),
}
except Exception as e2:
logger.error(f"Error in fallback transcription: {str(e2)}", exc_info=True)
return {
"transcription": "",
"tokens": [],
"aligned_segments": [],
"total_duration": 0.0,
"error": str(e2),
}
def _validate_and_adjust_segments(
aligned_segments: List[Dict],
chunk_start_time: float,
chunk_audio_tensor: torch.Tensor,
chunk_sample_rate: int,
chunk_duration: float,
chunk_index: int
) -> List[Dict]:
"""
Private helper function to validate and adjust segment timestamps to global timeline.
Args:
aligned_segments: Raw segments from forced alignment (local chunk timeline)
chunk_start_time: Start time of this chunk in global timeline
chunk_audio_tensor: Audio tensor for this chunk (to get actual duration)
chunk_sample_rate: Sample rate of the chunk
chunk_duration: Reported duration of the chunk
chunk_index: Index of this chunk for debugging
Returns:
List of validated segments with global timeline timestamps
"""
adjusted_segments = []
# Get the actual audio duration from the chunk tensor instead of the potentially incorrect chunk duration
actual_chunk_duration = len(chunk_audio_tensor) / chunk_sample_rate if len(chunk_audio_tensor) > 0 else chunk_duration
for segment in aligned_segments:
original_start = segment["start"]
original_end = segment["end"]
# Validate that segment timestamps are within chunk boundaries
if original_start < 0:
logger.warning(
f"Segment '{segment['text']}' has negative start time {original_start:.3f}s, clipping to 0"
)
original_start = 0
if original_end > actual_chunk_duration + 1.0: # Allow 1s buffer for alignment errors
logger.warning(
f"Segment '{segment['text']}' end time {original_end:.3f}s exceeds actual chunk duration {actual_chunk_duration:.3f}s, clipping"
)
original_end = actual_chunk_duration
if original_start >= original_end:
logger.warning(
f"Segment '{segment['text']}' has invalid timing {original_start:.3f}s-{original_end:.3f}s, using fallback"
)
# Use proportional timing based on segment position using actual chunk duration
segment_index = len(adjusted_segments)
total_segments = len(aligned_segments)
if total_segments > 0:
segment_proportion = segment_index / total_segments
next_proportion = (segment_index + 1) / total_segments
original_start = segment_proportion * actual_chunk_duration
original_end = next_proportion * actual_chunk_duration
else:
original_start = 0
original_end = 0.5
# Create segment with absolute timeline
adjusted_segment = {
"text": segment["text"],
"start": original_start + chunk_start_time, # Global timeline
"end": original_end + chunk_start_time, # Global timeline
"duration": original_end - original_start,
"chunk_index": chunk_index,
"original_start": original_start, # Local chunk time
"original_end": original_end, # Local chunk time
}
adjusted_segments.append(adjusted_segment)
logger.debug(
f"Segment '{segment['text']}': {original_start:.3f}-{original_end:.3f} -> {adjusted_segment['start']:.3f}-{adjusted_segment['end']:.3f}"
)
logger.info(
f"Adjusted {len(adjusted_segments)} segments to absolute timeline (chunk starts at {chunk_start_time:.2f}s)"
)
return adjusted_segments
def transcribe_full_audio_with_chunking(
audio_tensor: torch.Tensor, sample_rate: int = 16000, chunk_duration: float = 30.0, language_with_script: str = None, progress_callback=None
) -> Dict:
"""
Complete audio transcription pipeline that handles any length audio with intelligent chunking.
This is the full-featured transcription function that can process both short and long audio files.
Chunking mode is controlled by USE_CHUNKING environment variable:
- USE_CHUNKING=false: No chunking (single chunk mode)
- USE_CHUNKING=true (default): VAD-based intelligent chunking
Args:
audio_tensor: Audio tensor (1D waveform)
sample_rate: Sample rate of the audio tensor
chunk_duration: Target chunk duration in seconds (for static chunking)
language_with_script: {Language code}_{script} for transcription
progress_callback: Optional callback for progress updates
Returns:
Dict with full transcription and segment information including word-level timestamps
"""
try:
logger.info(f"Starting long-form transcription: tensor shape {audio_tensor.shape} at {sample_rate}Hz")
logger.info(f"USE_CHUNKING = {USE_CHUNKING}")
# Initialize chunker
chunker = AudioChunker()
# Determine chunking mode based on USE_CHUNKING setting
chunking_mode = "vad" if USE_CHUNKING else "none"
# Chunk the audio using the new unified interface
# Ensure tensor is 1D before chunking (squeeze any extra dimensions)
if len(audio_tensor.shape) > 1:
logger.info(f"Squeezing audio tensor from {audio_tensor.shape} to 1D")
audio_tensor_1d = audio_tensor.squeeze()
else:
audio_tensor_1d = audio_tensor
chunks = chunker.chunk_audio(audio_tensor_1d, sample_rate=sample_rate, mode=chunking_mode, chunk_duration=chunk_duration)
if not chunks:
logger.warning("No audio chunks created")
return {
"transcription": "",
"chunks": [],
"total_duration": 0.0,
"error": "No audio content detected",
}
logger.info(f"Processing {len(chunks)} audio chunks (mode: {chunking_mode})")
# Validate chunk continuity
for i, chunk in enumerate(chunks):
logger.info(
f"Chunk {i+1}: {chunk['start_time']:.2f}s - {chunk['end_time']:.2f}s ({chunk['duration']:.2f}s)"
)
if i > 0:
prev_end = chunks[i - 1]["end_time"]
current_start = chunk["start_time"]
gap = current_start - prev_end
if abs(gap) > 0.1: # More than 100ms gap/overlap
logger.warning(
f"Gap/overlap between chunks {i} and {i+1}: {gap:.3f}s"
)
# Process each chunk - now all chunks have uniform format!
all_segments = []
full_transcription_parts = []
total_duration = 0.0
chunk_details = []
for i, chunk in enumerate(chunks):
logger.info(
f"Processing chunk {i+1}/{len(chunks)} ({chunk['duration']:.1f}s, {chunk['start_time']:.1f}s-{chunk['end_time']:.1f}s)"
)
try:
# Process this chunk using tensor-based transcription pipeline
# Use the chunk's audio_data tensor directly - no more file operations!
chunk_audio_tensor = chunk["audio_data"]
chunk_sample_rate = chunk["sample_rate"]
chunk_result = transcribe_with_word_alignment(
audio_tensor=chunk_audio_tensor,
sample_rate=chunk_sample_rate,
language_with_script=language_with_script
)
# Process alignment results - uniform handling for all chunk types
chunk_segments = []
chunk_start_time = chunk["start_time"]
chunk_duration = chunk["duration"]
if chunk_result.get("aligned_segments"):
logger.info(
f"Chunk {i+1} has {len(chunk_result['aligned_segments'])} segments"
)
chunk_segments = _validate_and_adjust_segments(
aligned_segments=chunk_result["aligned_segments"],
chunk_start_time=chunk_start_time,
chunk_audio_tensor=chunk_audio_tensor,
chunk_sample_rate=chunk_sample_rate,
chunk_duration=chunk_duration,
chunk_index=i
)
all_segments.extend(chunk_segments)
logger.info(f"Chunk {i+1} processed {len(chunk_segments)} valid segments")
# Add to full transcription
chunk_transcription = ""
if chunk_result.get("transcription"):
chunk_transcription = chunk_result["transcription"]
full_transcription_parts.append(chunk_transcription)
# Store detailed chunk information
chunk_detail = {
"chunk_index": i,
"start_time": chunk["start_time"],
"end_time": chunk["end_time"],
"duration": chunk["duration"],
"transcription": chunk_transcription,
"num_segments": len(chunk_segments),
"segments": chunk_segments,
}
chunk_details.append(chunk_detail)
total_duration = max(total_duration, chunk["end_time"])
# Update progress linearly from 0.1 to 0.9 based on chunk processing
progress = 0.1 + (0.8 * (i + 1) / len(chunks))
transcription_status.update_progress(progress)
logger.info(
f"Chunk {i+1} processed: '{chunk_transcription}' ({len(chunk_segments)} segments)"
)
except Exception as chunk_error:
logger.error(f"Error processing chunk {i+1}: {chunk_error}")
# Continue with next chunk
# Combine results
full_transcription = " ".join(full_transcription_parts)
# Validate segment continuity
logger.info("Validating segment continuity...")
for i in range(1, len(all_segments)):
prev_end = all_segments[i - 1]["end"]
current_start = all_segments[i]["start"]
gap = current_start - prev_end
if abs(gap) > 1.0: # More than 1 second gap
logger.warning(f"Large gap between segments {i-1} and {i}: {gap:.3f}s")
result = {
"transcription": full_transcription,
"aligned_segments": all_segments,
"chunks": [
{
"chunk_index": chunk_detail["chunk_index"],
"start_time": chunk_detail["start_time"],
"end_time": chunk_detail["end_time"],
"duration": chunk_detail["duration"],
"transcription": chunk_detail["transcription"],
"num_segments": chunk_detail["num_segments"],
}
for chunk_detail in chunk_details
],
"chunk_details": chunk_details, # Full details including segments per chunk
"total_duration": total_duration,
"num_chunks": len(chunks),
"num_segments": len(all_segments),
"status": "success",
}
logger.info(
f"Long-form transcription completed: {len(chunks)} chunks, {total_duration:.1f}s total"
)
logger.info(f"Total segments: {len(all_segments)}")
# Log chunk timing summary
for chunk_detail in chunk_details:
logger.info(
f"Chunk {chunk_detail['chunk_index']}: {chunk_detail['start_time']:.2f}-{chunk_detail['end_time']:.2f}s, {chunk_detail['num_segments']} segments"
)
return result
except Exception as e:
logger.error(f"Error in long-form transcription: {str(e)}", exc_info=True)
return {
"transcription": "",
"chunks": [],
"total_duration": 0.0,
"error": str(e),
}
|