X-iZhang commited on
Commit
5b58e8c
Β·
verified Β·
1 Parent(s): ad638e7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -0
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from ccd import ccd_eval, run_eval
6
+ from libra.eval.run_libra import load_model
7
+
8
+
9
+ # =========================================
10
+ # Global Configuration
11
+ # =========================================
12
+ MODEL_CATALOGUE = {
13
+ "Libra-v1.0-7B": "X-iZhang/libra-v1.0-7b",
14
+ "Libra-v1.0-3B": "X-iZhang/libra-v1.0-3b",
15
+ "MAIRA-2": "X-iZhang/libra-maira-2",
16
+ "LLaVA-Med-v1.5": "X-iZhang/libra-llava-med-v1.5-mistral-7b",
17
+ "LLaVA-Rad": "X-iZhang/libra-llava-rad",
18
+ "Med-CXRGen-F": "X-iZhang/Med-CXRGen-F",
19
+ "Med-CXRGen-I": "X-iZhang/Med-CXRGen-I"
20
+ }
21
+ DEFAULT_MODEL_NAME = "MAIRA-2"
22
+ _loaded_models = {}
23
+
24
+
25
+ # =========================================
26
+ # Environment Setup
27
+ # =========================================
28
+ def setup_environment():
29
+ if torch.cuda.is_available():
30
+ print("πŸ”Ή Using GPU:", torch.cuda.get_device_name(0))
31
+ else:
32
+ print("πŸ”Ή Using CPU")
33
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
34
+ os.environ['TRANSFORMERS_CACHE'] = './cache'
35
+ torch.set_num_threads(4)
36
+
37
+
38
+ # =========================================
39
+ # Model Loader
40
+ # =========================================
41
+ def load_or_get_model(model_name: str):
42
+ """Load the model based on its display name."""
43
+ model_path = MODEL_CATALOGUE[model_name]
44
+ print(f"πŸ”Ή Model path resolved: {model_path}")
45
+ if model_path in _loaded_models:
46
+ print(f"πŸ”Ή Model already loaded: {model_name}")
47
+ return _loaded_models[model_path]
48
+
49
+ print(f"πŸ”Ή Loading model: {model_name} ({model_path}) ...")
50
+ try:
51
+ with torch.no_grad():
52
+ model = load_model(model_path)
53
+ _loaded_models[model_path] = model
54
+ print(f"βœ… Loaded successfully: {model_name}")
55
+ return model
56
+ except Exception as e:
57
+ print(f"❌ Error loading model {model_name}: {e}")
58
+ raise
59
+
60
+
61
+ # =========================================
62
+ # CCD Logic
63
+ # =========================================
64
+ def generate_ccd_description(
65
+ selected_model_name,
66
+ current_img,
67
+ prompt,
68
+ expert_model,
69
+ alpha,
70
+ beta,
71
+ gamma,
72
+ use_run_eval,
73
+ max_new_tokens
74
+ ):
75
+ """Generate findings using CCD evaluation."""
76
+ if not current_img:
77
+ return "⚠️ Please upload or select an example image first."
78
+
79
+ try:
80
+ print(f"πŸ”Ή Generating description with model: {selected_model_name}")
81
+ print(f"πŸ”Ή Parameters: alpha={alpha}, beta={beta}, gamma={gamma}")
82
+ print(f"πŸ”Ή Image path: {current_img}")
83
+
84
+ model = load_or_get_model(selected_model_name)
85
+ print(f"πŸ”Ή Running CCD with {selected_model_name} and expert model {expert_model}...")
86
+ ccd_output = ccd_eval(
87
+ libra_model=model,
88
+ image=current_img,
89
+ question=prompt,
90
+ max_new_tokens=max_new_tokens,
91
+ expert_model=expert_model,
92
+ alpha=alpha,
93
+ beta=beta,
94
+ gamma=gamma
95
+ )
96
+
97
+ if use_run_eval:
98
+ baseline_output = run_eval(
99
+ libra_model=model,
100
+ image=current_img,
101
+ question=prompt,
102
+ max_new_tokens=max_new_tokens,
103
+ num_beams=1
104
+ )
105
+ return (
106
+ f"### 🩺 CCD Result ({expert_model})\n{ccd_output}\n\n"
107
+ f"---\n### βš–οΈ Baseline (run_eval)\n{baseline_output[0]}"
108
+ )
109
+
110
+ return f"### 🩺 CCD Result ({expert_model})\n{ccd_output}"
111
+
112
+ except Exception:
113
+ import traceback, sys
114
+ error_msg = traceback.format_exc()
115
+ print("========== CCD ERROR LOG ==========", file=sys.stderr)
116
+ print(error_msg, file=sys.stderr)
117
+ print("===================================", file=sys.stderr)
118
+ return f"❌ Exception Trace:\n```\n{error_msg}\n```"
119
+
120
+
121
+ def safe_generate_ccd_description(
122
+ selected_model_name,
123
+ current_img,
124
+ prompt,
125
+ expert_model,
126
+ alpha,
127
+ beta,
128
+ gamma,
129
+ use_run_eval,
130
+ max_new_tokens
131
+ ):
132
+ """Wrapper around generate_ccd_description that logs inputs and prints full traceback on error."""
133
+ import traceback, sys, time
134
+ print("\n=== Gradio callback invoked ===")
135
+ print(f"timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
136
+ print(f"selected_model_name={selected_model_name}")
137
+ print(f"current_img={current_img}")
138
+ print(f"prompt={prompt}")
139
+ print(f"expert_model={expert_model}, alpha={alpha}, beta={beta}, gamma={gamma}, use_run_eval={use_run_eval}, max_new_tokens={max_new_tokens}")
140
+
141
+ try:
142
+ return generate_ccd_description(
143
+ selected_model_name,
144
+ current_img,
145
+ prompt,
146
+ expert_model,
147
+ alpha,
148
+ beta,
149
+ gamma,
150
+ use_run_eval,
151
+ max_new_tokens
152
+ )
153
+ except Exception as e:
154
+ err = traceback.format_exc()
155
+ print("========== GRADIO CALLBACK ERROR ==========", file=sys.stderr)
156
+ print(err, file=sys.stderr)
157
+ print("==========================================", file=sys.stderr)
158
+ # Also write the error and inputs to a persistent log file for easier inspection
159
+ try:
160
+ with open('/workspace/CCD/callback.log', 'a', encoding='utf-8') as f:
161
+ f.write('\n=== CALLBACK LOG ENTRY ===\n')
162
+ f.write(f"timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
163
+ f.write(f"selected_model_name={selected_model_name}\n")
164
+ f.write(f"current_img={current_img}\n")
165
+ f.write(f"prompt={prompt}\n")
166
+ f.write(f"expert_model={expert_model}, alpha={alpha}, beta={beta}, gamma={gamma}, use_run_eval={use_run_eval}, max_new_tokens={max_new_tokens}\n")
167
+ f.write('TRACEBACK:\n')
168
+ f.write(err + '\n')
169
+ f.write('=== END ENTRY ===\n')
170
+ except Exception as fe:
171
+ print(f"Failed to write callback.log: {fe}", file=sys.stderr)
172
+ # Also return a user-friendly error message to the UI with traceback
173
+ return f"❌ An internal error occurred. See server logs for details.\n\nTraceback:\n```\n{err}\n```"
174
+
175
+
176
+ # =========================================
177
+ # Main Application
178
+ # =========================================
179
+ def main():
180
+ setup_environment()
181
+
182
+ # Example Image Path
183
+ cur_dir = os.path.abspath(os.path.dirname(__file__))
184
+ example_path = os.path.abspath(os.path.join(cur_dir, "..", "assets", "example.jpg"))
185
+ example_exists = os.path.exists(example_path)
186
+
187
+ # Model reference table
188
+ model_table = """
189
+ | **Model Name** | **HuggingFace Link** |
190
+ |----------------|----------------------|
191
+ | **Libra-v1.0-7B** | [X-iZhang/libra-v1.0-7b](https://huggingface.co/X-iZhang/libra-v1.0-7b) |
192
+ | **Libra-v1.0-3B** | [X-iZhang/libra-v1.0-3b](https://huggingface.co/X-iZhang/libra-v1.0-3b) |
193
+ | **MAIRA-2** | [X-iZhang/libra-maira-2](https://huggingface.co/X-iZhang/libra-maira-2) |
194
+ | **LLaVA-Med-v1.5** | [X-iZhang/libra-llava-med-v1.5-mistral-7b](https://huggingface.co/X-iZhang/libra-llava-med-v1.5-mistral-7b) |
195
+ | **LLaVA-Rad** | [X-iZhang/libra-llava-rad](https://huggingface.co/X-iZhang/libra-llava-rad) |
196
+ | **Med-CXRGen-F** | [X-iZhang/Med-CXRGen-F](https://huggingface.co/X-iZhang/Med-CXRGen-F) |
197
+ | **Med-CXRGen-I** | [X-iZhang/Med-CXRGen-I](https://huggingface.co/X-iZhang/Med-CXRGen-I) |
198
+ """
199
+
200
+ with gr.Blocks(title="πŸ“· Clinical Contrastive Decoding", theme="soft") as demo:
201
+ gr.Markdown("""
202
+ # πŸ“· CCD: Mitigating Hallucinations in Radiology MLLMs via Clinical Contrastive Decoding
203
+ ### [Project Page](https://x-izhang.github.io/CCD/) | [Paper](https://arxiv.org/abs/2509.23379) | [Code](https://github.com/X-iZhang/CCD) | [Models](https://huggingface.co/collections/X-iZhang/libra-6772bfccc6079298a0fa5f8d)
204
+ """)
205
+
206
+ with gr.Tab("✨ CCD Demo"):
207
+ with gr.Row():
208
+ # -------- Left Column: Image --------
209
+ with gr.Column(scale=1):
210
+ gr.Markdown("### Radiology Image (eg. Chest X-ray)")
211
+ current_img = gr.Image(label="Radiology Image", type="filepath", interactive=True)
212
+ if example_exists:
213
+ gr.Examples(
214
+ examples=[[example_path]],
215
+ inputs=[current_img],
216
+ label="Example Image"
217
+ )
218
+ else:
219
+ gr.Markdown(f"⚠️ Example image not found at `{example_path}`")
220
+
221
+ # -------- Right Column: Controls --------
222
+ with gr.Column(scale=1):
223
+ gr.Markdown("### Model Selection & Prompt")
224
+ selected_model_name = gr.Dropdown(
225
+ label="Base Radiology MLLM",
226
+ choices=list(MODEL_CATALOGUE.keys()),
227
+ value=DEFAULT_MODEL_NAME
228
+ )
229
+ prompt = gr.Textbox(
230
+ label="Question / Prompt",
231
+ value="What are the findings in this chest X-ray?",
232
+ lines=1
233
+ )
234
+
235
+ gr.Markdown("### CCD Parameters")
236
+ expert_model = gr.Radio(
237
+ label="Expert Model",
238
+ choices=["MedSigLip", "DenseNet"],
239
+ value="DenseNet"
240
+ )
241
+
242
+ # Notice for MedSigLip access requirements (hidden by default)
243
+ medsiglip_message = (
244
+ "**Note: The MedSigLip model requires authorization to access.**\n\n"
245
+ "To use MedSigLip, please deploy the Gradio Web Interface locally and complete the authentication steps.\n"
246
+ "See deployment instructions and how to run locally here: "
247
+ "[Gradio Web Interface](https://github.com/X-iZhang/CCD#gradio-web-interface)"
248
+ )
249
+ medsiglip_notice = gr.Markdown(value="", visible=False)
250
+
251
+ def _toggle_medsiglip_notice(choice):
252
+ if choice == "MedSigLip":
253
+ return gr.update(visible=True, value=medsiglip_message)
254
+ else:
255
+ return gr.update(visible=False, value="")
256
+
257
+ # Connect radio change to the notice visibility
258
+ expert_model.change(fn=_toggle_medsiglip_notice, inputs=[expert_model], outputs=[medsiglip_notice])
259
+
260
+ with gr.Row():
261
+ alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Alpha")
262
+ beta = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Beta")
263
+ gamma = gr.Slider(0, 20, value=10, step=1, label="Gamma")
264
+
265
+ with gr.Accordion("Advanced Options", open=False):
266
+ max_new_tokens = gr.Slider(10, 256, value=128, step=1, label="Max New Tokens")
267
+ use_run_eval = gr.Checkbox(label="Compare with baseline (run_eval)", value=False)
268
+
269
+ generate_btn = gr.Button("πŸš€ Generate", variant="primary")
270
+
271
+ # -------- Output --------
272
+ # output = gr.Markdown(label="Output", value="### πŸ“· Results will appear here.πŸ‘‡")
273
+ output = gr.Markdown(
274
+ value='<h3 style="color:#007BFF;">πŸ“· Results will appear here.πŸ‘‡</h3>',
275
+ label="Output"
276
+ )
277
+ # Switch callback to the safe wrapper
278
+ generate_btn.click(
279
+ fn=safe_generate_ccd_description,
280
+ inputs=[
281
+ selected_model_name, current_img, prompt,
282
+ expert_model, alpha, beta, gamma,
283
+ use_run_eval, max_new_tokens
284
+ ],
285
+ outputs=output
286
+ )
287
+
288
+ # -------- Model Table --------
289
+ # gr.Markdown("### 🧠 Supported Models")
290
+ # gr.Markdown(model_table)
291
+
292
+ gr.Markdown("""
293
+ ### Terms of Use
294
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA.
295
+
296
+ By accessing or using this demo, you acknowledge and agree to the following:
297
+ - **Research & Non-Commercial Purposes**: This demo is provided solely for research and demonstration. It must not be used for commercial activities or profit-driven endeavors.
298
+ - **Not Medical Advice**: All generated content is experimental and must not replace professional medical judgment.
299
+ - **Content Moderationt**: While we apply basic safety checks, the system may still produce inaccurate or offensive outputs.
300
+ - **Responsible Use**: Do not use this demo for any illegal, harmful, hateful, violent, or sexual purposes.
301
+ By continuing to use this service, you confirm your acceptance of these terms. If you do not agree, please discontinue use immediately.
302
+ """)
303
+
304
+
305
+ # Log that Gradio is starting (helpful when stdout/stderr are captured)
306
+ try:
307
+ with open('/workspace/CCD/callback.log', 'a', encoding='utf-8') as f:
308
+ f.write(f"\n=== GRADIO START ===\nstarted_at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
309
+ except Exception:
310
+ pass
311
+
312
+ # Bind to 0.0.0.0 so the server is reachable from host/container and set an explicit port
313
+ demo.launch(share=True)
314
+
315
+
316
+ if __name__ == "__main__":
317
+ main()