import React, { useState, useRef, useCallback } from "react"; import { AutoProcessor, AutoModelForImageTextToText, RawImage, TextStreamer, } from "@huggingface/transformers"; import type { Tensor, PixtralProcessor, Ministral3ForCausalLM, ProgressInfo, } from "@huggingface/transformers"; import { VLMContext } from "./VLMContext"; const MODEL_ID = "mistralai/Ministral-3-3B-Instruct-2512-ONNX"; const MAX_NEW_TOKENS = 512; export const VLMProvider: React.FC = ({ children, }) => { const [isLoaded, setIsLoaded] = useState(false); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); const [imageSize, setImageSize] = useState(480); const processorRef = useRef(null); const modelRef = useRef(null); const loadPromiseRef = useRef | null>(null); const inferenceLock = useRef(false); const canvasRef = useRef(null); const imageSizeRef = useRef(480); const updateImageSize = useCallback((size: number) => { setImageSize(size); imageSizeRef.current = size; if (processorRef.current?.image_processor) { processorRef.current.image_processor.size = { longest_edge: size }; } }, []); const loadModel = useCallback( async (onProgress?: (msg: string, percentage: number) => void) => { if (isLoaded) { onProgress?.("Model already loaded!", 100); return; } if (loadPromiseRef.current) { return loadPromiseRef.current; } setIsLoading(true); setError(null); loadPromiseRef.current = (async () => { try { onProgress?.("Loading processor...", 0); processorRef.current = await AutoProcessor.from_pretrained(MODEL_ID); processorRef.current.image_processor!.size = { longest_edge: imageSizeRef.current, }; onProgress?.("Processor loaded. Loading model...", 0); const progressMap = new Map(); const progressCallback = (info: ProgressInfo) => { if ( info.status === "progress" && info.file.endsWith(".onnx_data") ) { progressMap.set(info.file, info.loaded / info.total); const total = Array.from(progressMap.values()).reduce( (a, b) => a + b, 0, ); const percentage = (total / 3) * 100; // 3 model files to download onProgress?.("Downloading model...", percentage); } }; modelRef.current = await AutoModelForImageTextToText.from_pretrained( MODEL_ID, { dtype: { embed_tokens: "fp16", vision_encoder: "q4", // q4 is slightly faster than q4f16 (+ better quality) decoder_model_merged: "q4f16", }, device: "webgpu", progress_callback: progressCallback, }, ); onProgress?.("Model loaded successfully!", 100); setIsLoaded(true); } catch (e) { const errorMessage = e instanceof Error ? e.message : String(e); setError(errorMessage); console.error("Error loading model:", e); throw e; } finally { setIsLoading(false); loadPromiseRef.current = null; } })(); return loadPromiseRef.current; }, [isLoaded], ); const runInference = useCallback( async ( video: HTMLVideoElement, instruction: string, onTextUpdate?: (text: string) => void, onStatsUpdate?: (stats: { tps?: number; ttft?: number }) => void, ): Promise => { if (inferenceLock.current) { return ""; // Return empty string to signal a skip } inferenceLock.current = true; if (!processorRef.current || !modelRef.current) { throw new Error("Model/processor not loaded"); } if (!canvasRef.current) { canvasRef.current = document.createElement("canvas"); } const canvas = canvasRef.current; canvas.width = video.videoWidth; canvas.height = video.videoHeight; const ctx = canvas.getContext("2d", { willReadFrequently: true }); if (!ctx) throw new Error("Could not get canvas context"); ctx.drawImage(video, 0, 0); const frame = ctx.getImageData(0, 0, canvas.width, canvas.height); const rawImg = new RawImage(frame.data, frame.width, frame.height, 4); const messages = [ { role: "system", content: `You are a helpful visual AI assistant. Respond concisely and accurately to the user's query in one sentence.`, }, { role: "user", content: `[IMG]${instruction}` }, ]; const prompt = processorRef.current.apply_chat_template(messages); const inputs = await processorRef.current(rawImg, prompt, { add_special_tokens: false, }); let streamed = ""; const start = performance.now(); let decodeStart: number | undefined; let numTokens = 0; const streamer = new TextStreamer(processorRef.current.tokenizer!, { skip_prompt: true, skip_special_tokens: true, callback_function: (t: string) => { if (streamed.length === 0) { const latency = performance.now() - start; onStatsUpdate?.({ ttft: latency }); } streamed += t; onTextUpdate?.(streamed.trim()); }, token_callback_function: () => { decodeStart ??= performance.now(); numTokens++; const elapsed = (performance.now() - decodeStart) / 1000; if (elapsed > 0) { onStatsUpdate?.({ tps: numTokens / elapsed }); } }, }); const outputs = (await modelRef.current.generate({ ...inputs, max_new_tokens: MAX_NEW_TOKENS, do_sample: false, streamer, repetition_penalty: 1.2, })) as Tensor; const generated = outputs.slice(null, [ inputs.input_ids.dims.at(-1), null, ]); const decodeEnd = performance.now(); if (decodeStart) { const numTokens = generated.dims[1]; const tps = numTokens / ((decodeEnd - decodeStart) / 1000); onStatsUpdate?.({ tps }); } const decoded = processorRef.current.batch_decode(generated, { skip_special_tokens: true, }); inferenceLock.current = false; return decoded[0].trim(); }, [], ); return ( {children} ); }; export default VLMProvider;