Spaces:
Sleeping
Sleeping
File size: 4,403 Bytes
2dcb354 22111fb a25030f 22111fb a25030f 22111fb 0730e48 22111fb 4b9ff75 22111fb 2dcb354 9d08586 4b9ff75 22111fb a25030f 22111fb a25030f 22111fb a25030f 22111fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Copied/Adapted from https://huggingface.co/spaces/akhaliq/MobileLLM-Pro
import spaces
import logging
import os
import re
import threading
from typing import List, Tuple, Dict
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
MODEL_ID = "openai/gpt-oss-20b"
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95
ANALYSIS_PATTERN = analysis_match = re.compile(r'^(.*)assistantfinal', flags=re.DOTALL)
# --- Silent Hub auth via env/Space Secret (no UI) ---
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
except Exception:
pass # stay silent
# Globals so we only load once
_tokenizer = None
_model = None
_device = None
def _ensure_loaded():
LOG.info("Loading model and tokenizer")
global _tokenizer, _model, _device
if _tokenizer is not None and _model is not None:
return
_tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True
)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None,
)
if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
_tokenizer.pad_token = _tokenizer.eos_token
_model.eval()
_device = next(_model.parameters()).device
_ensure_loaded()
LOG.info("DEVICE %s", _device)
def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
msgs: List[Dict[str, str]] = []
for user_msg, bot_msg in history:
if user_msg:
msgs.append({"role": "user", "content": user_msg})
if bot_msg:
msgs.append({"role": "assistant", "content": bot_msg})
return msgs
@spaces.GPU(duration=120)
def generate_stream(message: str, history: List[Tuple[str, str]]):
"""
Minimal streaming chat function for gr.ChatInterface.
Uses instruct chat template. No token UI. No extra controls.
"""
# FIXME: check the memory footprint doing so. We should rather do this before the spaces wrapper...
# _ensure_loaded()
messages = _history_to_messages(history) + [{"role": "user", "content": message}]
inputs = _tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
)
input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
input_ids = input_ids.to(_device)
# IMPORTANT: don't stream the prompt (prevents system/user text from appearing)
streamer = TextIteratorStreamer(
_tokenizer,
skip_special_tokens=True,
skip_prompt=True, # <-- key fix
)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=TEMPERATURE > 0.0,
temperature=float(TEMPERATURE),
top_p=float(TOP_P),
pad_token_id=_tokenizer.pad_token_id,
eos_token_id=_tokenizer.eos_token_id,
streamer=streamer,
)
thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
thread.start()
analysis = ""
output = ""
for new_text in streamer:
output += new_text
if not analysis:
m = ANALYSIS_PATTERN.match(output)
if m:
analysis = re.sub(r'^analysis\s*', '', m.group(1))
output = ""
LOG.info("NEW TEXT: %s, OUTPUT: %s", new_text, output.encode())
if not analysis:
answer = f"Analysis:\n{output}"
else:
answer = f"Analysis:\n{analysis}\nAnswer:\n{output}"
yield answer
with gr.Blocks(title="OpenAI GPT-OSS 20B Chat") as demo:
gr.Markdown(
"""
# Chat
Streaming chat with openai/gpt-oss-20b (instruct)
""")
gr.ChatInterface(
fn=generate_stream,
chatbot=gr.Chatbot(height=420, label="OpenAI"),
title=None, # header handled by Markdown above
description=None,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|