|
|
import gc |
|
|
import os |
|
|
import io |
|
|
import time |
|
|
import tempfile |
|
|
import logging |
|
|
import spaces |
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import Mistral3ForConditionalGeneration, AutoProcessor |
|
|
|
|
|
from mistral_text_encoding_core import encode_prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=os.getenv("LOG_LEVEL", "INFO"), |
|
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", |
|
|
) |
|
|
logger = logging.getLogger("mistral-text-encoding-gradio") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TEXT_ENCODER_ID = os.getenv("TEXT_ENCODER_ID", "mistralai/Mistral-Small-3.2-24B-Instruct-2506") |
|
|
TOKENIZER_ID = os.getenv( |
|
|
"TOKENIZER_ID", "mistralai/Mistral-Small-3.1-24B-Instruct-2503" |
|
|
) |
|
|
DTYPE = torch.bfloat16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Loading models...") |
|
|
|
|
|
t0 = time.time() |
|
|
text_encoder = Mistral3ForConditionalGeneration.from_pretrained( |
|
|
TEXT_ENCODER_ID, |
|
|
dtype=DTYPE, |
|
|
).to("cuda") |
|
|
logger.info( |
|
|
"Loaded Mistral text encoder (%.2fs) dtype=%s device=%s", |
|
|
time.time() - t0, |
|
|
text_encoder.dtype, |
|
|
) |
|
|
|
|
|
t1 = time.time() |
|
|
tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID) |
|
|
logger.info("Loaded tokenizer in %.2fs", time.time() - t1) |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
def get_vram_info(): |
|
|
"""Get current VRAM usage info.""" |
|
|
if torch.cuda.is_available(): |
|
|
return { |
|
|
"vram_allocated_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 2), |
|
|
"vram_reserved_mb": round(torch.cuda.memory_reserved() / 1024 / 1024, 2), |
|
|
"vram_max_allocated_mb": round(torch.cuda.max_memory_allocated() / 1024 / 1024, 2), |
|
|
} |
|
|
return {"vram": "CUDA not available"} |
|
|
|
|
|
@spaces.GPU() |
|
|
def encode_text(prompt: str): |
|
|
"""Encode text and return a downloadable pytorch file.""" |
|
|
global text_encoder, tokenizer |
|
|
|
|
|
if text_encoder is None or tokenizer is None: |
|
|
return None, "Model not loaded" |
|
|
|
|
|
t0 = time.time() |
|
|
|
|
|
|
|
|
prompts = [p.strip() for p in prompt.strip().split("\n") if p.strip()] |
|
|
if not prompts: |
|
|
return None, "Please enter at least one prompt" |
|
|
|
|
|
num_prompts = len(prompts) |
|
|
prompt_input = prompts[0] if num_prompts == 1 else prompts |
|
|
|
|
|
logger.info("Encoding %d prompt(s)", num_prompts) |
|
|
|
|
|
prompt_embeds, text_ids = encode_prompt( |
|
|
text_encoder=text_encoder, |
|
|
tokenizer=tokenizer, |
|
|
prompt=prompt_input, |
|
|
) |
|
|
|
|
|
duration = (time.time() - t0) * 1000.0 |
|
|
|
|
|
logger.info( |
|
|
"Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s", |
|
|
duration, |
|
|
tuple(prompt_embeds.shape), |
|
|
tuple(text_ids.shape), |
|
|
) |
|
|
|
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt") |
|
|
torch.save(prompt_embeds.cpu(), temp_file.name) |
|
|
|
|
|
|
|
|
del prompt_embeds, text_ids |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
vram = get_vram_info() |
|
|
status = ( |
|
|
f"Encoded {num_prompts} prompt(s) in {duration:.2f}ms\n" |
|
|
f"VRAM: {vram.get('vram_allocated_mb', 'N/A')} MB allocated, " |
|
|
f"{vram.get('vram_max_allocated_mb', 'N/A')} MB peak" |
|
|
) |
|
|
|
|
|
return temp_file.name, status |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Mistral Text Encoder") as demo: |
|
|
gr.Markdown("# Mistral Text Encoder") |
|
|
gr.Markdown("Enter text to encode. For multiple prompts, put each on a new line.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt_input = gr.Textbox( |
|
|
label="Prompt(s)", |
|
|
placeholder="Enter your prompt here...\nOr multiple prompts, one per line", |
|
|
lines=5, |
|
|
) |
|
|
encode_btn = gr.Button("Encode", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_file = gr.File(label="Download Embeddings (.pt)") |
|
|
status_output = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
encode_btn.click( |
|
|
fn=encode_text, |
|
|
inputs=[prompt_input], |
|
|
outputs=[output_file, status_output], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
load_models() |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|