ticket-ocr / app.py
Rbpppp's picture
Update app.py
332e828 verified
import gradio as gr
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
import torch
from PIL import Image
import spaces
import json
import re
# Load model and processor
model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-2B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-2B-Instruct")
SYSTEM_PROMPT = """Eres un asistente que recibe la imagen de un ticket de compra y responde SOLO con JSON válido.
Esquema requerido:
{
"merchant": string,
"date": string | null,
"time": string | null,
"currency": string | null,
"subtotal": number | null,
"tax": number | null,
"total": number | null,
"paymentMethod": string | null,
"category": string | null,
"items": [
{ "name": string, "quantity": number | null, "unitPrice": number | null, "total": number | null }
]
}
Reglas:
- No inventes valores: si falta un dato, usa null.
- Los números deben ser numéricos, no strings.
- La salida debe ser SOLO ese JSON, sin texto extra ni bloques de código."""
@spaces.GPU(duration=120)
def analyze_ticket(image):
"""Analiza una imagen de ticket y devuelve JSON estructurado."""
if image is None:
return {"error": "No se proporcionó imagen"}
# Build message with image and prompt
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": SYSTEM_PROMPT + "\n\nAnaliza este ticket."}
]
}
]
# Prepare inputs using apply_chat_template
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
# Generate response
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False,
)
# Decode output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
print(f"Respuesta del modelo: {output_text[:500]}")
# Parse JSON response
try:
cleaned = output_text.strip()
cleaned = re.sub(r'^```(?:json)?\s*', '', cleaned)
cleaned = re.sub(r'\s*```$', '', cleaned)
result = json.loads(cleaned)
return result
except json.JSONDecodeError as e:
print(f"Error parseando JSON: {e}")
return {"raw_response": output_text, "parse_error": True}
# Simple Interface with /predict endpoint
demo = gr.Interface(
fn=analyze_ticket,
inputs=gr.Image(type="pil", label="Imagen del ticket"),
outputs=gr.JSON(label="Datos extraídos"),
title="🧾 Ticket OCR",
description="Sube una imagen de un ticket para extraer la información estructurada.",
api_name="predict"
)
if __name__ == "__main__":
demo.launch()