Rbpppp commited on
Commit
4ad7141
·
verified ·
1 Parent(s): ef66f3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -110
app.py CHANGED
@@ -1,110 +1,112 @@
1
- import gradio as gr
2
- import spaces
3
- import torch
4
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
- from qwen_vl_utils import process_vision_info
6
- import json
7
- import re
8
-
9
- # Cargar modelo y procesador (se hace una vez al iniciar)
10
- MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
11
-
12
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
13
- MODEL_ID,
14
- torch_dtype=torch.bfloat16,
15
- device_map="auto",
16
- )
17
- processor = AutoProcessor.from_pretrained(MODEL_ID)
18
-
19
- SYSTEM_PROMPT = """Eres un asistente que recibe la imagen de un ticket de compra y responde SOLO con JSON válido.
20
- Esquema requerido:
21
- {
22
- "merchant": string,
23
- "date": string | null,
24
- "time": string | null,
25
- "currency": string | null,
26
- "subtotal": number | null,
27
- "tax": number | null,
28
- "total": number | null,
29
- "paymentMethod": string | null,
30
- "category": string | null,
31
- "items": [
32
- { "name": string, "quantity": number | null, "unitPrice": number | null, "total": number | null }
33
- ]
34
- }
35
- Reglas:
36
- - No inventes valores: si falta un dato, usa null.
37
- - Los números deben ser numéricos, no strings.
38
- - La salida debe ser SOLO ese JSON, sin texto extra ni bloques de código."""
39
-
40
- @spaces.GPU
41
- def analyze_ticket(image):
42
- """Analiza una imagen de ticket y devuelve JSON estructurado."""
43
- if image is None:
44
- return {"error": "No se proporcionó imagen"}
45
-
46
- messages = [
47
- {
48
- "role": "system",
49
- "content": SYSTEM_PROMPT
50
- },
51
- {
52
- "role": "user",
53
- "content": [
54
- {"type": "image", "image": image},
55
- {"type": "text", "text": "Analiza este ticket y extrae la información en formato JSON."}
56
- ],
57
- }
58
- ]
59
-
60
- # Preparar inputs
61
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
62
- image_inputs, video_inputs = process_vision_info(messages)
63
- inputs = processor(
64
- text=[text],
65
- images=image_inputs,
66
- videos=video_inputs,
67
- padding=True,
68
- return_tensors="pt",
69
- ).to(model.device)
70
-
71
- # Generar respuesta
72
- generated_ids = model.generate(
73
- **inputs,
74
- max_new_tokens=1024,
75
- do_sample=False,
76
- )
77
-
78
- # Decodificar solo los tokens generados
79
- generated_ids_trimmed = [
80
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
81
- ]
82
- output_text = processor.batch_decode(
83
- generated_ids_trimmed,
84
- skip_special_tokens=True,
85
- clean_up_tokenization_spaces=False
86
- )[0]
87
-
88
- # Intentar parsear como JSON
89
- try:
90
- # Limpiar posibles bloques de código markdown
91
- cleaned = re.sub(r'^```(?:json)?\s*', '', output_text.strip())
92
- cleaned = re.sub(r'\s*```$', '', cleaned)
93
- result = json.loads(cleaned)
94
- return result
95
- except json.JSONDecodeError:
96
- # Si falla el parseo, devolver el texto crudo
97
- return {"raw_response": output_text, "parse_error": True}
98
-
99
- # Crear interfaz Gradio
100
- demo = gr.Interface(
101
- fn=analyze_ticket,
102
- inputs=gr.Image(type="pil", label="Imagen del ticket"),
103
- outputs=gr.JSON(label="Datos extraídos"),
104
- title="🧾 Ticket OCR",
105
- description="Sube una imagen de un ticket para extraer la información estructurada.",
106
- api_name="predict"
107
- )
108
-
109
- if __name__ == "__main__":
110
- demo.launch()
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from PIL import Image
6
+ import json
7
+ import re
8
+
9
+ # Cargar modelo más pequeño (2B en lugar de 3B)
10
+ MODEL_ID = "Qwen/Qwen2.5-VL-2B-Instruct"
11
+
12
+ print("Cargando modelo...")
13
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
14
+ MODEL_ID,
15
+ torch_dtype=torch.float16,
16
+ device_map="auto",
17
+ trust_remote_code=True,
18
+ )
19
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
20
+ print("Modelo cargado!")
21
+
22
+ SYSTEM_PROMPT = """Eres un asistente que recibe la imagen de un ticket de compra y responde SOLO con JSON válido.
23
+ Esquema requerido:
24
+ {
25
+ "merchant": string,
26
+ "date": string | null,
27
+ "time": string | null,
28
+ "currency": string | null,
29
+ "subtotal": number | null,
30
+ "tax": number | null,
31
+ "total": number | null,
32
+ "paymentMethod": string | null,
33
+ "category": string | null,
34
+ "items": [
35
+ { "name": string, "quantity": number | null, "unitPrice": number | null, "total": number | null }
36
+ ]
37
+ }
38
+ Reglas:
39
+ - No inventes valores: si falta un dato, usa null.
40
+ - Los números deben ser numéricos, no strings.
41
+ - La salida debe ser SOLO ese JSON, sin texto extra ni bloques de código."""
42
+
43
+ @spaces.GPU
44
+ def analyze_ticket(image):
45
+ """Analiza una imagen de ticket y devuelve JSON estructurado."""
46
+ if image is None:
47
+ return {"error": "No se proporcionó imagen"}
48
+
49
+ # Construir mensajes para el modelo
50
+ messages = [
51
+ {
52
+ "role": "user",
53
+ "content": [
54
+ {"type": "image", "image": image},
55
+ {"type": "text", "text": SYSTEM_PROMPT + "\n\nAnaliza este ticket y extrae la información."}
56
+ ],
57
+ }
58
+ ]
59
+
60
+ # Preparar inputs
61
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
62
+
63
+ inputs = processor(
64
+ text=[text],
65
+ images=[image],
66
+ padding=True,
67
+ return_tensors="pt",
68
+ ).to(model.device)
69
+
70
+ # Generar respuesta
71
+ with torch.no_grad():
72
+ generated_ids = model.generate(
73
+ **inputs,
74
+ max_new_tokens=1024,
75
+ do_sample=False,
76
+ pad_token_id=processor.tokenizer.pad_token_id,
77
+ )
78
+
79
+ # Decodificar solo los tokens generados
80
+ generated_ids_trimmed = generated_ids[:, inputs.input_ids.shape[1]:]
81
+ output_text = processor.batch_decode(
82
+ generated_ids_trimmed,
83
+ skip_special_tokens=True,
84
+ clean_up_tokenization_spaces=False
85
+ )[0]
86
+
87
+ print(f"Respuesta del modelo: {output_text[:500]}")
88
+
89
+ # Intentar parsear como JSON
90
+ try:
91
+ # Limpiar posibles bloques de código markdown
92
+ cleaned = output_text.strip()
93
+ cleaned = re.sub(r'^```(?:json)?\s*', '', cleaned)
94
+ cleaned = re.sub(r'\s*```$', '', cleaned)
95
+ result = json.loads(cleaned)
96
+ return result
97
+ except json.JSONDecodeError as e:
98
+ print(f"Error parseando JSON: {e}")
99
+ return {"raw_response": output_text, "parse_error": True}
100
+
101
+ # Crear interfaz Gradio
102
+ demo = gr.Interface(
103
+ fn=analyze_ticket,
104
+ inputs=gr.Image(type="pil", label="Imagen del ticket"),
105
+ outputs=gr.JSON(label="Datos extraídos"),
106
+ title="🧾 Ticket OCR",
107
+ description="Sube una imagen de un ticket para extraer la información estructurada.",
108
+ api_name="predict"
109
+ )
110
+
111
+ if __name__ == "__main__":
112
+ demo.launch()