Transformer / app.py
Donlagon007's picture
Update app.py
3191ba3 verified
from pathlib import Path
import requests
import streamlit as st
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import re
import jieba
import matplotlib
import matplotlib.font_manager as fm
from transformers import AutoTokenizer, AutoModel
import os
import warnings
st.set_page_config(page_title="中文詞級 Transformer 可視化", layout="wide")
@st.cache_resource
def ensure_local_cjk_font(font_filename="NotoSansCJKtc-Regular.otf", variant="tc"):
"""
ดาวน์โหลดฟอนต์ Noto Sans CJK ครั้งแรกตอนรัน แล้วลงทะเบียนให้ matplotlib ใช้
variant: "tc"=ตัวเต็ม (Taiwan/HK), "sc"=ตัวย่อ (CN)
"""
here = Path(__file__).resolve().parent
fonts_dir = here / "fonts"
fonts_dir.mkdir(exist_ok=True)
dest = fonts_dir / font_filename
if not dest.exists():
url = {
"tc": "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/TraditionalChinese/NotoSansCJKtc-Regular.otf",
"sc": "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansCJKsc-Regular.otf",
}[variant]
r = requests.get(url, timeout=60)
r.raise_for_status()
dest.write_bytes(r.content)
print("⬇️ Downloaded font to", dest)
fm.fontManager.addfont(str(dest))
prop = fm.FontProperties(fname=str(dest))
family = prop.get_name()
matplotlib.rcParams["font.sans-serif"] = [family, "DejaVu Sans"]
matplotlib.rcParams["axes.unicode_minus"] = False
print("✅ Using CJK font (local):", dest.name, "->", family)
return prop, str(dest), family
# ===============================
# 中文字體設定(跨平台支持)
# ===============================
def setup_chinese_font():
# 0) บังคับมีฟอนต์แบบ local ก่อน (จบในรอบเดียว ไม่ง้อระบบ)
try:
prop, path, family = ensure_local_cjk_font(
font_filename="NotoSansCJKtc-Regular.otf", # หรือ NotoSansCJKsc-Regular.otf
variant="tc" # หรือ "sc"
)
return prop
except Exception as _:
pass # ถ้าดาวน์โหลดพลาด ค่อยตกไปลอง path ระบบด้านล่าง
# 1) ถัดไป: ลอง path ระบบตามปกติ (เผื่อคุณไปติดตั้งผ่าน apt/Docker ไว้แล้ว)
candidate_paths = [
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/opentype/noto/NotoSansCJK-Sc-Regular.otf",
"/usr/share/fonts/opentype/noto/NotoSansCJK-TC-Regular.otf",
"/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
]
for p in candidate_paths:
if os.path.exists(p):
fm.fontManager.addfont(p)
prop = fm.FontProperties(fname=p)
matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
matplotlib.rcParams["axes.unicode_minus"] = False
print("✅ Using system CJK font:", p, "->", prop.get_name())
return prop
# 2) scan ทั้งระบบ (กันพลาดชื่อไฟล์ต่าง distro)
for p in fm.findSystemFonts(fontpaths=["/usr/share/fonts", "/usr/local/share/fonts"]):
if any(k in p.lower() for k in ["wqy", "noto", "cjk", "droid"]):
fm.fontManager.addfont(p)
prop = fm.FontProperties(fname=p)
matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
matplotlib.rcParams["axes.unicode_minus"] = False
print("✅ Using scanned CJK font:", p, "->", prop.get_name())
return prop
# 3) สุดท้ายค่อยเตือน
warnings.warn("ไม่พบฟอนต์จีน ใช้ DejaVu Sans ชั่วคราว")
matplotlib.rcParams["font.sans-serif"] = ["DejaVu Sans"]
matplotlib.rcParams["axes.unicode_minus"] = False
return fm.FontProperties()
zh_font = setup_chinese_font()
# ===============================
# 頁面設定
# ===============================
st.title("🧠 中文詞級 Transformer Token / Position / Attention 可視化工具")
# ===============================
# 模型選擇與載入
# ===============================
model_options = {
"Chinese RoBERTa (WWM-ext)": "hfl/chinese-roberta-wwm-ext",
"BERT-base-Chinese": "bert-base-chinese",
"Chinese MacBERT-base": "hfl/chinese-macbert-base"
}
selected_model = st.selectbox(
"選擇模型",
list(model_options.keys()),
index=0
)
model_name = model_options[selected_model]
@st.cache_resource
def load_model(name):
with st.spinner(f"載入模型 {name} 中..."):
try:
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModel.from_pretrained(name, output_attentions=True)
return tokenizer, model, None
except Exception as e:
return None, None, str(e)
tokenizer, model, error = load_model(model_name)
if error:
st.error(f"模型載入失敗: {error}")
st.stop()
# ===============================
# 使用者輸入
# ===============================
text = st.text_area(
"請輸入中文句子:",
"我今年35歲,目前在科技業工作,作息略不規律。",
help="輸入您想分析的中文文本。將使用 Jieba 進行分詞,然後用 Transformer 模型分析。"
)
def normalize_text(s):
"""移除特殊符號與全形字"""
s = re.sub(r"[^\u4e00-\u9fa5A-Za-z0-9,。、;:?!%%\s]", "", s)
s = s.replace("%", "%").replace("。", "。 ")
return s.strip()
# ===============================
# 主流程
# ===============================
if st.button("開始分析", type="primary"):
if not text.strip():
st.warning("請輸入有效的中文句子")
st.stop()
# 文本清理與分詞
text = normalize_text(text)
words = list(jieba.cut(text))
st.write("🔹 Jieba 分詞結果:", words)
# 不使用空格連接,直接使用原始文本
# 這樣可以避免空格導致的詞-token不匹配問題
tokenized_result = tokenizer(text, return_tensors="pt")
token_ids = tokenized_result["input_ids"][0].tolist()
tokens = tokenizer.convert_ids_to_tokens(token_ids)
# 為了更準確地映射詞和token,我們需要找出每個token在原始文本中的位置
# 創建更穩健的詞-token映射
char_to_word = {}
current_pos = 0
# 為每個字符創建映射到對應詞的索引
for word_idx, word in enumerate(words):
for _ in range(len(word)):
char_to_word[current_pos] = word_idx
current_pos += 1
# 創建token到字符位置的映射
# 注意:這個方法適用於基於字符的中文模型,如BERT/RoBERTa中文模型
# 對於某些模型可能需要調整
# 首先找出特殊標記
special_tokens = []
for i, token in enumerate(tokens):
if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<cls>', '<sep>']:
special_tokens.append(i)
# 找出原始文本中每個token的起始位置
chars = list(text) # 將文本轉換為字符列表
token_to_char_mapping = []
token_to_word_mapping = []
# 處理特殊標記
char_pos = 0
for i, token in enumerate(tokens):
if i in special_tokens:
token_to_char_mapping.append(-1) # 特殊標記沒有對應的字符位置
token_to_word_mapping.append("特殊標記")
else:
# 對於中文字符,大多數模型是一個字符一個token
# 這個邏輯可能需要根據具體模型調整
if token.startswith('##'): # BERT風格的子詞
actual_token = token[2:]
elif token.startswith('▁') or token.startswith('Ġ'): # 其他模型風格
actual_token = token[1:]
else:
actual_token = token
# 注意:中文BERT通常每個token就是一個字符
# 所以這裡可以直接映射
if char_pos < len(chars):
token_to_char_mapping.append(char_pos)
if char_pos in char_to_word:
word_idx = char_to_word[char_pos]
token_to_word_mapping.append(words[word_idx])
else:
token_to_word_mapping.append("未知詞")
char_pos += len(actual_token)
else:
token_to_char_mapping.append(-1)
token_to_word_mapping.append("未知詞")
# 創建詞到token的映射
word_to_tokens = [[] for _ in range(len(words))]
for i, word_idx in enumerate(char_to_word.values()):
if i < len(chars):
# 找出對應這個字符位置的token
for j, char_pos in enumerate(token_to_char_mapping):
if char_pos == i:
word_to_tokens[word_idx].append(j)
break
# 創建token-word對照表
token_word_df = pd.DataFrame({
"Token": tokens,
"Token_ID": token_ids,
"Word": token_to_word_mapping
})
# 創建word-tokens對照表
word_token_map = []
for i, word in enumerate(words):
token_indices = word_to_tokens[i]
token_list = [tokens[idx] for idx in token_indices if idx < len(tokens)]
word_token_map.append({
"Word": word,
"Tokens": " ".join(token_list) if token_list else "無對應Token"
})
word_token_df = pd.DataFrame(word_token_map)
# 模型前向運算
with torch.no_grad():
try:
outputs = model(**tokenized_result)
hidden_states = outputs.last_hidden_state.squeeze(0)
attentions = outputs.attentions
# Position & Token embeddings
position_ids = torch.arange(0, tokenized_result["input_ids"].size(1)).unsqueeze(0)
pos_embeddings = model.embeddings.position_embeddings(position_ids).squeeze(0)
tok_embeddings = model.embeddings.word_embeddings(tokenized_result["input_ids"]).squeeze(0)
# ===============================
# 顯示 Token-Word 映射
# ===============================
st.subheader("🔤 Token與詞的對應關係")
# 顯示詞-Token映射
st.write("詞對應的Tokens:")
st.dataframe(word_token_df, use_container_width=True)
# 顯示Token-詞映射
st.write("每個Token對應的詞:")
st.dataframe(token_word_df, use_container_width=True)
# ===============================
# 顯示 Embedding(前10維)
# ===============================
st.subheader("🧩 Token Embedding(前10維)")
tok_df = pd.DataFrame(tok_embeddings[:, :10].detach().numpy(),
columns=[f"dim_{i}" for i in range(10)])
tok_df.insert(0, "Token", tokens)
tok_df.insert(1, "Word", token_word_df["Word"])
st.dataframe(tok_df, use_container_width=True)
st.subheader("📍 Position Embedding(前10維)")
pos_df = pd.DataFrame(pos_embeddings[:, :10].detach().numpy(),
columns=[f"dim_{i}" for i in range(10)])
pos_df.insert(0, "Token", tokens)
pos_df.insert(1, "Word", token_word_df["Word"])
st.dataframe(pos_df, use_container_width=True)
# ===============================
# Attention 可視化
# ===============================
num_layers = len(attentions)
num_heads = attentions[0].shape[1]
col1, col2 = st.columns(2)
with col1:
layer_idx = st.slider("選擇 Attention 層數", 1, num_layers, num_layers)
with col2:
head_idx = st.slider("選擇 Attention Head", 1, num_heads, 1)
# 取得該層、該頭的注意力矩陣
selected_attention = attentions[layer_idx - 1][0, head_idx - 1].detach().numpy()
mean_attention = attentions[layer_idx - 1][0].mean(0).detach().numpy()
# 添加標註信息
token_labels = [f"{t}\n({w})" if w != "特殊標記" else t
for t, w in zip(tokens, token_word_df["Word"])]
# 單頭 Attention Heatmap
st.subheader(f"🔥 Attention Heatmap(第 {layer_idx} 層,第 {head_idx} 頭)")
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(selected_attention, xticklabels=token_labels, yticklabels=token_labels,
cmap="YlGnBu", ax=ax)
plt.title(f"Attention - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
st.pyplot(fig, clear_figure=True, use_container_width=True)
# 平均所有頭
st.subheader(f"🌈 平均所有頭(第 {layer_idx} 層)")
fig2, ax2 = plt.subplots(figsize=(12, 10))
sns.heatmap(mean_attention, xticklabels=token_labels, yticklabels=token_labels,
cmap="rocket_r", ax=ax2)
plt.title(f"Mean Attention - Layer {layer_idx}", fontproperties=zh_font)
plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
st.pyplot(fig2, clear_figure=True, use_container_width=True)
# ===============================
# 詞的平均注意力可視化
# ===============================
st.subheader("📊 詞級別注意力熱圖")
# 創建詞彙列表(去除特殊標記和未知詞)
unique_words = [w for w in words if w.strip()]
if len(unique_words) > 1: # 確保有足夠的詞進行可視化
# 創建詞-詞注意力矩陣
word_attention = np.zeros((len(unique_words), len(unique_words)))
# 使用之前建立的映射來聚合token級別的注意力到詞級別
for i, word_i in enumerate(unique_words):
# 找出屬於word_i的所有token
tokens_i = []
for j, w in enumerate(token_word_df["Word"]):
if w == word_i:
tokens_i.append(j)
for j, word_j in enumerate(unique_words):
# 找出屬於word_j的所有token
tokens_j = []
for k, w in enumerate(token_word_df["Word"]):
if w == word_j:
tokens_j.append(k)
# 計算這兩個詞之間的所有token對的平均注意力
if tokens_i and tokens_j: # 確保兩個詞都有對應的token
attention_sum = 0
count = 0
for ti in tokens_i:
for tj in tokens_j:
if ti < len(selected_attention) and tj < len(selected_attention[0]):
attention_sum += selected_attention[ti, tj]
count += 1
if count > 0:
word_attention[i, j] = attention_sum / count
# 繪製詞級別注意力熱圖
fig3, ax3 = plt.subplots(figsize=(10, 8))
sns.heatmap(word_attention, xticklabels=unique_words, yticklabels=unique_words,
cmap="viridis", ax=ax3)
plt.title(f"詞級別注意力 - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
plt.xticks(rotation=45, fontsize=12, fontproperties=zh_font)
plt.yticks(rotation=0, fontsize=12, fontproperties=zh_font)
st.pyplot(fig3, clear_figure=True, use_container_width=True)
else:
st.info("詞數量不足,無法生成詞級別注意力熱圖")
# ===============================
# 下載 CSV
# ===============================
merged_df = pd.concat([tok_df, pos_df.add_prefix("pos_").iloc[:, 2:]], axis=1)
st.download_button(
label="💾 下載 Token + Position 向量 CSV",
data=merged_df.to_csv(index=False).encode("utf-8-sig"),
file_name="embeddings.csv",
mime="text/csv"
)
# 詞級別平均 embeddings
st.subheader("📑 詞級別平均 Embeddings(前10維)")
word_embeddings = {}
for word in unique_words:
# 找出屬於該詞的所有token索引
token_indices = [i for i, w in enumerate(token_word_df["Word"]) if w == word]
if token_indices:
# 計算該詞的平均 embedding
word_emb = tok_embeddings[token_indices].mean(dim=0)
word_embeddings[word] = word_emb[:10].detach().numpy()
if word_embeddings:
word_emb_df = pd.DataFrame.from_dict(
{word: values for word, values in word_embeddings.items()},
orient='index',
columns=[f"dim_{i}" for i in range(10)]
)
word_emb_df = word_emb_df.reset_index().rename(columns={"index": "Word"})
st.dataframe(word_emb_df, use_container_width=True)
# 下載詞級別 embeddings
st.download_button(
label="💾 下載詞級別向量 CSV",
data=word_emb_df.to_csv(index=False).encode("utf-8-sig"),
file_name="word_embeddings.csv",
mime="text/csv"
)
except Exception as e:
st.error(f"處理時發生錯誤: {str(e)}")
import traceback
st.code(traceback.format_exc(), language="python")
# ===============================
# 說明與幫助
# ===============================
with st.expander("📖 使用說明"):
st.markdown("""
### 工具功能
這個工具可以幫助您理解 Transformer 模型如何處理中文文本:
1. **分詞與映射**:使用 Jieba 將文本分詞,然後映射到 Transformer 模型的 token
2. **Embedding 可視化**:查看每個 token 和位置的 embedding 向量前10維
3. **Attention 可視化**:查看不同層和頭的注意力模式
4. **詞級別分析**:整合 token 級別信息,得到詞級別的 embedding 和注意力模式
### 使用方法
1. 選擇一個預訓練的中文模型
2. 輸入您想分析的中文文本
3. 點擊"開始分析"按鈕
4. 使用滑塊選擇不同的層和注意力頭進行可視化
5. 下載 CSV 文件以進一步分析
### 技術細節
- **詞-Token映射**:中文字符通常會被映射到單個Token,而詞通常由多個Token組成
- **注意力機制**:每一層的每個注意力頭都關注不同的模式
- **注意力熱圖**:顏色越深表示注意力越強
### 注意事項
- Transformer 模型可能會將一個詞切分成多個 token
- 特殊標記(如 [CLS], [SEP])會被排除在詞級別分析之外
- 較長的文本可能需要更多處理時間
""")