File size: 36,178 Bytes
ae238b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
from __future__ import annotations

# Standard library imports
import logging
import os
import tempfile
from typing import Dict, List, Optional, Tuple

# Third-party imports
import librosa
import numpy as np
import soundfile as sf
import torch
import uroman

# fairseq2 imports
from inference.align_utils import get_uroman_tokens
from inference.audio_chunker import AudioChunker

from inference.audio_reading_tools import wav_to_bytes

# Import AudioAlignment and its config classes
from inference.audio_sentence_alignment import AudioAlignment
from inference.mms_model_pipeline import MMSModel
from inference.text_normalization import text_normalize
from transcription_status import transcription_status
from env_vars import USE_CHUNKING

# Constants
SAMPLE_RATE = 16000

logger = logging.getLogger(__name__)


def transcribe_single_chunk(audio_tensor: torch.Tensor, sample_rate: int = 16000, language_with_script: str = None):
    """
    Basic transcription pipeline for a single audio chunk using MMS model pipeline.
    This is the lowest-level transcription function that handles individual audio segments.

    Args:
        audio_tensor (torch.Tensor): Audio tensor (1D waveform)
        sample_rate (int): Sample rate of the audio tensor
        language_with_script (str): language_with_script for transcription (3-letter ISO codes like "eng", "spa") with script

    Returns:
        str: Transcribed text
    """

    logger.info("Starting complete audio transcription pipeline...")

    try:
        logger.info("Using pipeline transcription...")
        # Use the singleton model instance
        model = MMSModel.get_instance()

        # Transcribe using pipeline - convert tensor to list format
        lang_list = [language_with_script] if language_with_script else None
        results = model.transcribe_audio(audio_tensor, batch_size=1, language_with_scripts=lang_list)
        result = results[0] if results else {}

        # Convert pipeline result to expected format
        if isinstance(result, dict) and 'text' in result:
            transcription_text = result['text']
        elif isinstance(result, str):
            transcription_text = result
        else:
            transcription_text = str(result)

        if not transcription_text.strip():
            logger.warning("Pipeline returned empty transcription")
            return ""

        logger.info(f"✓ Pipeline transcription successful: '{transcription_text}'")

        # Return the transcription text
        return transcription_text

    except Exception as e:
        logger.error(f"Error in transcription pipeline: {str(e)}", exc_info=True)
        raise


def perform_forced_alignment(
    audio_tensor: torch.Tensor,
    transcription_tokens: List[str],
    device,
    sample_rate: int = 16000,
) -> List[Dict]:
    """
    Perform forced alignment using the AudioAlignment class from audio_sentence_alignment.py.
    Uses the provided audio tensor directly.

    Args:
        audio_tensor (torch.Tensor): Audio tensor (1D waveform)
        transcription_tokens (List[str]): List of tokens from transcription
        device: Device for computation
        sample_rate (int): Audio sample rate

    Returns:
        List[Dict]: List of segments with timestamps and text
    """

    try:
        logger.info(f"Starting forced alignment with audio tensor")
        logger.info(f"Audio shape: {audio_tensor.shape}, sample_rate: {sample_rate}")
        logger.info(f"Tokens to align: {transcription_tokens}")

        # Use the provided audio tensor directly
        # Convert to the format expected by AudioAlignment.get_one_row_alignments
        if hasattr(audio_tensor, "cpu"):
            # If it's a torch tensor, use it directly
            alignment_tensor = audio_tensor.float()
        else:
            # If it's numpy, convert to tensor
            alignment_tensor = torch.from_numpy(audio_tensor).float()

        # Ensure it's 1D (flatten if needed)
        if len(alignment_tensor.shape) > 1:
            alignment_tensor = alignment_tensor.flatten()

        # Convert audio tensor to bytes format expected by AudioAlignment
        # Use wav_to_bytes to create proper audio bytes
        # Move tensor to CPU first to avoid CUDA tensor to numpy conversion error
        audio_tensor_cpu = alignment_tensor.cpu() if alignment_tensor.is_cuda else alignment_tensor

        audio_arr = wav_to_bytes(
            audio_tensor_cpu, sample_rate=sample_rate, format="wav"
        )

        logger.info(
            f"Converted audio to bytes: shape={audio_arr.shape}, dtype={audio_arr.dtype}"
        )

        # Preprocess tokens for MMS alignment model using the same approach as TextRomanizer
        # The MMS alignment model expects romanized tokens in the same format as text_sentences_tokens
        try:
            # Join tokens back to text for uroman processing
            transcription_text = " ".join(transcription_tokens)

            # Create uroman instance and process the text the same way as TextRomanizer
            uroman_instance = uroman.Uroman()

            # Step 1: Normalize the text first using text_normalize function (same as TextRomanizer)
            normalized_text = text_normalize(transcription_text.strip(), "en")

            # Step 2: Get uroman tokens using the same function as TextRomanizer
            # This creates character-level tokens with spaces between characters
            uroman_tokens_str = get_uroman_tokens(
                [normalized_text], uroman_instance, "en"
            )[0]

            # Step 3: Split by spaces to get individual character tokens (same as real MMS pipeline)
            alignment_tokens = uroman_tokens_str.split()

            logger.info(f"Original tokens: {transcription_tokens}")
            logger.info(f"Original text: '{transcription_text}'")
            logger.info(f"Normalized text: '{normalized_text}'")
            logger.info(f"Uroman tokens string: '{uroman_tokens_str}'")
            logger.info(
                f"Alignment tokens (count={len(alignment_tokens)}): {alignment_tokens[:20]}..."
            )

            # Additional debugging - check for any unusual characters
            for i, token in enumerate(alignment_tokens[:10]):  # Check first 10 tokens
                logger.info(
                    f"Token {i}: '{token}' (length={len(token)}, chars={[c for c in token]})"
                )

        except Exception as e:
            logger.warning(
                f"Failed to preprocess tokens with TextRomanizer approach: {e}"
            )
            logger.exception("Full error traceback:")
            # Fallback: use simple character-level tokenization
            transcription_text = " ".join(transcription_tokens).lower()
            # Simple character-level tokenization as fallback
            alignment_tokens = []
            for char in transcription_text:
                if char == " ":
                    alignment_tokens.append(" ")
                else:
                    alignment_tokens.append(char)
            logger.info(f"Using fallback character tokens: {alignment_tokens[:20]}...")

        logger.info(
            f"Using {len(alignment_tokens)} alignment tokens for forced alignment"
        )

        # Create AudioAlignment instance
        logger.info("Creating AudioAlignment instance...")
        alignment = AudioAlignment()

        # Perform alignment using get_one_row_alignments
        logger.info("Performing alignment...")
        logger.info(f"About to call get_one_row_alignments with:")
        logger.info(f"  audio_arr type: {type(audio_arr)}, shape: {audio_arr.shape}")
        logger.info(
            f"  alignment_tokens type: {type(alignment_tokens)}, length: {len(alignment_tokens)}"
        )
        logger.info(
            f"  First 10 tokens: {alignment_tokens[:10] if len(alignment_tokens) >= 10 else alignment_tokens}"
        )

        # Check for any problematic characters in tokens
        for i, token in enumerate(alignment_tokens[:5]):
            token_chars = [ord(c) for c in str(token)]
            logger.info(f"  Token {i} '{token}' char codes: {token_chars}")

        # Check if tokens contain any RTL characters that might cause the LTR assertion
        rtl_chars = []
        for i, token in enumerate(alignment_tokens):
            for char in str(token):
                # Check for Arabic, Hebrew, and other RTL characters
                if (
                    "\u0590" <= char <= "\u08ff"
                    or "\ufb1d" <= char <= "\ufdff"
                    or "\ufe70" <= char <= "\ufeff"
                ):
                    rtl_chars.append((i, token, char, ord(char)))

        if rtl_chars:
            logger.warning(f"Found RTL characters in tokens: {rtl_chars[:10]}...")

        try:
            audio_segments = alignment.get_one_row_alignments(
                audio_arr, sample_rate, alignment_tokens
            )

        except Exception as alignment_error:
            logger.error(f"Alignment failed with error: {alignment_error}")
            logger.error(f"Error type: {type(alignment_error)}")

            # Try to provide more context about the error
            if "ltr" in str(alignment_error).lower():
                logger.error("LTR assertion error detected. This might be due to:")
                logger.error("1. RTL characters in the input tokens")
                logger.error(
                    "2. Incorrect token format - tokens should be individual characters"
                )
                logger.error("3. Unicode normalization issues")

                # Try a simple ASCII-only fallback
                logger.info("Attempting ASCII-only fallback...")
                ascii_tokens = []
                for token in alignment_tokens:
                    # Keep only ASCII characters
                    ascii_token = "".join(c for c in str(token) if ord(c) < 128)
                    if ascii_token:
                        ascii_tokens.append(ascii_token)

                logger.info(
                    f"ASCII tokens (count={len(ascii_tokens)}): {ascii_tokens[:20]}..."
                )

                try:
                    audio_segments = alignment.get_one_row_alignments(
                        audio_arr, ascii_tokens
                    )
                    alignment_tokens = ascii_tokens  # Update for later use
                    logger.info("ASCII fallback successful!")
                except Exception as ascii_error:
                    logger.error(f"ASCII fallback also failed: {ascii_error}")
                    raise alignment_error
            else:
                raise

        logger.info(
            f"Alignment completed, got {len(audio_segments)} character segments"
        )

        # Debug: Log the actual structure of audio_segments
        if audio_segments:
            logger.info("=== Audio Segments Debug Info ===")
            logger.info(f"Total segments: {len(audio_segments)}")

            # Print ALL audio segments for complete debugging
            logger.info("=== ALL AUDIO SEGMENTS ===")
            for i, segment in enumerate(audio_segments):
                logger.info(f"Segment {i}: {segment}")
                if i > 0 and i % 20 == 0:  # Print progress every 20 segments
                    logger.info(
                        f"... printed {i+1}/{len(audio_segments)} segments so far..."
                    )
            logger.info("=== End All Audio Segments ===")
            logger.info("=== End Audio Segments Debug ===")

        # Convert character-level segments back to word-level segments
        # Use the actual alignment timings to preserve silence and natural timing
        aligned_segments = []

        logger.info(
            f"Converting {len(audio_segments)} character segments to word segments"
        )
        logger.info(f"Original tokens: {transcription_tokens}")
        logger.info(f"Alignment tokens: {alignment_tokens[:20]}...")

        # Validate that we have segments and tokens
        if not audio_segments or not transcription_tokens:
            logger.warning("No audio segments or transcription tokens available")
            return []

        # Get actual timing from character segments
        if audio_segments:
            # Use the known segment keys from audio_sentence_alignment
            start_key, duration_key = "segment_start_sec", "segment_duration"

            first_segment = audio_segments[0]
            last_segment = audio_segments[-1]

            total_audio_duration = last_segment.get(start_key, 0) + last_segment.get(
                duration_key, 0
            )
            logger.info(
                f"Total audio duration from segments: {total_audio_duration:.3f}s"
            )
        else:
            total_audio_duration = 0.0
            start_key, duration_key = "segment_start_sec", "segment_duration"

        # Strategy: Group character segments by words using the actual alignment timing
        # This preserves the natural timing including silences from the forced alignment

        # First, reconstruct the alignment character sequence
        alignment_char_sequence = "".join(alignment_tokens)
        transcription_text = "".join(
            transcription_tokens
        )  # Remove spaces for character matching

        logger.info(f"Alignment sequence length: {len(alignment_char_sequence)}")
        logger.info(f"Transcription length: {len(transcription_text)}")

        # Create word boundaries based on romanized alignment tokens
        # We need to map each original word to its position in the romanized sequence
        word_boundaries = []
        alignment_pos = 0

        # Process each word individually to get its romanized representation
        for word in transcription_tokens:
            try:
                # Get romanized version of this individual word
                normalized_word = text_normalize(word.strip(), "en")
                uroman_word_str = get_uroman_tokens([normalized_word], uroman_instance, "en")[0]
                romanized_word_tokens = uroman_word_str.split()

                word_start = alignment_pos
                word_end = alignment_pos + len(romanized_word_tokens)
                word_boundaries.append((word_start, word_end))
                alignment_pos = word_end

                logger.info(f"Word '{word}' -> romanized tokens {romanized_word_tokens} -> positions {word_start}-{word_end}")

            except Exception as e:
                logger.warning(f"Failed to romanize word '{word}': {e}")
                # Fallback: estimate based on character length ratio
                estimated_length = max(1, int(len(word) * len(alignment_tokens) / len(transcription_text)))
                word_start = alignment_pos
                word_end = min(alignment_pos + estimated_length, len(alignment_tokens))
                word_boundaries.append((word_start, word_end))
                alignment_pos = word_end

                logger.info(f"Word '{word}' (fallback) -> estimated positions {word_start}-{word_end}")

        logger.info(f"Word boundaries (romanized): {word_boundaries[:5]}...")
        logger.info(f"Total alignment tokens used: {alignment_pos}/{len(alignment_tokens)}")

        # Map each word to its character segments using the boundaries
        for word_idx, (word, (word_start, word_end)) in enumerate(
            zip(transcription_tokens, word_boundaries)
        ):
            # Find character segments that belong to this word
            word_segments = []

            # Map word character range to alignment token indices
            # Since alignment_tokens might be slightly different due to normalization,
            # we'll be flexible and use a range around the expected positions
            start_idx = max(0, min(word_start, len(audio_segments) - 1))
            end_idx = min(word_end, len(audio_segments))

            # Ensure we don't go beyond available segments
            for seg_idx in range(start_idx, end_idx):
                if seg_idx < len(audio_segments):
                    word_segments.append(audio_segments[seg_idx])

            if word_segments:
                # Use actual timing from the character segments for this word
                start_times = [seg.get(start_key, 0) for seg in word_segments]
                end_times = [
                    seg.get(start_key, 0) + seg.get(duration_key, 0)
                    for seg in word_segments
                ]

                start_time = min(start_times) if start_times else 0
                end_time = max(end_times) if end_times else start_time + 0.1
                duration = end_time - start_time

                # Ensure minimum duration
                if duration < 0.05:  # Minimum 50ms
                    duration = 0.05
                    end_time = start_time + duration

                logger.debug(
                    f"Word '{word}' (segments {start_idx}-{end_idx}, {len(word_segments)} segs): {start_time:.3f}s - {end_time:.3f}s ({duration:.3f}s)"
                )
            else:
                logger.warning(
                    f"No segments found for word '{word}' at position {word_start}-{word_end}"
                )
                # Fallback: use proportional timing if no segments found
                if total_audio_duration > 0 and len(transcription_text) > 0:
                    start_proportion = word_start / len(transcription_text)
                    end_proportion = word_end / len(transcription_text)
                    start_time = start_proportion * total_audio_duration
                    end_time = end_proportion * total_audio_duration
                    duration = end_time - start_time
                else:
                    # Ultimate fallback
                    word_duration = 0.5
                    start_time = word_idx * word_duration
                    end_time = start_time + word_duration
                    duration = word_duration

                logger.debug(
                    f"Word '{word}' (fallback): {start_time:.3f}s - {end_time:.3f}s"
                )

            aligned_segments.append(
                {
                    "text": word,
                    "start": start_time,
                    "end": end_time,
                    "duration": duration,
                }
            )

        # Validate segments don't overlap but preserve natural gaps/silences
        for i in range(1, len(aligned_segments)):
            prev_end = aligned_segments[i - 1]["end"]
            current_start = aligned_segments[i]["start"]

            if current_start < prev_end:
                # Only fix actual overlaps, don't force adjacency
                gap = prev_end - current_start
                logger.debug(
                    f"Overlap detected: segment {i-1} ends at {prev_end:.3f}s, segment {i} starts at {current_start:.3f}s (overlap: {gap:.3f}s)"
                )

                # Fix overlap by adjusting current segment start to previous end
                aligned_segments[i]["start"] = prev_end
                aligned_segments[i]["duration"] = (
                    aligned_segments[i]["end"] - aligned_segments[i]["start"]
                )
                logger.debug(
                    f"Fixed overlap for segment {i}: adjusted start to {prev_end:.3f}s"
                )
            else:
                # Log natural gaps (this is normal and expected)
                gap = current_start - prev_end
                if gap > 0.1:  # Log gaps > 100ms
                    logger.debug(
                        f"Natural gap preserved: {gap:.3f}s between segments {i-1} and {i}"
                    )

        logger.info(f"Forced alignment completed: {len(aligned_segments)} segments")
        return aligned_segments

    except Exception as e:
        logger.error(f"Error in forced alignment: {str(e)}", exc_info=True)

        # Fallback: create uniform timestamps based on audio tensor length
        logger.info("Using fallback uniform timestamps")
        try:
            # Calculate duration from the audio tensor
            total_duration = (
                len(audio_tensor) / sample_rate
                if len(audio_tensor) > 0
                else len(transcription_tokens) * 0.5
            )
        except:
            total_duration = len(transcription_tokens) * 0.5  # Fallback

        segment_duration = (
            total_duration / len(transcription_tokens) if transcription_tokens else 1.0
        )

        fallback_segments = []
        for i, token in enumerate(transcription_tokens):
            start_time = i * segment_duration
            end_time = (i + 1) * segment_duration

            fallback_segments.append(
                {
                    "text": token,
                    "start": start_time,
                    "end": end_time,
                    "duration": segment_duration,
                }
            )

        logger.info(
            f"Using fallback uniform timestamps: {len(fallback_segments)} segments"
        )
        return fallback_segments


def transcribe_with_word_alignment(audio_tensor: torch.Tensor, sample_rate: int = 16000, language_with_script: str = None) -> Dict:
    """
    Transcription pipeline that includes word-level timing through forced alignment.
    Adds precise word-level timestamps to the basic transcription capability.

    Args:
        audio_tensor (torch.Tensor): Audio tensor (1D waveform)
        sample_rate (int): Sample rate of the audio tensor
        language_with_script (str): language_with_script code for transcription (3-letter ISO codes like "eng", "spa") with script

    Returns:
        Dict: Transcription results with alignment information including word-level timestamps
    """

    try:
        # Get model and device first

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # Get the transcription results
        transcription_text = transcribe_single_chunk(audio_tensor, sample_rate=sample_rate, language_with_script=language_with_script)

        if not transcription_text:
            return {
                "transcription": "",
                "tokens": [],
                "aligned_segments": [],
                "total_duration": 0.0,
            }

        # Tokenize the transcription for alignment
        tokens = transcription_text.split()

        # Perform forced alignment using the original audio tensor
        logger.info("Performing forced alignment with original audio tensor...")
        aligned_segments = perform_forced_alignment(audio_tensor, tokens, device, sample_rate)

        # Calculate total duration
        total_duration = aligned_segments[-1]["end"] if aligned_segments else 0.0

        result = {
            "transcription": transcription_text,
            "tokens": tokens,
            "aligned_segments": aligned_segments,
            "total_duration": total_duration,
            "num_segments": len(aligned_segments),
        }

        logger.info(
            f"Transcription with alignment completed: {len(aligned_segments)} segments, {total_duration:.2f}s total"
        )
        return result

    except Exception as e:
        logger.error(f"Error in transcription with alignment: {str(e)}", exc_info=True)
        # Return basic transcription without alignment
        try:
            transcription_text = transcribe_single_chunk(audio_tensor, sample_rate=sample_rate, language_with_script=language_with_script)
            tokens = transcription_text.split() if transcription_text else []

            return {
                "transcription": transcription_text,
                "tokens": tokens,
                "aligned_segments": [],
                "total_duration": 0.0,
                "alignment_error": str(e),
            }
        except Exception as e2:
            logger.error(f"Error in fallback transcription: {str(e2)}", exc_info=True)
            return {
                "transcription": "",
                "tokens": [],
                "aligned_segments": [],
                "total_duration": 0.0,
                "error": str(e2),
            }


def _validate_and_adjust_segments(
    aligned_segments: List[Dict],
    chunk_start_time: float,
    chunk_audio_tensor: torch.Tensor,
    chunk_sample_rate: int,
    chunk_duration: float,
    chunk_index: int
) -> List[Dict]:
    """
    Private helper function to validate and adjust segment timestamps to global timeline.

    Args:
        aligned_segments: Raw segments from forced alignment (local chunk timeline)
        chunk_start_time: Start time of this chunk in global timeline
        chunk_audio_tensor: Audio tensor for this chunk (to get actual duration)
        chunk_sample_rate: Sample rate of the chunk
        chunk_duration: Reported duration of the chunk
        chunk_index: Index of this chunk for debugging

    Returns:
        List of validated segments with global timeline timestamps
    """
    adjusted_segments = []

    # Get the actual audio duration from the chunk tensor instead of the potentially incorrect chunk duration
    actual_chunk_duration = len(chunk_audio_tensor) / chunk_sample_rate if len(chunk_audio_tensor) > 0 else chunk_duration

    for segment in aligned_segments:
        original_start = segment["start"]
        original_end = segment["end"]

        # Validate that segment timestamps are within chunk boundaries
        if original_start < 0:
            logger.warning(
                f"Segment '{segment['text']}' has negative start time {original_start:.3f}s, clipping to 0"
            )
            original_start = 0

        if original_end > actual_chunk_duration + 1.0:  # Allow 1s buffer for alignment errors
            logger.warning(
                f"Segment '{segment['text']}' end time {original_end:.3f}s exceeds actual chunk duration {actual_chunk_duration:.3f}s, clipping"
            )
            original_end = actual_chunk_duration

        if original_start >= original_end:
            logger.warning(
                f"Segment '{segment['text']}' has invalid timing {original_start:.3f}s-{original_end:.3f}s, using fallback"
            )
            # Use proportional timing based on segment position using actual chunk duration
            segment_index = len(adjusted_segments)
            total_segments = len(aligned_segments)
            if total_segments > 0:
                segment_proportion = segment_index / total_segments
                next_proportion = (segment_index + 1) / total_segments
                original_start = segment_proportion * actual_chunk_duration
                original_end = next_proportion * actual_chunk_duration
            else:
                original_start = 0
                original_end = 0.5

        # Create segment with absolute timeline
        adjusted_segment = {
            "text": segment["text"],
            "start": original_start + chunk_start_time,  # Global timeline
            "end": original_end + chunk_start_time,    # Global timeline
            "duration": original_end - original_start,
            "chunk_index": chunk_index,
            "original_start": original_start,  # Local chunk time
            "original_end": original_end,     # Local chunk time
        }

        adjusted_segments.append(adjusted_segment)

        logger.debug(
            f"Segment '{segment['text']}': {original_start:.3f}-{original_end:.3f} -> {adjusted_segment['start']:.3f}-{adjusted_segment['end']:.3f}"
        )

    logger.info(
        f"Adjusted {len(adjusted_segments)} segments to absolute timeline (chunk starts at {chunk_start_time:.2f}s)"
    )

    return adjusted_segments


def transcribe_full_audio_with_chunking(
    audio_tensor: torch.Tensor, sample_rate: int = 16000, chunk_duration: float = 30.0, language_with_script: str = None, progress_callback=None
) -> Dict:
    """
    Complete audio transcription pipeline that handles any length audio with intelligent chunking.
    This is the full-featured transcription function that can process both short and long audio files.

    Chunking mode is controlled by USE_CHUNKING environment variable:
    - USE_CHUNKING=false: No chunking (single chunk mode)
    - USE_CHUNKING=true (default): VAD-based intelligent chunking

    Args:
        audio_tensor: Audio tensor (1D waveform)
        sample_rate: Sample rate of the audio tensor
        chunk_duration: Target chunk duration in seconds (for static chunking)
        language_with_script: {Language code}_{script} for transcription
        progress_callback: Optional callback for progress updates

    Returns:
        Dict with full transcription and segment information including word-level timestamps
    """

    try:
        logger.info(f"Starting long-form transcription: tensor shape {audio_tensor.shape} at {sample_rate}Hz")
        logger.info(f"USE_CHUNKING = {USE_CHUNKING}")

        # Initialize chunker
        chunker = AudioChunker()

        # Determine chunking mode based on USE_CHUNKING setting
        chunking_mode = "vad" if USE_CHUNKING else "none"

        # Chunk the audio using the new unified interface
        # Ensure tensor is 1D before chunking (squeeze any extra dimensions)
        if len(audio_tensor.shape) > 1:
            logger.info(f"Squeezing audio tensor from {audio_tensor.shape} to 1D")
            audio_tensor_1d = audio_tensor.squeeze()
        else:
            audio_tensor_1d = audio_tensor

        chunks = chunker.chunk_audio(audio_tensor_1d, sample_rate=sample_rate, mode=chunking_mode, chunk_duration=chunk_duration)

        if not chunks:
            logger.warning("No audio chunks created")
            return {
                "transcription": "",
                "chunks": [],
                "total_duration": 0.0,
                "error": "No audio content detected",
            }

        logger.info(f"Processing {len(chunks)} audio chunks (mode: {chunking_mode})")

        # Validate chunk continuity
        for i, chunk in enumerate(chunks):
            logger.info(
                f"Chunk {i+1}: {chunk['start_time']:.2f}s - {chunk['end_time']:.2f}s ({chunk['duration']:.2f}s)"
            )
            if i > 0:
                prev_end = chunks[i - 1]["end_time"]
                current_start = chunk["start_time"]
                gap = current_start - prev_end
                if abs(gap) > 0.1:  # More than 100ms gap/overlap
                    logger.warning(
                        f"Gap/overlap between chunks {i} and {i+1}: {gap:.3f}s"
                    )

        # Process each chunk - now all chunks have uniform format!
        all_segments = []
        full_transcription_parts = []
        total_duration = 0.0
        chunk_details = []

        for i, chunk in enumerate(chunks):
            logger.info(
                f"Processing chunk {i+1}/{len(chunks)} ({chunk['duration']:.1f}s, {chunk['start_time']:.1f}s-{chunk['end_time']:.1f}s)"
            )

            try:
                # Process this chunk using tensor-based transcription pipeline
                # Use the chunk's audio_data tensor directly - no more file operations!
                chunk_audio_tensor = chunk["audio_data"]
                chunk_sample_rate = chunk["sample_rate"]

                chunk_result = transcribe_with_word_alignment(
                    audio_tensor=chunk_audio_tensor,
                    sample_rate=chunk_sample_rate,
                    language_with_script=language_with_script
                )

                # Process alignment results - uniform handling for all chunk types
                chunk_segments = []
                chunk_start_time = chunk["start_time"]
                chunk_duration = chunk["duration"]

                if chunk_result.get("aligned_segments"):
                    logger.info(
                        f"Chunk {i+1} has {len(chunk_result['aligned_segments'])} segments"
                    )

                    chunk_segments = _validate_and_adjust_segments(
                        aligned_segments=chunk_result["aligned_segments"],
                        chunk_start_time=chunk_start_time,
                        chunk_audio_tensor=chunk_audio_tensor,
                        chunk_sample_rate=chunk_sample_rate,
                        chunk_duration=chunk_duration,
                        chunk_index=i
                    )

                all_segments.extend(chunk_segments)
                logger.info(f"Chunk {i+1} processed {len(chunk_segments)} valid segments")

                # Add to full transcription
                chunk_transcription = ""
                if chunk_result.get("transcription"):
                    chunk_transcription = chunk_result["transcription"]
                    full_transcription_parts.append(chunk_transcription)

                # Store detailed chunk information
                chunk_detail = {
                    "chunk_index": i,
                    "start_time": chunk["start_time"],
                    "end_time": chunk["end_time"],
                    "duration": chunk["duration"],
                    "transcription": chunk_transcription,
                    "num_segments": len(chunk_segments),
                    "segments": chunk_segments,
                }
                chunk_details.append(chunk_detail)

                total_duration = max(total_duration, chunk["end_time"])

                # Update progress linearly from 0.1 to 0.9 based on chunk processing
                progress = 0.1 + (0.8 * (i + 1) / len(chunks))
                transcription_status.update_progress(progress)

                logger.info(
                    f"Chunk {i+1} processed: '{chunk_transcription}' ({len(chunk_segments)} segments)"
                )

            except Exception as chunk_error:
                logger.error(f"Error processing chunk {i+1}: {chunk_error}")
                # Continue with next chunk

        # Combine results
        full_transcription = " ".join(full_transcription_parts)

        # Validate segment continuity
        logger.info("Validating segment continuity...")
        for i in range(1, len(all_segments)):
            prev_end = all_segments[i - 1]["end"]
            current_start = all_segments[i]["start"]
            gap = current_start - prev_end
            if abs(gap) > 1.0:  # More than 1 second gap
                logger.warning(f"Large gap between segments {i-1} and {i}: {gap:.3f}s")

        result = {
            "transcription": full_transcription,
            "aligned_segments": all_segments,
            "chunks": [
                {
                    "chunk_index": chunk_detail["chunk_index"],
                    "start_time": chunk_detail["start_time"],
                    "end_time": chunk_detail["end_time"],
                    "duration": chunk_detail["duration"],
                    "transcription": chunk_detail["transcription"],
                    "num_segments": chunk_detail["num_segments"],
                }
                for chunk_detail in chunk_details
            ],
            "chunk_details": chunk_details,  # Full details including segments per chunk
            "total_duration": total_duration,
            "num_chunks": len(chunks),
            "num_segments": len(all_segments),
            "status": "success",
        }

        logger.info(
            f"Long-form transcription completed: {len(chunks)} chunks, {total_duration:.1f}s total"
        )
        logger.info(f"Total segments: {len(all_segments)}")

        # Log chunk timing summary
        for chunk_detail in chunk_details:
            logger.info(
                f"Chunk {chunk_detail['chunk_index']}: {chunk_detail['start_time']:.2f}-{chunk_detail['end_time']:.2f}s, {chunk_detail['num_segments']} segments"
            )

        return result

    except Exception as e:
        logger.error(f"Error in long-form transcription: {str(e)}", exc_info=True)
        return {
            "transcription": "",
            "chunks": [],
            "total_duration": 0.0,
            "error": str(e),
        }