Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import Qwen3VLForConditionalGeneration, AutoProcessor | |
| import torch | |
| from PIL import Image | |
| import spaces | |
| import json | |
| import re | |
| # Load model and processor | |
| model = Qwen3VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen3-VL-2B-Instruct", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-2B-Instruct") | |
| SYSTEM_PROMPT = """Eres un asistente que recibe la imagen de un ticket de compra y responde SOLO con JSON válido. | |
| Esquema requerido: | |
| { | |
| "merchant": string, | |
| "date": string | null, | |
| "time": string | null, | |
| "currency": string | null, | |
| "subtotal": number | null, | |
| "tax": number | null, | |
| "total": number | null, | |
| "paymentMethod": string | null, | |
| "category": string | null, | |
| "items": [ | |
| { "name": string, "quantity": number | null, "unitPrice": number | null, "total": number | null } | |
| ] | |
| } | |
| Reglas: | |
| - No inventes valores: si falta un dato, usa null. | |
| - Los números deben ser numéricos, no strings. | |
| - La salida debe ser SOLO ese JSON, sin texto extra ni bloques de código.""" | |
| def analyze_ticket(image): | |
| """Analiza una imagen de ticket y devuelve JSON estructurado.""" | |
| if image is None: | |
| return {"error": "No se proporcionó imagen"} | |
| # Build message with image and prompt | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": SYSTEM_PROMPT + "\n\nAnaliza este ticket."} | |
| ] | |
| } | |
| ] | |
| # Prepare inputs using apply_chat_template | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = inputs.to(model.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| ) | |
| # Decode output | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| print(f"Respuesta del modelo: {output_text[:500]}") | |
| # Parse JSON response | |
| try: | |
| cleaned = output_text.strip() | |
| cleaned = re.sub(r'^```(?:json)?\s*', '', cleaned) | |
| cleaned = re.sub(r'\s*```$', '', cleaned) | |
| result = json.loads(cleaned) | |
| return result | |
| except json.JSONDecodeError as e: | |
| print(f"Error parseando JSON: {e}") | |
| return {"raw_response": output_text, "parse_error": True} | |
| # Simple Interface with /predict endpoint | |
| demo = gr.Interface( | |
| fn=analyze_ticket, | |
| inputs=gr.Image(type="pil", label="Imagen del ticket"), | |
| outputs=gr.JSON(label="Datos extraídos"), | |
| title="🧾 Ticket OCR", | |
| description="Sube una imagen de un ticket para extraer la información estructurada.", | |
| api_name="predict" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |