ACloudCenter commited on
Commit
5bc92c5
·
1 Parent(s): e46b406

Fix pipeline error by using model.generate() directlyclear

Browse files
Files changed (1) hide show
  1. app.py +76 -21
app.py CHANGED
@@ -1,30 +1,85 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import torch
4
- from nemo.collections.speechlm2 import SALM
5
  import spaces
 
 
6
 
7
- if torch.cuda.is_available():
8
- device = torch.device("cuda")
9
- else:
10
- device = torch.device("cpu")
11
-
12
- SAMPLE_RATE = 16000 # Hz - NVIDIA model sampling rate
13
- MAX_AUDIO_MINUTES = 120 # wont try to transcribe if longer than this
14
- CHUNK_SECONDS = 40.0 # max audio length seen by the model
15
- BATCH_SIZE = 192 # for parallel transcription of audio longer than CHUNK_SECONDS
16
 
17
- # Initialize the ASR model which is based on the "nvidia/canary-qwen-2.5b" architecture and uses NVIDIA's NeMo framework
18
  model = SALM.from_pretrained("nvidia/canary-qwen-2.5b").bfloat16().eval().to(device)
19
- transcriber = pipeline("automatic-speech-recognition", model = model)
20
 
21
- # Transcribe audio file using NeMo's transcribe class and use spaces for GPU acceleration
22
  @spaces.GPU
23
- def transcribe_audio(audio_file):
24
- transcript = transcriber([audio_file])[0].text
25
- return transcript
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- demo = gr.Interface(
28
- fn=transcribe_audio,
29
- inputs=gr.Audio(source="upload", type="filepath"),
30
- outputs=gr.Textbox())
 
1
  import gradio as gr
 
2
  import torch
 
3
  import spaces
4
+ from lhotse import Recording
5
+ from nemo.collections.speechlm2 import SALM
6
 
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ SAMPLE_RATE = 16000
 
 
 
 
 
 
 
9
 
 
10
  model = SALM.from_pretrained("nvidia/canary-qwen-2.5b").bfloat16().eval().to(device)
 
11
 
 
12
  @spaces.GPU
13
+ def transcribe_audio(audio_filepath):
14
+ if audio_filepath is None:
15
+ return "Please upload an audio file", ""
16
+
17
+ rec = Recording.from_file(audio_filepath, recording_id="temp")
18
+ cut = rec.resample(SAMPLE_RATE).to_cut()
19
+ if cut.num_channels > 1:
20
+ cut = cut.to_mono(mono_downmix=True)
21
+
22
+ audio, audio_lens = cut.load_audio()
23
+
24
+ with torch.inference_mode():
25
+ output_ids = model.generate(
26
+ prompts=[[{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}]],
27
+ audios=torch.as_tensor(audio).unsqueeze(0).to(device),
28
+ audio_lens=torch.as_tensor([audio_lens]).to(device),
29
+ max_new_tokens=256,
30
+ )
31
+
32
+ transcript = model.tokenizer.ids_to_text(output_ids[0].cpu())
33
+ return transcript, transcript
34
+
35
+ @spaces.GPU
36
+ def answer_question(transcript, question):
37
+ if not transcript:
38
+ return "Please transcribe audio first"
39
+
40
+ with torch.inference_mode(), model.llm.disable_adapter():
41
+ output_ids = model.generate(
42
+ prompts=[[{"role": "user", "content": f"{question}\n\n{transcript}"}]],
43
+ max_new_tokens=512,
44
+ )
45
+
46
+ answer = model.tokenizer.ids_to_text(output_ids[0].cpu())
47
+ answer = answer.split("<|im_start|>assistant")[-1]
48
+ return answer.strip()
49
+
50
+ with gr.Blocks(title="Canary-Qwen Transcriber & Q&A") as demo:
51
+ gr.Markdown("# Canary-Qwen Transcriber with Q&A")
52
+ gr.Markdown("Upload audio to transcribe, then ask questions about it!")
53
+
54
+ with gr.Row():
55
+ with gr.Column():
56
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input")
57
+ transcribe_btn = gr.Button("Transcribe", variant="primary")
58
+
59
+ with gr.Column():
60
+ transcript_output = gr.Textbox(label="Transcript", lines=8)
61
+
62
+ transcript_state = gr.State()
63
+
64
+ with gr.Row():
65
+ with gr.Column():
66
+ question_input = gr.Textbox(label="Ask a question about the transcript", placeholder="What is the main topic?")
67
+ ask_btn = gr.Button("Ask", variant="primary")
68
+
69
+ with gr.Column():
70
+ answer_output = gr.Textbox(label="Answer", lines=4)
71
+
72
+ transcribe_btn.click(
73
+ fn=transcribe_audio,
74
+ inputs=[audio_input],
75
+ outputs=[transcript_output, transcript_state]
76
+ )
77
+
78
+ ask_btn.click(
79
+ fn=answer_question,
80
+ inputs=[transcript_state, question_input],
81
+ outputs=[answer_output]
82
+ )
83
 
84
+ demo.queue()
85
+ demo.launch()