Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import requests | |
| import streamlit as st | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import re | |
| import jieba | |
| import matplotlib | |
| import matplotlib.font_manager as fm | |
| from transformers import AutoTokenizer, AutoModel | |
| import os | |
| import warnings | |
| st.set_page_config(page_title="中文詞級 Transformer 可視化", layout="wide") | |
| def ensure_local_cjk_font(font_filename="NotoSansCJKtc-Regular.otf", variant="tc"): | |
| """ | |
| ดาวน์โหลดฟอนต์ Noto Sans CJK ครั้งแรกตอนรัน แล้วลงทะเบียนให้ matplotlib ใช้ | |
| variant: "tc"=ตัวเต็ม (Taiwan/HK), "sc"=ตัวย่อ (CN) | |
| """ | |
| here = Path(__file__).resolve().parent | |
| fonts_dir = here / "fonts" | |
| fonts_dir.mkdir(exist_ok=True) | |
| dest = fonts_dir / font_filename | |
| if not dest.exists(): | |
| url = { | |
| "tc": "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/TraditionalChinese/NotoSansCJKtc-Regular.otf", | |
| "sc": "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansCJKsc-Regular.otf", | |
| }[variant] | |
| r = requests.get(url, timeout=60) | |
| r.raise_for_status() | |
| dest.write_bytes(r.content) | |
| print("⬇️ Downloaded font to", dest) | |
| fm.fontManager.addfont(str(dest)) | |
| prop = fm.FontProperties(fname=str(dest)) | |
| family = prop.get_name() | |
| matplotlib.rcParams["font.sans-serif"] = [family, "DejaVu Sans"] | |
| matplotlib.rcParams["axes.unicode_minus"] = False | |
| print("✅ Using CJK font (local):", dest.name, "->", family) | |
| return prop, str(dest), family | |
| # =============================== | |
| # 中文字體設定(跨平台支持) | |
| # =============================== | |
| def setup_chinese_font(): | |
| # 0) บังคับมีฟอนต์แบบ local ก่อน (จบในรอบเดียว ไม่ง้อระบบ) | |
| try: | |
| prop, path, family = ensure_local_cjk_font( | |
| font_filename="NotoSansCJKtc-Regular.otf", # หรือ NotoSansCJKsc-Regular.otf | |
| variant="tc" # หรือ "sc" | |
| ) | |
| return prop | |
| except Exception as _: | |
| pass # ถ้าดาวน์โหลดพลาด ค่อยตกไปลอง path ระบบด้านล่าง | |
| # 1) ถัดไป: ลอง path ระบบตามปกติ (เผื่อคุณไปติดตั้งผ่าน apt/Docker ไว้แล้ว) | |
| candidate_paths = [ | |
| "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", | |
| "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", | |
| "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", | |
| "/usr/share/fonts/opentype/noto/NotoSansCJK-Sc-Regular.otf", | |
| "/usr/share/fonts/opentype/noto/NotoSansCJK-TC-Regular.otf", | |
| "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", | |
| ] | |
| for p in candidate_paths: | |
| if os.path.exists(p): | |
| fm.fontManager.addfont(p) | |
| prop = fm.FontProperties(fname=p) | |
| matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"] | |
| matplotlib.rcParams["axes.unicode_minus"] = False | |
| print("✅ Using system CJK font:", p, "->", prop.get_name()) | |
| return prop | |
| # 2) scan ทั้งระบบ (กันพลาดชื่อไฟล์ต่าง distro) | |
| for p in fm.findSystemFonts(fontpaths=["/usr/share/fonts", "/usr/local/share/fonts"]): | |
| if any(k in p.lower() for k in ["wqy", "noto", "cjk", "droid"]): | |
| fm.fontManager.addfont(p) | |
| prop = fm.FontProperties(fname=p) | |
| matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"] | |
| matplotlib.rcParams["axes.unicode_minus"] = False | |
| print("✅ Using scanned CJK font:", p, "->", prop.get_name()) | |
| return prop | |
| # 3) สุดท้ายค่อยเตือน | |
| warnings.warn("ไม่พบฟอนต์จีน ใช้ DejaVu Sans ชั่วคราว") | |
| matplotlib.rcParams["font.sans-serif"] = ["DejaVu Sans"] | |
| matplotlib.rcParams["axes.unicode_minus"] = False | |
| return fm.FontProperties() | |
| zh_font = setup_chinese_font() | |
| # =============================== | |
| # 頁面設定 | |
| # =============================== | |
| st.title("🧠 中文詞級 Transformer Token / Position / Attention 可視化工具") | |
| # =============================== | |
| # 模型選擇與載入 | |
| # =============================== | |
| model_options = { | |
| "Chinese RoBERTa (WWM-ext)": "hfl/chinese-roberta-wwm-ext", | |
| "BERT-base-Chinese": "bert-base-chinese", | |
| "Chinese MacBERT-base": "hfl/chinese-macbert-base" | |
| } | |
| selected_model = st.selectbox( | |
| "選擇模型", | |
| list(model_options.keys()), | |
| index=0 | |
| ) | |
| model_name = model_options[selected_model] | |
| def load_model(name): | |
| with st.spinner(f"載入模型 {name} 中..."): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(name) | |
| model = AutoModel.from_pretrained(name, output_attentions=True) | |
| return tokenizer, model, None | |
| except Exception as e: | |
| return None, None, str(e) | |
| tokenizer, model, error = load_model(model_name) | |
| if error: | |
| st.error(f"模型載入失敗: {error}") | |
| st.stop() | |
| # =============================== | |
| # 使用者輸入 | |
| # =============================== | |
| text = st.text_area( | |
| "請輸入中文句子:", | |
| "我今年35歲,目前在科技業工作,作息略不規律。", | |
| help="輸入您想分析的中文文本。將使用 Jieba 進行分詞,然後用 Transformer 模型分析。" | |
| ) | |
| def normalize_text(s): | |
| """移除特殊符號與全形字""" | |
| s = re.sub(r"[^\u4e00-\u9fa5A-Za-z0-9,。、;:?!%%\s]", "", s) | |
| s = s.replace("%", "%").replace("。", "。 ") | |
| return s.strip() | |
| # =============================== | |
| # 主流程 | |
| # =============================== | |
| if st.button("開始分析", type="primary"): | |
| if not text.strip(): | |
| st.warning("請輸入有效的中文句子") | |
| st.stop() | |
| # 文本清理與分詞 | |
| text = normalize_text(text) | |
| words = list(jieba.cut(text)) | |
| st.write("🔹 Jieba 分詞結果:", words) | |
| # 不使用空格連接,直接使用原始文本 | |
| # 這樣可以避免空格導致的詞-token不匹配問題 | |
| tokenized_result = tokenizer(text, return_tensors="pt") | |
| token_ids = tokenized_result["input_ids"][0].tolist() | |
| tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
| # 為了更準確地映射詞和token,我們需要找出每個token在原始文本中的位置 | |
| # 創建更穩健的詞-token映射 | |
| char_to_word = {} | |
| current_pos = 0 | |
| # 為每個字符創建映射到對應詞的索引 | |
| for word_idx, word in enumerate(words): | |
| for _ in range(len(word)): | |
| char_to_word[current_pos] = word_idx | |
| current_pos += 1 | |
| # 創建token到字符位置的映射 | |
| # 注意:這個方法適用於基於字符的中文模型,如BERT/RoBERTa中文模型 | |
| # 對於某些模型可能需要調整 | |
| # 首先找出特殊標記 | |
| special_tokens = [] | |
| for i, token in enumerate(tokens): | |
| if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<cls>', '<sep>']: | |
| special_tokens.append(i) | |
| # 找出原始文本中每個token的起始位置 | |
| chars = list(text) # 將文本轉換為字符列表 | |
| token_to_char_mapping = [] | |
| token_to_word_mapping = [] | |
| # 處理特殊標記 | |
| char_pos = 0 | |
| for i, token in enumerate(tokens): | |
| if i in special_tokens: | |
| token_to_char_mapping.append(-1) # 特殊標記沒有對應的字符位置 | |
| token_to_word_mapping.append("特殊標記") | |
| else: | |
| # 對於中文字符,大多數模型是一個字符一個token | |
| # 這個邏輯可能需要根據具體模型調整 | |
| if token.startswith('##'): # BERT風格的子詞 | |
| actual_token = token[2:] | |
| elif token.startswith('▁') or token.startswith('Ġ'): # 其他模型風格 | |
| actual_token = token[1:] | |
| else: | |
| actual_token = token | |
| # 注意:中文BERT通常每個token就是一個字符 | |
| # 所以這裡可以直接映射 | |
| if char_pos < len(chars): | |
| token_to_char_mapping.append(char_pos) | |
| if char_pos in char_to_word: | |
| word_idx = char_to_word[char_pos] | |
| token_to_word_mapping.append(words[word_idx]) | |
| else: | |
| token_to_word_mapping.append("未知詞") | |
| char_pos += len(actual_token) | |
| else: | |
| token_to_char_mapping.append(-1) | |
| token_to_word_mapping.append("未知詞") | |
| # 創建詞到token的映射 | |
| word_to_tokens = [[] for _ in range(len(words))] | |
| for i, word_idx in enumerate(char_to_word.values()): | |
| if i < len(chars): | |
| # 找出對應這個字符位置的token | |
| for j, char_pos in enumerate(token_to_char_mapping): | |
| if char_pos == i: | |
| word_to_tokens[word_idx].append(j) | |
| break | |
| # 創建token-word對照表 | |
| token_word_df = pd.DataFrame({ | |
| "Token": tokens, | |
| "Token_ID": token_ids, | |
| "Word": token_to_word_mapping | |
| }) | |
| # 創建word-tokens對照表 | |
| word_token_map = [] | |
| for i, word in enumerate(words): | |
| token_indices = word_to_tokens[i] | |
| token_list = [tokens[idx] for idx in token_indices if idx < len(tokens)] | |
| word_token_map.append({ | |
| "Word": word, | |
| "Tokens": " ".join(token_list) if token_list else "無對應Token" | |
| }) | |
| word_token_df = pd.DataFrame(word_token_map) | |
| # 模型前向運算 | |
| with torch.no_grad(): | |
| try: | |
| outputs = model(**tokenized_result) | |
| hidden_states = outputs.last_hidden_state.squeeze(0) | |
| attentions = outputs.attentions | |
| # Position & Token embeddings | |
| position_ids = torch.arange(0, tokenized_result["input_ids"].size(1)).unsqueeze(0) | |
| pos_embeddings = model.embeddings.position_embeddings(position_ids).squeeze(0) | |
| tok_embeddings = model.embeddings.word_embeddings(tokenized_result["input_ids"]).squeeze(0) | |
| # =============================== | |
| # 顯示 Token-Word 映射 | |
| # =============================== | |
| st.subheader("🔤 Token與詞的對應關係") | |
| # 顯示詞-Token映射 | |
| st.write("詞對應的Tokens:") | |
| st.dataframe(word_token_df, use_container_width=True) | |
| # 顯示Token-詞映射 | |
| st.write("每個Token對應的詞:") | |
| st.dataframe(token_word_df, use_container_width=True) | |
| # =============================== | |
| # 顯示 Embedding(前10維) | |
| # =============================== | |
| st.subheader("🧩 Token Embedding(前10維)") | |
| tok_df = pd.DataFrame(tok_embeddings[:, :10].detach().numpy(), | |
| columns=[f"dim_{i}" for i in range(10)]) | |
| tok_df.insert(0, "Token", tokens) | |
| tok_df.insert(1, "Word", token_word_df["Word"]) | |
| st.dataframe(tok_df, use_container_width=True) | |
| st.subheader("📍 Position Embedding(前10維)") | |
| pos_df = pd.DataFrame(pos_embeddings[:, :10].detach().numpy(), | |
| columns=[f"dim_{i}" for i in range(10)]) | |
| pos_df.insert(0, "Token", tokens) | |
| pos_df.insert(1, "Word", token_word_df["Word"]) | |
| st.dataframe(pos_df, use_container_width=True) | |
| # =============================== | |
| # Attention 可視化 | |
| # =============================== | |
| num_layers = len(attentions) | |
| num_heads = attentions[0].shape[1] | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| layer_idx = st.slider("選擇 Attention 層數", 1, num_layers, num_layers) | |
| with col2: | |
| head_idx = st.slider("選擇 Attention Head", 1, num_heads, 1) | |
| # 取得該層、該頭的注意力矩陣 | |
| selected_attention = attentions[layer_idx - 1][0, head_idx - 1].detach().numpy() | |
| mean_attention = attentions[layer_idx - 1][0].mean(0).detach().numpy() | |
| # 添加標註信息 | |
| token_labels = [f"{t}\n({w})" if w != "特殊標記" else t | |
| for t, w in zip(tokens, token_word_df["Word"])] | |
| # 單頭 Attention Heatmap | |
| st.subheader(f"🔥 Attention Heatmap(第 {layer_idx} 層,第 {head_idx} 頭)") | |
| fig, ax = plt.subplots(figsize=(12, 10)) | |
| sns.heatmap(selected_attention, xticklabels=token_labels, yticklabels=token_labels, | |
| cmap="YlGnBu", ax=ax) | |
| plt.title(f"Attention - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font) | |
| plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font) | |
| plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font) | |
| st.pyplot(fig, clear_figure=True, use_container_width=True) | |
| # 平均所有頭 | |
| st.subheader(f"🌈 平均所有頭(第 {layer_idx} 層)") | |
| fig2, ax2 = plt.subplots(figsize=(12, 10)) | |
| sns.heatmap(mean_attention, xticklabels=token_labels, yticklabels=token_labels, | |
| cmap="rocket_r", ax=ax2) | |
| plt.title(f"Mean Attention - Layer {layer_idx}", fontproperties=zh_font) | |
| plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font) | |
| plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font) | |
| st.pyplot(fig2, clear_figure=True, use_container_width=True) | |
| # =============================== | |
| # 詞的平均注意力可視化 | |
| # =============================== | |
| st.subheader("📊 詞級別注意力熱圖") | |
| # 創建詞彙列表(去除特殊標記和未知詞) | |
| unique_words = [w for w in words if w.strip()] | |
| if len(unique_words) > 1: # 確保有足夠的詞進行可視化 | |
| # 創建詞-詞注意力矩陣 | |
| word_attention = np.zeros((len(unique_words), len(unique_words))) | |
| # 使用之前建立的映射來聚合token級別的注意力到詞級別 | |
| for i, word_i in enumerate(unique_words): | |
| # 找出屬於word_i的所有token | |
| tokens_i = [] | |
| for j, w in enumerate(token_word_df["Word"]): | |
| if w == word_i: | |
| tokens_i.append(j) | |
| for j, word_j in enumerate(unique_words): | |
| # 找出屬於word_j的所有token | |
| tokens_j = [] | |
| for k, w in enumerate(token_word_df["Word"]): | |
| if w == word_j: | |
| tokens_j.append(k) | |
| # 計算這兩個詞之間的所有token對的平均注意力 | |
| if tokens_i and tokens_j: # 確保兩個詞都有對應的token | |
| attention_sum = 0 | |
| count = 0 | |
| for ti in tokens_i: | |
| for tj in tokens_j: | |
| if ti < len(selected_attention) and tj < len(selected_attention[0]): | |
| attention_sum += selected_attention[ti, tj] | |
| count += 1 | |
| if count > 0: | |
| word_attention[i, j] = attention_sum / count | |
| # 繪製詞級別注意力熱圖 | |
| fig3, ax3 = plt.subplots(figsize=(10, 8)) | |
| sns.heatmap(word_attention, xticklabels=unique_words, yticklabels=unique_words, | |
| cmap="viridis", ax=ax3) | |
| plt.title(f"詞級別注意力 - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font) | |
| plt.xticks(rotation=45, fontsize=12, fontproperties=zh_font) | |
| plt.yticks(rotation=0, fontsize=12, fontproperties=zh_font) | |
| st.pyplot(fig3, clear_figure=True, use_container_width=True) | |
| else: | |
| st.info("詞數量不足,無法生成詞級別注意力熱圖") | |
| # =============================== | |
| # 下載 CSV | |
| # =============================== | |
| merged_df = pd.concat([tok_df, pos_df.add_prefix("pos_").iloc[:, 2:]], axis=1) | |
| st.download_button( | |
| label="💾 下載 Token + Position 向量 CSV", | |
| data=merged_df.to_csv(index=False).encode("utf-8-sig"), | |
| file_name="embeddings.csv", | |
| mime="text/csv" | |
| ) | |
| # 詞級別平均 embeddings | |
| st.subheader("📑 詞級別平均 Embeddings(前10維)") | |
| word_embeddings = {} | |
| for word in unique_words: | |
| # 找出屬於該詞的所有token索引 | |
| token_indices = [i for i, w in enumerate(token_word_df["Word"]) if w == word] | |
| if token_indices: | |
| # 計算該詞的平均 embedding | |
| word_emb = tok_embeddings[token_indices].mean(dim=0) | |
| word_embeddings[word] = word_emb[:10].detach().numpy() | |
| if word_embeddings: | |
| word_emb_df = pd.DataFrame.from_dict( | |
| {word: values for word, values in word_embeddings.items()}, | |
| orient='index', | |
| columns=[f"dim_{i}" for i in range(10)] | |
| ) | |
| word_emb_df = word_emb_df.reset_index().rename(columns={"index": "Word"}) | |
| st.dataframe(word_emb_df, use_container_width=True) | |
| # 下載詞級別 embeddings | |
| st.download_button( | |
| label="💾 下載詞級別向量 CSV", | |
| data=word_emb_df.to_csv(index=False).encode("utf-8-sig"), | |
| file_name="word_embeddings.csv", | |
| mime="text/csv" | |
| ) | |
| except Exception as e: | |
| st.error(f"處理時發生錯誤: {str(e)}") | |
| import traceback | |
| st.code(traceback.format_exc(), language="python") | |
| # =============================== | |
| # 說明與幫助 | |
| # =============================== | |
| with st.expander("📖 使用說明"): | |
| st.markdown(""" | |
| ### 工具功能 | |
| 這個工具可以幫助您理解 Transformer 模型如何處理中文文本: | |
| 1. **分詞與映射**:使用 Jieba 將文本分詞,然後映射到 Transformer 模型的 token | |
| 2. **Embedding 可視化**:查看每個 token 和位置的 embedding 向量前10維 | |
| 3. **Attention 可視化**:查看不同層和頭的注意力模式 | |
| 4. **詞級別分析**:整合 token 級別信息,得到詞級別的 embedding 和注意力模式 | |
| ### 使用方法 | |
| 1. 選擇一個預訓練的中文模型 | |
| 2. 輸入您想分析的中文文本 | |
| 3. 點擊"開始分析"按鈕 | |
| 4. 使用滑塊選擇不同的層和注意力頭進行可視化 | |
| 5. 下載 CSV 文件以進一步分析 | |
| ### 技術細節 | |
| - **詞-Token映射**:中文字符通常會被映射到單個Token,而詞通常由多個Token組成 | |
| - **注意力機制**:每一層的每個注意力頭都關注不同的模式 | |
| - **注意力熱圖**:顏色越深表示注意力越強 | |
| ### 注意事項 | |
| - Transformer 模型可能會將一個詞切分成多個 token | |
| - 特殊標記(如 [CLS], [SEP])會被排除在詞級別分析之外 | |
| - 較長的文本可能需要更多處理時間 | |
| """) |