File size: 20,060 Bytes
985e0f6
 
eabd7c3
 
 
 
 
 
 
 
 
 
 
 
 
3191ba3
 
 
985e0f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eabd7c3
 
 
 
 
 
3c74d1a
985e0f6
 
 
 
 
 
 
 
 
 
 
3c74d1a
 
 
985e0f6
 
 
3c74d1a
 
 
 
 
 
 
 
985e0f6
3c74d1a
 
985e0f6
3c74d1a
 
 
 
 
 
985e0f6
3c74d1a
 
985e0f6
3c74d1a
 
 
 
eabd7c3
 
 
 
 
 
 
3191ba3
eabd7c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e983e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
from pathlib import Path
import requests
import streamlit as st
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import re
import jieba
import matplotlib
import matplotlib.font_manager as fm
from transformers import AutoTokenizer, AutoModel
import os
import warnings

st.set_page_config(page_title="中文詞級 Transformer 可視化", layout="wide")

@st.cache_resource
def ensure_local_cjk_font(font_filename="NotoSansCJKtc-Regular.otf", variant="tc"):
    """
    ดาวน์โหลดฟอนต์ Noto Sans CJK ครั้งแรกตอนรัน แล้วลงทะเบียนให้ matplotlib ใช้
    variant: "tc"=ตัวเต็ม (Taiwan/HK), "sc"=ตัวย่อ (CN)
    """
    here = Path(__file__).resolve().parent
    fonts_dir = here / "fonts"
    fonts_dir.mkdir(exist_ok=True)
    dest = fonts_dir / font_filename

    if not dest.exists():
        url = {
            "tc": "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/TraditionalChinese/NotoSansCJKtc-Regular.otf",
            "sc": "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansCJKsc-Regular.otf",
        }[variant]
        r = requests.get(url, timeout=60)
        r.raise_for_status()
        dest.write_bytes(r.content)
        print("⬇️  Downloaded font to", dest)

    fm.fontManager.addfont(str(dest))
    prop = fm.FontProperties(fname=str(dest))
    family = prop.get_name()
    matplotlib.rcParams["font.sans-serif"] = [family, "DejaVu Sans"]
    matplotlib.rcParams["axes.unicode_minus"] = False
    print("✅ Using CJK font (local):", dest.name, "->", family)
    return prop, str(dest), family




# ===============================
# 中文字體設定(跨平台支持)
# ===============================
def setup_chinese_font():
    # 0) บังคับมีฟอนต์แบบ local ก่อน (จบในรอบเดียว ไม่ง้อระบบ)
    try:
        prop, path, family = ensure_local_cjk_font(
            font_filename="NotoSansCJKtc-Regular.otf",  # หรือ NotoSansCJKsc-Regular.otf
            variant="tc"                                 # หรือ "sc"
        )
        return prop
    except Exception as _:
        pass  # ถ้าดาวน์โหลดพลาด ค่อยตกไปลอง path ระบบด้านล่าง

    # 1) ถัดไป: ลอง path ระบบตามปกติ (เผื่อคุณไปติดตั้งผ่าน apt/Docker ไว้แล้ว)
    candidate_paths = [
        "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
        "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
        "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
        "/usr/share/fonts/opentype/noto/NotoSansCJK-Sc-Regular.otf",
        "/usr/share/fonts/opentype/noto/NotoSansCJK-TC-Regular.otf",
        "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
    ]
    for p in candidate_paths:
        if os.path.exists(p):
            fm.fontManager.addfont(p)
            prop = fm.FontProperties(fname=p)
            matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
            matplotlib.rcParams["axes.unicode_minus"] = False
            print("✅ Using system CJK font:", p, "->", prop.get_name())
            return prop

    # 2) scan ทั้งระบบ (กันพลาดชื่อไฟล์ต่าง distro)
    for p in fm.findSystemFonts(fontpaths=["/usr/share/fonts", "/usr/local/share/fonts"]):
        if any(k in p.lower() for k in ["wqy", "noto", "cjk", "droid"]):
            fm.fontManager.addfont(p)
            prop = fm.FontProperties(fname=p)
            matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
            matplotlib.rcParams["axes.unicode_minus"] = False
            print("✅ Using scanned CJK font:", p, "->", prop.get_name())
            return prop

    # 3) สุดท้ายค่อยเตือน
    warnings.warn("ไม่พบฟอนต์จีน ใช้ DejaVu Sans ชั่วคราว")
    matplotlib.rcParams["font.sans-serif"] = ["DejaVu Sans"]
    matplotlib.rcParams["axes.unicode_minus"] = False
    return fm.FontProperties()


zh_font = setup_chinese_font()

# ===============================
# 頁面設定
# ===============================

st.title("🧠 中文詞級 Transformer Token / Position / Attention 可視化工具")

# ===============================
# 模型選擇與載入
# ===============================
model_options = {
    "Chinese RoBERTa (WWM-ext)": "hfl/chinese-roberta-wwm-ext",
    "BERT-base-Chinese": "bert-base-chinese",
    "Chinese MacBERT-base": "hfl/chinese-macbert-base"
}

selected_model = st.selectbox(
    "選擇模型",
    list(model_options.keys()),
    index=0
)

model_name = model_options[selected_model]


@st.cache_resource
def load_model(name):
    with st.spinner(f"載入模型 {name} 中..."):
        try:
            tokenizer = AutoTokenizer.from_pretrained(name)
            model = AutoModel.from_pretrained(name, output_attentions=True)
            return tokenizer, model, None
        except Exception as e:
            return None, None, str(e)


tokenizer, model, error = load_model(model_name)

if error:
    st.error(f"模型載入失敗: {error}")
    st.stop()

# ===============================
# 使用者輸入
# ===============================
text = st.text_area(
    "請輸入中文句子:",
    "我今年35歲,目前在科技業工作,作息略不規律。",
    help="輸入您想分析的中文文本。將使用 Jieba 進行分詞,然後用 Transformer 模型分析。"
)


def normalize_text(s):
    """移除特殊符號與全形字"""
    s = re.sub(r"[^\u4e00-\u9fa5A-Za-z0-9,。、;:?!%%\s]", "", s)
    s = s.replace("%", "%").replace("。", "。 ")
    return s.strip()


# ===============================
# 主流程
# ===============================
if st.button("開始分析", type="primary"):
    if not text.strip():
        st.warning("請輸入有效的中文句子")
        st.stop()

    # 文本清理與分詞
    text = normalize_text(text)
    words = list(jieba.cut(text))
    st.write("🔹 Jieba 分詞結果:", words)

    # 不使用空格連接,直接使用原始文本
    # 這樣可以避免空格導致的詞-token不匹配問題
    tokenized_result = tokenizer(text, return_tensors="pt")
    token_ids = tokenized_result["input_ids"][0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(token_ids)

    # 為了更準確地映射詞和token,我們需要找出每個token在原始文本中的位置
    # 創建更穩健的詞-token映射
    char_to_word = {}
    current_pos = 0

    # 為每個字符創建映射到對應詞的索引
    for word_idx, word in enumerate(words):
        for _ in range(len(word)):
            char_to_word[current_pos] = word_idx
            current_pos += 1

    # 創建token到字符位置的映射
    # 注意:這個方法適用於基於字符的中文模型,如BERT/RoBERTa中文模型
    # 對於某些模型可能需要調整

    # 首先找出特殊標記
    special_tokens = []
    for i, token in enumerate(tokens):
        if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<cls>', '<sep>']:
            special_tokens.append(i)

    # 找出原始文本中每個token的起始位置
    chars = list(text)  # 將文本轉換為字符列表
    token_to_char_mapping = []
    token_to_word_mapping = []

    # 處理特殊標記
    char_pos = 0
    for i, token in enumerate(tokens):
        if i in special_tokens:
            token_to_char_mapping.append(-1)  # 特殊標記沒有對應的字符位置
            token_to_word_mapping.append("特殊標記")
        else:
            # 對於中文字符,大多數模型是一個字符一個token
            # 這個邏輯可能需要根據具體模型調整
            if token.startswith('##'):  # BERT風格的子詞
                actual_token = token[2:]
            elif token.startswith('▁') or token.startswith('Ġ'):  # 其他模型風格
                actual_token = token[1:]
            else:
                actual_token = token

            # 注意:中文BERT通常每個token就是一個字符
            # 所以這裡可以直接映射
            if char_pos < len(chars):
                token_to_char_mapping.append(char_pos)
                if char_pos in char_to_word:
                    word_idx = char_to_word[char_pos]
                    token_to_word_mapping.append(words[word_idx])
                else:
                    token_to_word_mapping.append("未知詞")
                char_pos += len(actual_token)
            else:
                token_to_char_mapping.append(-1)
                token_to_word_mapping.append("未知詞")

    # 創建詞到token的映射
    word_to_tokens = [[] for _ in range(len(words))]
    for i, word_idx in enumerate(char_to_word.values()):
        if i < len(chars):
            # 找出對應這個字符位置的token
            for j, char_pos in enumerate(token_to_char_mapping):
                if char_pos == i:
                    word_to_tokens[word_idx].append(j)
                    break

    # 創建token-word對照表
    token_word_df = pd.DataFrame({
        "Token": tokens,
        "Token_ID": token_ids,
        "Word": token_to_word_mapping
    })

    # 創建word-tokens對照表
    word_token_map = []
    for i, word in enumerate(words):
        token_indices = word_to_tokens[i]
        token_list = [tokens[idx] for idx in token_indices if idx < len(tokens)]
        word_token_map.append({
            "Word": word,
            "Tokens": " ".join(token_list) if token_list else "無對應Token"
        })

    word_token_df = pd.DataFrame(word_token_map)

    # 模型前向運算
    with torch.no_grad():
        try:
            outputs = model(**tokenized_result)

            hidden_states = outputs.last_hidden_state.squeeze(0)
            attentions = outputs.attentions

            # Position & Token embeddings
            position_ids = torch.arange(0, tokenized_result["input_ids"].size(1)).unsqueeze(0)
            pos_embeddings = model.embeddings.position_embeddings(position_ids).squeeze(0)
            tok_embeddings = model.embeddings.word_embeddings(tokenized_result["input_ids"]).squeeze(0)

            # ===============================
            # 顯示 Token-Word 映射
            # ===============================
            st.subheader("🔤 Token與詞的對應關係")

            # 顯示詞-Token映射
            st.write("詞對應的Tokens:")
            st.dataframe(word_token_df, use_container_width=True)

            # 顯示Token-詞映射
            st.write("每個Token對應的詞:")
            st.dataframe(token_word_df, use_container_width=True)

            # ===============================
            # 顯示 Embedding(前10維)
            # ===============================
            st.subheader("🧩 Token Embedding(前10維)")
            tok_df = pd.DataFrame(tok_embeddings[:, :10].detach().numpy(),
                                  columns=[f"dim_{i}" for i in range(10)])
            tok_df.insert(0, "Token", tokens)
            tok_df.insert(1, "Word", token_word_df["Word"])
            st.dataframe(tok_df, use_container_width=True)

            st.subheader("📍 Position Embedding(前10維)")
            pos_df = pd.DataFrame(pos_embeddings[:, :10].detach().numpy(),
                                  columns=[f"dim_{i}" for i in range(10)])
            pos_df.insert(0, "Token", tokens)
            pos_df.insert(1, "Word", token_word_df["Word"])
            st.dataframe(pos_df, use_container_width=True)

            # ===============================
            # Attention 可視化
            # ===============================
            num_layers = len(attentions)
            num_heads = attentions[0].shape[1]

            col1, col2 = st.columns(2)
            with col1:
                layer_idx = st.slider("選擇 Attention 層數", 1, num_layers, num_layers)
            with col2:
                head_idx = st.slider("選擇 Attention Head", 1, num_heads, 1)

            # 取得該層、該頭的注意力矩陣
            selected_attention = attentions[layer_idx - 1][0, head_idx - 1].detach().numpy()
            mean_attention = attentions[layer_idx - 1][0].mean(0).detach().numpy()

            # 添加標註信息
            token_labels = [f"{t}\n({w})" if w != "特殊標記" else t
                            for t, w in zip(tokens, token_word_df["Word"])]

            # 單頭 Attention Heatmap
            st.subheader(f"🔥 Attention Heatmap(第 {layer_idx} 層,第 {head_idx} 頭)")
            fig, ax = plt.subplots(figsize=(12, 10))
            sns.heatmap(selected_attention, xticklabels=token_labels, yticklabels=token_labels,
                        cmap="YlGnBu", ax=ax)
            plt.title(f"Attention - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
            plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
            plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
            st.pyplot(fig, clear_figure=True, use_container_width=True)

            # 平均所有頭
            st.subheader(f"🌈 平均所有頭(第 {layer_idx} 層)")
            fig2, ax2 = plt.subplots(figsize=(12, 10))
            sns.heatmap(mean_attention, xticklabels=token_labels, yticklabels=token_labels,
                        cmap="rocket_r", ax=ax2)
            plt.title(f"Mean Attention - Layer {layer_idx}", fontproperties=zh_font)
            plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
            plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
            st.pyplot(fig2, clear_figure=True, use_container_width=True)

            # ===============================
            # 詞的平均注意力可視化
            # ===============================
            st.subheader("📊 詞級別注意力熱圖")

            # 創建詞彙列表(去除特殊標記和未知詞)
            unique_words = [w for w in words if w.strip()]

            if len(unique_words) > 1:  # 確保有足夠的詞進行可視化
                # 創建詞-詞注意力矩陣
                word_attention = np.zeros((len(unique_words), len(unique_words)))

                # 使用之前建立的映射來聚合token級別的注意力到詞級別
                for i, word_i in enumerate(unique_words):
                    # 找出屬於word_i的所有token
                    tokens_i = []
                    for j, w in enumerate(token_word_df["Word"]):
                        if w == word_i:
                            tokens_i.append(j)

                    for j, word_j in enumerate(unique_words):
                        # 找出屬於word_j的所有token
                        tokens_j = []
                        for k, w in enumerate(token_word_df["Word"]):
                            if w == word_j:
                                tokens_j.append(k)

                        # 計算這兩個詞之間的所有token對的平均注意力
                        if tokens_i and tokens_j:  # 確保兩個詞都有對應的token
                            attention_sum = 0
                            count = 0
                            for ti in tokens_i:
                                for tj in tokens_j:
                                    if ti < len(selected_attention) and tj < len(selected_attention[0]):
                                        attention_sum += selected_attention[ti, tj]
                                        count += 1

                            if count > 0:
                                word_attention[i, j] = attention_sum / count

                # 繪製詞級別注意力熱圖
                fig3, ax3 = plt.subplots(figsize=(10, 8))
                sns.heatmap(word_attention, xticklabels=unique_words, yticklabels=unique_words,
                            cmap="viridis", ax=ax3)
                plt.title(f"詞級別注意力 - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
                plt.xticks(rotation=45, fontsize=12, fontproperties=zh_font)
                plt.yticks(rotation=0, fontsize=12, fontproperties=zh_font)
                st.pyplot(fig3, clear_figure=True, use_container_width=True)
            else:
                st.info("詞數量不足,無法生成詞級別注意力熱圖")

            # ===============================
            # 下載 CSV
            # ===============================
            merged_df = pd.concat([tok_df, pos_df.add_prefix("pos_").iloc[:, 2:]], axis=1)
            st.download_button(
                label="💾 下載 Token + Position 向量 CSV",
                data=merged_df.to_csv(index=False).encode("utf-8-sig"),
                file_name="embeddings.csv",
                mime="text/csv"
            )

            # 詞級別平均 embeddings
            st.subheader("📑 詞級別平均 Embeddings(前10維)")

            word_embeddings = {}
            for word in unique_words:
                # 找出屬於該詞的所有token索引
                token_indices = [i for i, w in enumerate(token_word_df["Word"]) if w == word]

                if token_indices:
                    # 計算該詞的平均 embedding
                    word_emb = tok_embeddings[token_indices].mean(dim=0)
                    word_embeddings[word] = word_emb[:10].detach().numpy()

            if word_embeddings:
                word_emb_df = pd.DataFrame.from_dict(
                    {word: values for word, values in word_embeddings.items()},
                    orient='index',
                    columns=[f"dim_{i}" for i in range(10)]
                )
                word_emb_df = word_emb_df.reset_index().rename(columns={"index": "Word"})
                st.dataframe(word_emb_df, use_container_width=True)

                # 下載詞級別 embeddings
                st.download_button(
                    label="💾 下載詞級別向量 CSV",
                    data=word_emb_df.to_csv(index=False).encode("utf-8-sig"),
                    file_name="word_embeddings.csv",
                    mime="text/csv"
                )

        except Exception as e:
            st.error(f"處理時發生錯誤: {str(e)}")
            import traceback

            st.code(traceback.format_exc(), language="python")

# ===============================
# 說明與幫助
# ===============================
with st.expander("📖 使用說明"):
    st.markdown("""
    ### 工具功能

    這個工具可以幫助您理解 Transformer 模型如何處理中文文本:

    1. **分詞與映射**:使用 Jieba 將文本分詞,然後映射到 Transformer 模型的 token
    2. **Embedding 可視化**:查看每個 token 和位置的 embedding 向量前10維
    3. **Attention 可視化**:查看不同層和頭的注意力模式
    4. **詞級別分析**:整合 token 級別信息,得到詞級別的 embedding 和注意力模式

    ### 使用方法

    1. 選擇一個預訓練的中文模型
    2. 輸入您想分析的中文文本
    3. 點擊"開始分析"按鈕
    4. 使用滑塊選擇不同的層和注意力頭進行可視化
    5. 下載 CSV 文件以進一步分析

    ### 技術細節

    - **詞-Token映射**:中文字符通常會被映射到單個Token,而詞通常由多個Token組成
    - **注意力機制**:每一層的每個注意力頭都關注不同的模式
    - **注意力熱圖**:顏色越深表示注意力越強

    ### 注意事項

    - Transformer 模型可能會將一個詞切分成多個 token
    - 特殊標記(如 [CLS], [SEP])會被排除在詞級別分析之外
    - 較長的文本可能需要更多處理時間
    """)