File size: 11,652 Bytes
962d22d
 
 
 
 
c801df7
6b8b7fc
b2c5cb8
962d22d
6b8b7fc
b2c5cb8
 
6b8b7fc
b2c5cb8
6b8b7fc
b2c5cb8
 
6b8b7fc
 
 
 
 
 
 
 
 
 
 
 
 
b2c5cb8
962d22d
 
 
6b8b7fc
962d22d
 
 
 
 
 
 
 
 
 
 
 
 
 
6b8b7fc
962d22d
 
b2c5cb8
6b8b7fc
962d22d
 
 
 
 
 
 
 
 
 
 
 
 
6b8b7fc
962d22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c5cb8
6b8b7fc
962d22d
 
 
 
 
 
 
 
 
 
 
 
 
b2c5cb8
acb34dc
962d22d
 
 
 
 
 
 
 
 
 
 
 
 
 
161e0b2
b2c5cb8
 
 
6b8b7fc
b2c5cb8
 
 
6b8b7fc
962d22d
 
 
6b8b7fc
 
 
 
 
 
 
 
 
b2c5cb8
6b8b7fc
 
 
962d22d
 
 
6b8b7fc
e71abcc
962d22d
 
 
b2c5cb8
962d22d
6b8b7fc
962d22d
 
 
6b8b7fc
962d22d
6b8b7fc
 
962d22d
6b8b7fc
 
962d22d
b2c5cb8
962d22d
 
 
b2c5cb8
962d22d
6b8b7fc
962d22d
 
 
6b8b7fc
962d22d
b2c5cb8
 
 
 
6b8b7fc
 
 
 
962d22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c5cb8
962d22d
 
 
 
 
 
 
b2c5cb8
962d22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c5cb8
 
6b8b7fc
 
 
b2c5cb8
962d22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b8b7fc
 
 
 
962d22d
 
6b8b7fc
b2c5cb8
6b8b7fc
 
 
962d22d
 
b2c5cb8
6b8b7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c5cb8
6b8b7fc
 
 
e71abcc
962d22d
 
 
 
 
 
 
 
 
6b8b7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""
OCR Application with Multiple Models including DeepSeek OCR
Fixed version with @spaces.GPU decorator for Hugging Face Spaces
"""

import os
import time
import torch
import spaces
from threading import Thread
from PIL import Image
from transformers import (
    AutoProcessor,
    AutoModelForCausalLM,
    Qwen2_5_VLForConditionalGeneration,
    TextIteratorStreamer
)
from qwen_vl_utils import process_vision_info

# Try importing Qwen3VL if available
try:
    from transformers import Qwen3VLForConditionalGeneration
except ImportError:
    Qwen3VLForConditionalGeneration = None

MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"Initial Device: {device}")
print(f"CUDA Available: {torch.cuda.is_available()}")

# Load Chandra-OCR
try:
    MODEL_ID_V = "datalab-to/chandra"
    processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
    if Qwen3VLForConditionalGeneration:
        model_v = Qwen3VLForConditionalGeneration.from_pretrained(
            MODEL_ID_V,
            trust_remote_code=True,
            torch_dtype=torch.float16
        ).eval()
        print("✓ Chandra-OCR loaded")
    else:
        model_v = None
        print("✗ Chandra-OCR: Qwen3VL not available")
except Exception as e:
    model_v = None
    processor_v = None
    print(f"✗ Chandra-OCR: Failed to load - {str(e)}")

# Load Nanonets-OCR2-3B
try:
    MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
    processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
    model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID_X,
        trust_remote_code=True,
        torch_dtype=torch.float16
    ).eval()
    print("✓ Nanonets-OCR2-3B loaded")
except Exception as e:
    model_x = None
    processor_x = None
    print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}")

# Load Dots.OCR - will be moved to GPU when needed
try:
    MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
    processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
    model_d = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH_D,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    ).eval()
    print("✓ Dots.OCR loaded")
except Exception as e:
    model_d = None
    processor_d = None
    print(f"✗ Dots.OCR: Failed to load - {str(e)}")

# Load olmOCR-2-7B-1025
try:
    MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
    processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
    model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID_M,
        trust_remote_code=True,
        torch_dtype=torch.float16
    ).eval()
    print("✓ olmOCR-2-7B-1025 loaded")
except Exception as e:
    model_m = None
    processor_m = None
    print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}")

# Load DeepSeek-OCR
try:
    MODEL_ID_DS = "deepseek-ai/deepseek-ocr"
    processor_ds = AutoProcessor.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
    model_ds = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID_DS,
        trust_remote_code=True,
        torch_dtype=torch.float16
    ).eval()
    print("✓ DeepSeek-OCR loaded")
except Exception as e:
    model_ds = None
    processor_ds = None
    print(f"✗ DeepSeek-OCR: Failed to load - {str(e)}")


@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
                   max_new_tokens: int, temperature: float, top_p: float,
                   top_k: int, repetition_penalty: float):
    """
    Generates responses using the selected model for image input.
    Yields raw text and Markdown-formatted text.
    
    This function is decorated with @spaces.GPU to ensure it runs on GPU
    when available in Hugging Face Spaces.
    
    Args:
        model_name: Name of the OCR model to use
        text: Prompt text for the model
        image: PIL Image object to process
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature
        top_p: Nucleus sampling parameter
        top_k: Top-k sampling parameter
        repetition_penalty: Penalty for repeating tokens
        
    Yields:
        tuple: (raw_text, markdown_text)
    """
    # Device will be cuda when @spaces.GPU decorator activates
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Select model and processor based on model_name
    if model_name == "olmOCR-2-7B-1025":
        if model_m is None:
            yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available."
            return
        processor = processor_m
        model = model_m.to(device)
    elif model_name == "Nanonets-OCR2-3B":
        if model_x is None:
            yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available."
            return
        processor = processor_x
        model = model_x.to(device)
    elif model_name == "Chandra-OCR":
        if model_v is None:
            yield "Chandra-OCR is not available.", "Chandra-OCR is not available."
            return
        processor = processor_v
        model = model_v.to(device)
    elif model_name == "Dots.OCR":
        if model_d is None:
            yield "Dots.OCR is not available.", "Dots.OCR is not available."
            return
        processor = processor_d
        model = model_d.to(device)
    elif model_name == "DeepSeek-OCR":
        if model_ds is None:
            yield "DeepSeek-OCR is not available.", "DeepSeek-OCR is not available."
            return
        processor = processor_ds
        model = model_ds.to(device)
    else:
        yield "Invalid model selected.", "Invalid model selected."
        return

    if image is None:
        yield "Please upload an image.", "Please upload an image."
        return

    try:
        # Prepare messages in chat format
        messages = [{
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": text},
            ]
        }]
        
        # Apply chat template
        prompt_full = processor.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )

        # Process inputs
        inputs = processor(
            text=[prompt_full],
            images=[image],
            return_tensors="pt",
            padding=True
        ).to(device)

        # Setup streaming generation
        streamer = TextIteratorStreamer(
            processor, 
            skip_prompt=True, 
            skip_special_tokens=True
        )
        
        generation_kwargs = {
            **inputs,
            "streamer": streamer,
            "max_new_tokens": max_new_tokens,
            "do_sample": True,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
            "repetition_penalty": repetition_penalty,
        }
        
        # Start generation in separate thread
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        
        # Stream the results
        buffer = ""
        for new_text in streamer:
            buffer += new_text
            buffer = buffer.replace("<|im_end|>", "")
            time.sleep(0.01)
            yield buffer, buffer
        
        # Ensure thread completes
        thread.join()
        
    except Exception as e:
        error_msg = f"Error during generation: {str(e)}"
        yield error_msg, error_msg


# Example usage for Gradio interface
if __name__ == "__main__":
    import gradio as gr
    
    # Determine available models
    available_models = []
    if model_m is not None:
        available_models.append("olmOCR-2-7B-1025")
    if model_x is not None:
        available_models.append("Nanonets-OCR2-3B")
    if model_v is not None:
        available_models.append("Chandra-OCR")
    if model_d is not None:
        available_models.append("Dots.OCR")
    if model_ds is not None:
        available_models.append("DeepSeek-OCR")
    
    if not available_models:
        print("ERROR: No models were loaded successfully!")
        exit(1)
    
    print(f"\n✓ Available models: {', '.join(available_models)}")
    
    with gr.Blocks(title="Multi-Model OCR") as demo:
        gr.Markdown("# 🔍 Multi-Model OCR Application")
        gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.")
        
        with gr.Row():
            with gr.Column():
                model_selector = gr.Dropdown(
                    choices=available_models,
                    value=available_models[0] if available_models else None,
                    label="Select OCR Model"
                )
                image_input = gr.Image(type="pil", label="Upload Image")
                text_input = gr.Textbox(
                    value="Extract all text from this image.",
                    label="Prompt",
                    lines=2
                )
                
                with gr.Accordion("Advanced Settings", open=False):
                    max_tokens = gr.Slider(
                        minimum=1,
                        maximum=MAX_MAX_NEW_TOKENS,
                        value=DEFAULT_MAX_NEW_TOKENS,
                        step=1,
                        label="Max New Tokens"
                    )
                    temperature = gr.Slider(
                        minimum=0.1,
                        maximum=2.0,
                        value=0.7,
                        step=0.1,
                        label="Temperature"
                    )
                    top_p = gr.Slider(
                        minimum=0.0,
                        maximum=1.0,
                        value=0.9,
                        step=0.05,
                        label="Top P"
                    )
                    top_k = gr.Slider(
                        minimum=1,
                        maximum=100,
                        value=50,
                        step=1,
                        label="Top K"
                    )
                    repetition_penalty = gr.Slider(
                        minimum=1.0,
                        maximum=2.0,
                        value=1.1,
                        step=0.1,
                        label="Repetition Penalty"
                    )
                
                submit_btn = gr.Button("Extract Text", variant="primary")
            
            with gr.Column():
                output_text = gr.Textbox(label="Extracted Text", lines=20)
                output_markdown = gr.Markdown(label="Formatted Output")
        
        gr.Markdown("""
        ### Available Models:
        - **olmOCR-2-7B-1025**: Allen AI's OCR model
        - **Nanonets-OCR2-3B**: Nanonets OCR model
        - **Chandra-OCR**: Datalab OCR model
        - **Dots.OCR**: Stranger Vision OCR model
        - **DeepSeek-OCR**: DeepSeek AI's OCR model
        """)
        
        submit_btn.click(
            fn=generate_image,
            inputs=[
                model_selector,
                text_input,
                image_input,
                max_tokens,
                temperature,
                top_p,
                top_k,
                repetition_penalty
            ],
            outputs=[output_text, output_markdown]
        )
    
    demo.launch()