ACloudCenter commited on
Commit
b4f488c
·
1 Parent(s): bf79fdd

Fix: Mismatch between qa fn and call

Browse files
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -51,7 +51,7 @@ def transcribe_audio(audio_filepath):
51
  prompts=[[{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}]],
52
  audios=audio.to(device),
53
  audio_lens=torch.tensor([audio_lens]).to(device),
54
- max_new_tokens=256,
55
  )
56
 
57
  # Convert output IDs to text
@@ -60,7 +60,7 @@ def transcribe_audio(audio_filepath):
60
  return transcript, transcript, initial_message
61
 
62
 
63
- # Simple Q&A function - adapted from working version
64
  @spaces.GPU
65
  def transcript_qa(transcript, question, history):
66
  if not transcript:
@@ -69,33 +69,27 @@ def transcript_qa(transcript, question, history):
69
  if not question:
70
  return history, ""
71
 
72
- # Add user message to history first
73
  history = history + [{"role": "user", "content": question}]
74
 
75
  with torch.inference_mode(), model.llm.disable_adapter():
76
  output_ids = model.generate(
77
  prompts=[[{"role": "user", "content": f"{question}\n\n{transcript}"}]],
78
- max_new_tokens=256,
79
  )
80
 
81
- # Convert output IDs to text and extract answer
82
- answer = model.tokenizer.ids_to_text(output_ids[0].cpu())
83
- answer = answer.split("<|im_start|>assistant")[-1]
84
 
85
- # Remove thinking tags if present
86
- if "<think>" in answer:
87
- if "</think>" in answer:
88
- parts = answer.split("</think>")
89
- if len(parts) > 1:
90
- answer = parts[-1] # Get text after thinking
91
- else:
92
- # If no closing tag, try to get text after opening tag
93
- answer = answer.split("<think>")[0] # Get text before thinking
94
 
95
- answer = answer.strip()
96
 
97
  # Add assistant response to history
98
- history = history + [{"role": "assistant", "content": answer}]
99
 
100
  return history, "" # Return updated history and clear input
101
 
@@ -140,6 +134,9 @@ with gr.Blocks(theme=theme) as demo:
140
  bubble_full_width=False
141
  )
142
 
 
 
 
143
  with gr.Row():
144
  question_input = gr.Textbox(
145
  label="",
 
51
  prompts=[[{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}]],
52
  audios=audio.to(device),
53
  audio_lens=torch.tensor([audio_lens]).to(device),
54
+ max_new_tokens=2048,
55
  )
56
 
57
  # Convert output IDs to text
 
60
  return transcript, transcript, initial_message
61
 
62
 
63
+ # Simple Q&A function
64
  @spaces.GPU
65
  def transcript_qa(transcript, question, history):
66
  if not transcript:
 
69
  if not question:
70
  return history, ""
71
 
72
+ # Add user message to history
73
  history = history + [{"role": "user", "content": question}]
74
 
75
  with torch.inference_mode(), model.llm.disable_adapter():
76
  output_ids = model.generate(
77
  prompts=[[{"role": "user", "content": f"{question}\n\n{transcript}"}]],
78
+ max_new_tokens=2048,
79
  )
80
 
81
+ ans = model.tokenizer.ids_to_text(output_ids[0].cpu())
82
+ ans = ans.split("<|im_start|>assistant")[-1] # get rid of the prompt
 
83
 
84
+ if "<think>" in ans:
85
+ if "</think>" in ans:
86
+ ans = ans.split("<think>")[-1]
87
+ _, ans = ans.split("</think>") # get rid of the thinking
 
 
 
 
 
88
 
89
+ ans = ans.strip()
90
 
91
  # Add assistant response to history
92
+ history = history + [{"role": "assistant", "content": ans}]
93
 
94
  return history, "" # Return updated history and clear input
95
 
 
134
  bubble_full_width=False
135
  )
136
 
137
+ def user(user_message, history: list):
138
+ return "", history + [{"role": "user", "content": user_message}]
139
+
140
  with gr.Row():
141
  question_input = gr.Textbox(
142
  label="",