File size: 4,403 Bytes
2dcb354
 
22111fb
 
 
 
a25030f
22111fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a25030f
 
22111fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0730e48
22111fb
 
 
 
 
 
 
 
 
4b9ff75
22111fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dcb354
9d08586
4b9ff75
22111fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a25030f
22111fb
 
 
a25030f
 
 
 
 
 
 
 
 
 
 
 
22111fb
 
a25030f
22111fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copied/Adapted from https://huggingface.co/spaces/akhaliq/MobileLLM-Pro

import spaces

import logging
import os
import re
import threading
from typing import List, Tuple, Dict

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login

MODEL_ID = "openai/gpt-oss-20b"

logging.basicConfig(level=logging.DEBUG)

LOG = logging.getLogger(__name__)

MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95

ANALYSIS_PATTERN = analysis_match = re.compile(r'^(.*)assistantfinal', flags=re.DOTALL)

# --- Silent Hub auth via env/Space Secret (no UI) ---
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
    except Exception:
        pass  # stay silent

# Globals so we only load once
_tokenizer = None
_model = None
_device = None


def _ensure_loaded():
    LOG.info("Loading model and tokenizer")
    global _tokenizer, _model, _device
    if _tokenizer is not None and _model is not None:
        return
    _tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID, trust_remote_code=True
    )
    _model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        low_cpu_mem_usage=True,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
        _tokenizer.pad_token = _tokenizer.eos_token
    _model.eval()
    _device = next(_model.parameters()).device


_ensure_loaded()
LOG.info("DEVICE %s", _device)


def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
    msgs: List[Dict[str, str]] = []
    for user_msg, bot_msg in history:
        if user_msg:
            msgs.append({"role": "user", "content": user_msg})
        if bot_msg:
            msgs.append({"role": "assistant", "content": bot_msg})
    return msgs


@spaces.GPU(duration=120)
def generate_stream(message: str, history: List[Tuple[str, str]]):
    """
    Minimal streaming chat function for gr.ChatInterface.
    Uses instruct chat template. No token UI. No extra controls.
    """

    # FIXME: check the memory footprint doing so. We should rather do this before the spaces wrapper...
    # _ensure_loaded()

    messages = _history_to_messages(history) + [{"role": "user", "content": message}]
    inputs = _tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True,
    )
    input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
    input_ids = input_ids.to(_device)

    # IMPORTANT: don't stream the prompt (prevents system/user text from appearing)
    streamer = TextIteratorStreamer(
        _tokenizer,
        skip_special_tokens=True,
        skip_prompt=True,      # <-- key fix
    )

    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=TEMPERATURE > 0.0,
        temperature=float(TEMPERATURE),
        top_p=float(TOP_P),
        pad_token_id=_tokenizer.pad_token_id,
        eos_token_id=_tokenizer.eos_token_id,
        streamer=streamer,
    )

    thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
    thread.start()

    analysis = ""
    output = ""
    for new_text in streamer:
        output += new_text
        if not analysis:
            m = ANALYSIS_PATTERN.match(output)
            if m:
                analysis = re.sub(r'^analysis\s*', '', m.group(1))
                output = ""

        LOG.info("NEW TEXT: %s, OUTPUT: %s", new_text, output.encode())
        if not analysis:
            answer = f"Analysis:\n{output}"
        else:
            answer = f"Analysis:\n{analysis}\nAnswer:\n{output}"
        yield answer


with gr.Blocks(title="OpenAI GPT-OSS 20B Chat") as demo:
    gr.Markdown(
        """
# Chat
Streaming chat with openai/gpt-oss-20b (instruct)
""")
    gr.ChatInterface(
        fn=generate_stream,
        chatbot=gr.Chatbot(height=420, label="OpenAI"),
        title=None,            # header handled by Markdown above
        description=None,
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))