Donlagon007 commited on
Commit
3c74d1a
·
verified ·
1 Parent(s): 6883970

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +431 -408
app.py CHANGED
@@ -1,409 +1,432 @@
1
- import streamlit as st
2
- import torch
3
- import pandas as pd
4
- import numpy as np
5
- import seaborn as sns
6
- import matplotlib.pyplot as plt
7
- import re
8
- import jieba
9
- import matplotlib
10
- import matplotlib.font_manager as fm
11
- from transformers import AutoTokenizer, AutoModel
12
- import os
13
- import warnings
14
-
15
-
16
- # ===============================
17
- # 中文字體設定(跨平台支持)
18
- # ===============================
19
- def setup_chinese_font():
20
- """ใช้ฟอนต์จีนมาตรฐานของ Linux (สำหรับ Hugging Face Spaces)"""
21
- try:
22
- zh_font = fm.FontProperties(family="WenQuanYi Micro Hei")
23
- matplotlib.rcParams["font.sans-serif"] = ["WenQuanYi Micro Hei", "DejaVu Sans"]
24
- matplotlib.rcParams["axes.unicode_minus"] = False
25
- return zh_font
26
- except Exception as e:
27
- print("⚠️ Font not found, fallback to default:", e)
28
- matplotlib.rcParams["font.sans-serif"] = ["DejaVu Sans"]
29
- return fm.FontProperties()
30
-
31
- zh_font = setup_chinese_font()
32
-
33
- # ===============================
34
- # 頁面設定
35
- # ===============================
36
- st.set_page_config(page_title="中文詞級 Transformer 可視化", layout="wide")
37
- st.title("🧠 中文詞級 Transformer Token / Position / Attention 可視化工具")
38
-
39
- # ===============================
40
- # 模型選擇與載入
41
- # ===============================
42
- model_options = {
43
- "Chinese RoBERTa (WWM-ext)": "hfl/chinese-roberta-wwm-ext",
44
- "BERT-base-Chinese": "bert-base-chinese",
45
- "Chinese MacBERT-base": "hfl/chinese-macbert-base"
46
- }
47
-
48
- selected_model = st.selectbox(
49
- "選擇模型",
50
- list(model_options.keys()),
51
- index=0
52
- )
53
-
54
- model_name = model_options[selected_model]
55
-
56
-
57
- @st.cache_resource
58
- def load_model(name):
59
- with st.spinner(f"載入模型 {name} 中..."):
60
- try:
61
- tokenizer = AutoTokenizer.from_pretrained(name)
62
- model = AutoModel.from_pretrained(name, output_attentions=True)
63
- return tokenizer, model, None
64
- except Exception as e:
65
- return None, None, str(e)
66
-
67
-
68
- tokenizer, model, error = load_model(model_name)
69
-
70
- if error:
71
- st.error(f"模型載入失敗: {error}")
72
- st.stop()
73
-
74
- # ===============================
75
- # 使用者輸入
76
- # ===============================
77
- text = st.text_area(
78
- "請輸入中文句子:",
79
- "我今年35歲,目前在科技業工作,作息略不規律。",
80
- help="輸入您想分析的中文文本。將使用 Jieba 進行分詞,然後用 Transformer 模型分析。"
81
- )
82
-
83
-
84
- def normalize_text(s):
85
- """移除特殊符號與全形字"""
86
- s = re.sub(r"[^\u4e00-\u9fa5A-Za-z0-9,。、;:?!%%\s]", "", s)
87
- s = s.replace("%", "%").replace("。", "。 ")
88
- return s.strip()
89
-
90
-
91
- # ===============================
92
- # 主流程
93
- # ===============================
94
- if st.button("開始分析", type="primary"):
95
- if not text.strip():
96
- st.warning("請輸入有效的中文句子")
97
- st.stop()
98
-
99
- # 文本清理與分詞
100
- text = normalize_text(text)
101
- words = list(jieba.cut(text))
102
- st.write("🔹 Jieba 分詞結果:", words)
103
-
104
- # 不使用空格連接,直接使用原始文本
105
- # 這樣可以避免空格導致的詞-token不匹配問題
106
- tokenized_result = tokenizer(text, return_tensors="pt")
107
- token_ids = tokenized_result["input_ids"][0].tolist()
108
- tokens = tokenizer.convert_ids_to_tokens(token_ids)
109
-
110
- # 為了更準確地映射詞和token,我們需要找出每個token在原始文本中的位置
111
- # 創建更穩健的詞-token映射
112
- char_to_word = {}
113
- current_pos = 0
114
-
115
- # 為每個字符創建映射到對應詞的索引
116
- for word_idx, word in enumerate(words):
117
- for _ in range(len(word)):
118
- char_to_word[current_pos] = word_idx
119
- current_pos += 1
120
-
121
- # 創建token到字符位置的映射
122
- # 注意:這個方法適用於基於字符的中文模型,如BERT/RoBERTa中文模型
123
- # 對於某些模型可能需要調整
124
-
125
- # 首先找出特殊標記
126
- special_tokens = []
127
- for i, token in enumerate(tokens):
128
- if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<cls>', '<sep>']:
129
- special_tokens.append(i)
130
-
131
- # 找出原始文本中每個token的起始位置
132
- chars = list(text) # 將文本轉換為字符列表
133
- token_to_char_mapping = []
134
- token_to_word_mapping = []
135
-
136
- # 處理特殊標記
137
- char_pos = 0
138
- for i, token in enumerate(tokens):
139
- if i in special_tokens:
140
- token_to_char_mapping.append(-1) # 特殊標記沒有對應的字符位置
141
- token_to_word_mapping.append("特殊標記")
142
- else:
143
- # 對於中文字符,大多數模型是一個字符一個token
144
- # 這個邏輯可能需要根據具體模型調整
145
- if token.startswith('##'): # BERT風格的子詞
146
- actual_token = token[2:]
147
- elif token.startswith('▁') or token.startswith('Ġ'): # 其他模型風格
148
- actual_token = token[1:]
149
- else:
150
- actual_token = token
151
-
152
- # 注意:中文BERT通常每個token就是一個字符
153
- # 所以這裡可以直接映射
154
- if char_pos < len(chars):
155
- token_to_char_mapping.append(char_pos)
156
- if char_pos in char_to_word:
157
- word_idx = char_to_word[char_pos]
158
- token_to_word_mapping.append(words[word_idx])
159
- else:
160
- token_to_word_mapping.append("未知詞")
161
- char_pos += len(actual_token)
162
- else:
163
- token_to_char_mapping.append(-1)
164
- token_to_word_mapping.append("未知詞")
165
-
166
- # 創建詞到token的映射
167
- word_to_tokens = [[] for _ in range(len(words))]
168
- for i, word_idx in enumerate(char_to_word.values()):
169
- if i < len(chars):
170
- # 找出對應這個字符位置的token
171
- for j, char_pos in enumerate(token_to_char_mapping):
172
- if char_pos == i:
173
- word_to_tokens[word_idx].append(j)
174
- break
175
-
176
- # 創建token-word對照表
177
- token_word_df = pd.DataFrame({
178
- "Token": tokens,
179
- "Token_ID": token_ids,
180
- "Word": token_to_word_mapping
181
- })
182
-
183
- # 創建word-tokens對照表
184
- word_token_map = []
185
- for i, word in enumerate(words):
186
- token_indices = word_to_tokens[i]
187
- token_list = [tokens[idx] for idx in token_indices if idx < len(tokens)]
188
- word_token_map.append({
189
- "Word": word,
190
- "Tokens": " ".join(token_list) if token_list else "無對應Token"
191
- })
192
-
193
- word_token_df = pd.DataFrame(word_token_map)
194
-
195
- # 模型前向運算
196
- with torch.no_grad():
197
- try:
198
- outputs = model(**tokenized_result)
199
-
200
- hidden_states = outputs.last_hidden_state.squeeze(0)
201
- attentions = outputs.attentions
202
-
203
- # Position & Token embeddings
204
- position_ids = torch.arange(0, tokenized_result["input_ids"].size(1)).unsqueeze(0)
205
- pos_embeddings = model.embeddings.position_embeddings(position_ids).squeeze(0)
206
- tok_embeddings = model.embeddings.word_embeddings(tokenized_result["input_ids"]).squeeze(0)
207
-
208
- # ===============================
209
- # 顯示 Token-Word 映射
210
- # ===============================
211
- st.subheader("🔤 Token與詞的對應關係")
212
-
213
- # 顯示詞-Token映射
214
- st.write("詞對應的Tokens:")
215
- st.dataframe(word_token_df, use_container_width=True)
216
-
217
- # 顯示Token-詞映射
218
- st.write("每個Token對應的詞:")
219
- st.dataframe(token_word_df, use_container_width=True)
220
-
221
- # ===============================
222
- # 顯示 Embedding(前10維)
223
- # ===============================
224
- st.subheader("🧩 Token Embedding(前10維)")
225
- tok_df = pd.DataFrame(tok_embeddings[:, :10].detach().numpy(),
226
- columns=[f"dim_{i}" for i in range(10)])
227
- tok_df.insert(0, "Token", tokens)
228
- tok_df.insert(1, "Word", token_word_df["Word"])
229
- st.dataframe(tok_df, use_container_width=True)
230
-
231
- st.subheader("📍 Position Embedding(前10維)")
232
- pos_df = pd.DataFrame(pos_embeddings[:, :10].detach().numpy(),
233
- columns=[f"dim_{i}" for i in range(10)])
234
- pos_df.insert(0, "Token", tokens)
235
- pos_df.insert(1, "Word", token_word_df["Word"])
236
- st.dataframe(pos_df, use_container_width=True)
237
-
238
- # ===============================
239
- # Attention 可視化
240
- # ===============================
241
- num_layers = len(attentions)
242
- num_heads = attentions[0].shape[1]
243
-
244
- col1, col2 = st.columns(2)
245
- with col1:
246
- layer_idx = st.slider("選擇 Attention 層數", 1, num_layers, num_layers)
247
- with col2:
248
- head_idx = st.slider("選擇 Attention Head", 1, num_heads, 1)
249
-
250
- # 取得該層、該頭的注意力矩陣
251
- selected_attention = attentions[layer_idx - 1][0, head_idx - 1].detach().numpy()
252
- mean_attention = attentions[layer_idx - 1][0].mean(0).detach().numpy()
253
-
254
- # 添加標註信息
255
- token_labels = [f"{t}\n({w})" if w != "特殊標記" else t
256
- for t, w in zip(tokens, token_word_df["Word"])]
257
-
258
- # 單頭 Attention Heatmap
259
- st.subheader(f"🔥 Attention Heatmap(第 {layer_idx} 層,第 {head_idx} 頭)")
260
- fig, ax = plt.subplots(figsize=(12, 10))
261
- sns.heatmap(selected_attention, xticklabels=token_labels, yticklabels=token_labels,
262
- cmap="YlGnBu", ax=ax)
263
- plt.title(f"Attention - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
264
- plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
265
- plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
266
- st.pyplot(fig, clear_figure=True, use_container_width=True)
267
-
268
- # 平均所有頭
269
- st.subheader(f"🌈 平均所有頭(第 {layer_idx} 層)")
270
- fig2, ax2 = plt.subplots(figsize=(12, 10))
271
- sns.heatmap(mean_attention, xticklabels=token_labels, yticklabels=token_labels,
272
- cmap="rocket_r", ax=ax2)
273
- plt.title(f"Mean Attention - Layer {layer_idx}", fontproperties=zh_font)
274
- plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
275
- plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
276
- st.pyplot(fig2, clear_figure=True, use_container_width=True)
277
-
278
- # ===============================
279
- # 詞的平均注意力可視化
280
- # ===============================
281
- st.subheader("📊 詞級別注意力熱圖")
282
-
283
- # 創建詞彙列表(去除特殊標記和未知詞)
284
- unique_words = [w for w in words if w.strip()]
285
-
286
- if len(unique_words) > 1: # 確保有足夠的詞進行可視化
287
- # 創建詞-詞注意力矩陣
288
- word_attention = np.zeros((len(unique_words), len(unique_words)))
289
-
290
- # 使用之前建立的映射來聚合token級別的注意力到詞級別
291
- for i, word_i in enumerate(unique_words):
292
- # 找出屬於word_i的所有token
293
- tokens_i = []
294
- for j, w in enumerate(token_word_df["Word"]):
295
- if w == word_i:
296
- tokens_i.append(j)
297
-
298
- for j, word_j in enumerate(unique_words):
299
- # 找出屬於word_j的所有token
300
- tokens_j = []
301
- for k, w in enumerate(token_word_df["Word"]):
302
- if w == word_j:
303
- tokens_j.append(k)
304
-
305
- # 計算這兩個詞之間的所有token對的平均注意力
306
- if tokens_i and tokens_j: # 確保兩個詞都有對應的token
307
- attention_sum = 0
308
- count = 0
309
- for ti in tokens_i:
310
- for tj in tokens_j:
311
- if ti < len(selected_attention) and tj < len(selected_attention[0]):
312
- attention_sum += selected_attention[ti, tj]
313
- count += 1
314
-
315
- if count > 0:
316
- word_attention[i, j] = attention_sum / count
317
-
318
- # 繪製詞級別注意力熱圖
319
- fig3, ax3 = plt.subplots(figsize=(10, 8))
320
- sns.heatmap(word_attention, xticklabels=unique_words, yticklabels=unique_words,
321
- cmap="viridis", ax=ax3)
322
- plt.title(f"詞級別注意力 - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
323
- plt.xticks(rotation=45, fontsize=12, fontproperties=zh_font)
324
- plt.yticks(rotation=0, fontsize=12, fontproperties=zh_font)
325
- st.pyplot(fig3, clear_figure=True, use_container_width=True)
326
- else:
327
- st.info("詞數量不足,無法生成詞級別注意力熱圖")
328
-
329
- # ===============================
330
- # 下載 CSV
331
- # ===============================
332
- merged_df = pd.concat([tok_df, pos_df.add_prefix("pos_").iloc[:, 2:]], axis=1)
333
- st.download_button(
334
- label="💾 下載 Token + Position 向量 CSV",
335
- data=merged_df.to_csv(index=False).encode("utf-8-sig"),
336
- file_name="embeddings.csv",
337
- mime="text/csv"
338
- )
339
-
340
- # 詞級別平均 embeddings
341
- st.subheader("📑 詞級別平均 Embeddings(前10維)")
342
-
343
- word_embeddings = {}
344
- for word in unique_words:
345
- # 找出屬於該詞的所有token索引
346
- token_indices = [i for i, w in enumerate(token_word_df["Word"]) if w == word]
347
-
348
- if token_indices:
349
- # 計算該詞的平均 embedding
350
- word_emb = tok_embeddings[token_indices].mean(dim=0)
351
- word_embeddings[word] = word_emb[:10].detach().numpy()
352
-
353
- if word_embeddings:
354
- word_emb_df = pd.DataFrame.from_dict(
355
- {word: values for word, values in word_embeddings.items()},
356
- orient='index',
357
- columns=[f"dim_{i}" for i in range(10)]
358
- )
359
- word_emb_df = word_emb_df.reset_index().rename(columns={"index": "Word"})
360
- st.dataframe(word_emb_df, use_container_width=True)
361
-
362
- # 下載詞級別 embeddings
363
- st.download_button(
364
- label="💾 下載詞級別向量 CSV",
365
- data=word_emb_df.to_csv(index=False).encode("utf-8-sig"),
366
- file_name="word_embeddings.csv",
367
- mime="text/csv"
368
- )
369
-
370
- except Exception as e:
371
- st.error(f"處理時發生錯誤: {str(e)}")
372
- import traceback
373
-
374
- st.code(traceback.format_exc(), language="python")
375
-
376
- # ===============================
377
- # 說明與幫助
378
- # ===============================
379
- with st.expander("📖 使用說明"):
380
- st.markdown("""
381
- ### 工具功能
382
-
383
- 這個工具可以幫助您理解 Transformer 模型如何處理中文文本:
384
-
385
- 1. **分詞與映射**:使用 Jieba 將文本分詞,然後映射到 Transformer 模型的 token
386
- 2. **Embedding 可視化**:查看每個 token 和位置的 embedding 向量前10維
387
- 3. **Attention 可視化**:查看不同層和頭的注意力模式
388
- 4. **詞級別分析**:整合 token 級別信息,得到詞級別的 embedding 和注意力模式
389
-
390
- ### 使用方法
391
-
392
- 1. 選擇一個預訓練的中文模型
393
- 2. 輸入您想分析的中文文本
394
- 3. 點擊"開始分析"按鈕
395
- 4. 使用滑塊選擇不同的層和注意力頭進行可視化
396
- 5. 下載 CSV 文件以進一步分析
397
-
398
- ### 技術細節
399
-
400
- - **詞-Token映射**:中文字符通常會被映射到單個Token,而詞通常由多個Token組成
401
- - **注意力機制**:每一層的每個注意力頭都關注不同的模式
402
- - **注意力熱圖**:顏色越深表示注意力越強
403
-
404
- ### 注意事項
405
-
406
- - Transformer 模型可能會將一個詞切分成多個 token
407
- - 特殊標記(如 [CLS], [SEP])會被排除在詞級別分析之外
408
- - 較長的文本可能需要更多處理時間
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  """)
 
1
+ import streamlit as st
2
+ import torch
3
+ import pandas as pd
4
+ import numpy as np
5
+ import seaborn as sns
6
+ import matplotlib.pyplot as plt
7
+ import re
8
+ import jieba
9
+ import matplotlib
10
+ import matplotlib.font_manager as fm
11
+ from transformers import AutoTokenizer, AutoModel
12
+ import os
13
+ import warnings
14
+
15
+
16
+ # ===============================
17
+ # 中文字體設定(跨平台支持)
18
+ # ===============================
19
+ def setup_chinese_font():
20
+ # เส้นทางที่พบบ่อยใน Ubuntu/HF Spaces
21
+ candidate_paths = [
22
+ "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
23
+ "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
24
+ "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", # Debian/Ubuntu
25
+ "/usr/share/fonts/opentype/noto/NotoSansCJK-Sc-Regular.otf", # SC
26
+ "/usr/share/fonts/opentype/noto/NotoSansCJK-TC-Regular.otf", # TC
27
+ "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
28
+ ]
29
+
30
+ for p in candidate_paths:
31
+ if os.path.exists(p):
32
+ # ลงทะเบียนฟอนต์เข้า fontManager (สำคัญ)
33
+ fm.fontManager.addfont(p)
34
+ prop = fm.FontProperties(fname=p)
35
+ # ตั้งค่าสำรองให้ใช้ชื่อฟอนต์ที่เพิ่ง add เข้าไป
36
+ matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
37
+ matplotlib.rcParams["axes.unicode_minus"] = False
38
+ return prop
39
+
40
+ # สแกนทั้งระบบเผื่อชื่อไฟล์ต่าง distribution
41
+ for p in fm.findSystemFonts(fontpaths=["/usr/share/fonts", "/usr/local/share/fonts"]):
42
+ if any(k in p.lower() for k in ["wqy", "noto", "cjk", "droid"]):
43
+ fm.fontManager.addfont(p)
44
+ prop = fm.FontProperties(fname=p)
45
+ matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
46
+ matplotlib.rcParams["axes.unicode_minus"] = False
47
+ return prop
48
+
49
+ warnings.warn("ไม่พบฟอนต์จีน ใช้ DejaVu Sans ชั่วคราว")
50
+ matplotlib.rcParams["font.sans-serif"] = ["DejaVu Sans"]
51
+ matplotlib.rcParams["axes.unicode_minus"] = False
52
+ return fm.FontProperties()
53
+
54
+ zh_font = setup_chinese_font()
55
+
56
+ # ===============================
57
+ # 頁面設定
58
+ # ===============================
59
+ st.set_page_config(page_title="中文詞級 Transformer 可視化", layout="wide")
60
+ st.title("🧠 中文詞級 Transformer Token / Position / Attention 可視化工具")
61
+
62
+ # ===============================
63
+ # 模型選擇與載入
64
+ # ===============================
65
+ model_options = {
66
+ "Chinese RoBERTa (WWM-ext)": "hfl/chinese-roberta-wwm-ext",
67
+ "BERT-base-Chinese": "bert-base-chinese",
68
+ "Chinese MacBERT-base": "hfl/chinese-macbert-base"
69
+ }
70
+
71
+ selected_model = st.selectbox(
72
+ "選擇模型",
73
+ list(model_options.keys()),
74
+ index=0
75
+ )
76
+
77
+ model_name = model_options[selected_model]
78
+
79
+
80
+ @st.cache_resource
81
+ def load_model(name):
82
+ with st.spinner(f"載入模型 {name} 中..."):
83
+ try:
84
+ tokenizer = AutoTokenizer.from_pretrained(name)
85
+ model = AutoModel.from_pretrained(name, output_attentions=True)
86
+ return tokenizer, model, None
87
+ except Exception as e:
88
+ return None, None, str(e)
89
+
90
+
91
+ tokenizer, model, error = load_model(model_name)
92
+
93
+ if error:
94
+ st.error(f"模型載入失敗: {error}")
95
+ st.stop()
96
+
97
+ # ===============================
98
+ # 使用者輸入
99
+ # ===============================
100
+ text = st.text_area(
101
+ "請輸入中文句子:",
102
+ "我今年35歲,目前在科技業工作,作息略不規律。",
103
+ help="輸入您想分析的中文文本。將使用 Jieba 進行分詞,然後用 Transformer 模型分析。"
104
+ )
105
+
106
+
107
+ def normalize_text(s):
108
+ """移除特殊符號與全形字"""
109
+ s = re.sub(r"[^\u4e00-\u9fa5A-Za-z0-9,。、;:?!%%\s]", "", s)
110
+ s = s.replace("%", "%").replace("。", "。 ")
111
+ return s.strip()
112
+
113
+
114
+ # ===============================
115
+ # 主流程
116
+ # ===============================
117
+ if st.button("開始分析", type="primary"):
118
+ if not text.strip():
119
+ st.warning("請輸入有效的中文句子")
120
+ st.stop()
121
+
122
+ # 文本清理與分詞
123
+ text = normalize_text(text)
124
+ words = list(jieba.cut(text))
125
+ st.write("🔹 Jieba 分詞結果:", words)
126
+
127
+ # 不使用空格連接,直接使用原始文本
128
+ # 這樣可以避免空格導致的詞-token不匹配問題
129
+ tokenized_result = tokenizer(text, return_tensors="pt")
130
+ token_ids = tokenized_result["input_ids"][0].tolist()
131
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
132
+
133
+ # 為了更準確地映射詞和token,我們需要找出每個token在原始文本中的位置
134
+ # 創建更穩健的詞-token映射
135
+ char_to_word = {}
136
+ current_pos = 0
137
+
138
+ # 為每個字符創建映射到對應詞的索引
139
+ for word_idx, word in enumerate(words):
140
+ for _ in range(len(word)):
141
+ char_to_word[current_pos] = word_idx
142
+ current_pos += 1
143
+
144
+ # 創建token到字符位置的映射
145
+ # 注意:這個方法適用於基於字符的中文模型,如BERT/RoBERTa中文模型
146
+ # 對於某些模型可能需要調整
147
+
148
+ # 首先找出特殊標記
149
+ special_tokens = []
150
+ for i, token in enumerate(tokens):
151
+ if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<cls>', '<sep>']:
152
+ special_tokens.append(i)
153
+
154
+ # 找出原始文本中每個token的起始位置
155
+ chars = list(text) # 將文本轉換為字符列表
156
+ token_to_char_mapping = []
157
+ token_to_word_mapping = []
158
+
159
+ # 處理特殊標記
160
+ char_pos = 0
161
+ for i, token in enumerate(tokens):
162
+ if i in special_tokens:
163
+ token_to_char_mapping.append(-1) # 特殊標記沒有對應的字符位置
164
+ token_to_word_mapping.append("特殊標記")
165
+ else:
166
+ # 對於中文字符,大多數模型是一個字符一個token
167
+ # 這個邏輯可能需要根據具體模型調整
168
+ if token.startswith('##'): # BERT風格的子詞
169
+ actual_token = token[2:]
170
+ elif token.startswith('▁') or token.startswith('Ġ'): # 其他模型風格
171
+ actual_token = token[1:]
172
+ else:
173
+ actual_token = token
174
+
175
+ # 注意:中文BERT通常每個token就是一個字符
176
+ # 所以這裡可以直接映射
177
+ if char_pos < len(chars):
178
+ token_to_char_mapping.append(char_pos)
179
+ if char_pos in char_to_word:
180
+ word_idx = char_to_word[char_pos]
181
+ token_to_word_mapping.append(words[word_idx])
182
+ else:
183
+ token_to_word_mapping.append("未知詞")
184
+ char_pos += len(actual_token)
185
+ else:
186
+ token_to_char_mapping.append(-1)
187
+ token_to_word_mapping.append("未知詞")
188
+
189
+ # 創建詞到token的映射
190
+ word_to_tokens = [[] for _ in range(len(words))]
191
+ for i, word_idx in enumerate(char_to_word.values()):
192
+ if i < len(chars):
193
+ # 找出對應這個字符位置的token
194
+ for j, char_pos in enumerate(token_to_char_mapping):
195
+ if char_pos == i:
196
+ word_to_tokens[word_idx].append(j)
197
+ break
198
+
199
+ # 創建token-word對照表
200
+ token_word_df = pd.DataFrame({
201
+ "Token": tokens,
202
+ "Token_ID": token_ids,
203
+ "Word": token_to_word_mapping
204
+ })
205
+
206
+ # 創建word-tokens對照表
207
+ word_token_map = []
208
+ for i, word in enumerate(words):
209
+ token_indices = word_to_tokens[i]
210
+ token_list = [tokens[idx] for idx in token_indices if idx < len(tokens)]
211
+ word_token_map.append({
212
+ "Word": word,
213
+ "Tokens": " ".join(token_list) if token_list else "無對應Token"
214
+ })
215
+
216
+ word_token_df = pd.DataFrame(word_token_map)
217
+
218
+ # 模型前向運算
219
+ with torch.no_grad():
220
+ try:
221
+ outputs = model(**tokenized_result)
222
+
223
+ hidden_states = outputs.last_hidden_state.squeeze(0)
224
+ attentions = outputs.attentions
225
+
226
+ # Position & Token embeddings
227
+ position_ids = torch.arange(0, tokenized_result["input_ids"].size(1)).unsqueeze(0)
228
+ pos_embeddings = model.embeddings.position_embeddings(position_ids).squeeze(0)
229
+ tok_embeddings = model.embeddings.word_embeddings(tokenized_result["input_ids"]).squeeze(0)
230
+
231
+ # ===============================
232
+ # 顯示 Token-Word 映射
233
+ # ===============================
234
+ st.subheader("🔤 Token與詞的對應關係")
235
+
236
+ # 顯示詞-Token映射
237
+ st.write("詞對應的Tokens:")
238
+ st.dataframe(word_token_df, use_container_width=True)
239
+
240
+ # 顯示Token-詞映射
241
+ st.write("每個Token對應的詞:")
242
+ st.dataframe(token_word_df, use_container_width=True)
243
+
244
+ # ===============================
245
+ # 顯示 Embedding(前10維)
246
+ # ===============================
247
+ st.subheader("🧩 Token Embedding(前10維)")
248
+ tok_df = pd.DataFrame(tok_embeddings[:, :10].detach().numpy(),
249
+ columns=[f"dim_{i}" for i in range(10)])
250
+ tok_df.insert(0, "Token", tokens)
251
+ tok_df.insert(1, "Word", token_word_df["Word"])
252
+ st.dataframe(tok_df, use_container_width=True)
253
+
254
+ st.subheader("📍 Position Embedding(前10維)")
255
+ pos_df = pd.DataFrame(pos_embeddings[:, :10].detach().numpy(),
256
+ columns=[f"dim_{i}" for i in range(10)])
257
+ pos_df.insert(0, "Token", tokens)
258
+ pos_df.insert(1, "Word", token_word_df["Word"])
259
+ st.dataframe(pos_df, use_container_width=True)
260
+
261
+ # ===============================
262
+ # Attention 可視化
263
+ # ===============================
264
+ num_layers = len(attentions)
265
+ num_heads = attentions[0].shape[1]
266
+
267
+ col1, col2 = st.columns(2)
268
+ with col1:
269
+ layer_idx = st.slider("選擇 Attention 層數", 1, num_layers, num_layers)
270
+ with col2:
271
+ head_idx = st.slider("選擇 Attention Head", 1, num_heads, 1)
272
+
273
+ # 取得該層、該頭的注意力矩陣
274
+ selected_attention = attentions[layer_idx - 1][0, head_idx - 1].detach().numpy()
275
+ mean_attention = attentions[layer_idx - 1][0].mean(0).detach().numpy()
276
+
277
+ # 添加標註信息
278
+ token_labels = [f"{t}\n({w})" if w != "特殊標記" else t
279
+ for t, w in zip(tokens, token_word_df["Word"])]
280
+
281
+ # 單頭 Attention Heatmap
282
+ st.subheader(f"🔥 Attention Heatmap(第 {layer_idx} 層,第 {head_idx} 頭)")
283
+ fig, ax = plt.subplots(figsize=(12, 10))
284
+ sns.heatmap(selected_attention, xticklabels=token_labels, yticklabels=token_labels,
285
+ cmap="YlGnBu", ax=ax)
286
+ plt.title(f"Attention - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
287
+ plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
288
+ plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
289
+ st.pyplot(fig, clear_figure=True, use_container_width=True)
290
+
291
+ # 平均所有頭
292
+ st.subheader(f"🌈 平均所有頭(第 {layer_idx} 層)")
293
+ fig2, ax2 = plt.subplots(figsize=(12, 10))
294
+ sns.heatmap(mean_attention, xticklabels=token_labels, yticklabels=token_labels,
295
+ cmap="rocket_r", ax=ax2)
296
+ plt.title(f"Mean Attention - Layer {layer_idx}", fontproperties=zh_font)
297
+ plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
298
+ plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
299
+ st.pyplot(fig2, clear_figure=True, use_container_width=True)
300
+
301
+ # ===============================
302
+ # 詞的平均注意力可視化
303
+ # ===============================
304
+ st.subheader("📊 詞級別注意力熱圖")
305
+
306
+ # 創建詞彙列表(去除特殊標記和未知詞)
307
+ unique_words = [w for w in words if w.strip()]
308
+
309
+ if len(unique_words) > 1: # 確保有足夠的詞進行可視化
310
+ # 創建詞-詞注意力矩陣
311
+ word_attention = np.zeros((len(unique_words), len(unique_words)))
312
+
313
+ # 使用之前建立的映射來聚合token級別的注意力到詞級別
314
+ for i, word_i in enumerate(unique_words):
315
+ # 找出屬於word_i的所有token
316
+ tokens_i = []
317
+ for j, w in enumerate(token_word_df["Word"]):
318
+ if w == word_i:
319
+ tokens_i.append(j)
320
+
321
+ for j, word_j in enumerate(unique_words):
322
+ # 找出屬於word_j的所有token
323
+ tokens_j = []
324
+ for k, w in enumerate(token_word_df["Word"]):
325
+ if w == word_j:
326
+ tokens_j.append(k)
327
+
328
+ # 計算這兩個詞之間的所有token對的平均注意力
329
+ if tokens_i and tokens_j: # 確保兩個詞都有對應的token
330
+ attention_sum = 0
331
+ count = 0
332
+ for ti in tokens_i:
333
+ for tj in tokens_j:
334
+ if ti < len(selected_attention) and tj < len(selected_attention[0]):
335
+ attention_sum += selected_attention[ti, tj]
336
+ count += 1
337
+
338
+ if count > 0:
339
+ word_attention[i, j] = attention_sum / count
340
+
341
+ # 繪製詞級別注意力熱圖
342
+ fig3, ax3 = plt.subplots(figsize=(10, 8))
343
+ sns.heatmap(word_attention, xticklabels=unique_words, yticklabels=unique_words,
344
+ cmap="viridis", ax=ax3)
345
+ plt.title(f"詞級別注意力 - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
346
+ plt.xticks(rotation=45, fontsize=12, fontproperties=zh_font)
347
+ plt.yticks(rotation=0, fontsize=12, fontproperties=zh_font)
348
+ st.pyplot(fig3, clear_figure=True, use_container_width=True)
349
+ else:
350
+ st.info("詞數量不足,無法生成詞級別注意力熱圖")
351
+
352
+ # ===============================
353
+ # 下載 CSV
354
+ # ===============================
355
+ merged_df = pd.concat([tok_df, pos_df.add_prefix("pos_").iloc[:, 2:]], axis=1)
356
+ st.download_button(
357
+ label="💾 下載 Token + Position 向量 CSV",
358
+ data=merged_df.to_csv(index=False).encode("utf-8-sig"),
359
+ file_name="embeddings.csv",
360
+ mime="text/csv"
361
+ )
362
+
363
+ # 詞級別平均 embeddings
364
+ st.subheader("📑 詞級別平均 Embeddings(前10維)")
365
+
366
+ word_embeddings = {}
367
+ for word in unique_words:
368
+ # 找出屬於該詞的所有token索引
369
+ token_indices = [i for i, w in enumerate(token_word_df["Word"]) if w == word]
370
+
371
+ if token_indices:
372
+ # 計算該詞的平均 embedding
373
+ word_emb = tok_embeddings[token_indices].mean(dim=0)
374
+ word_embeddings[word] = word_emb[:10].detach().numpy()
375
+
376
+ if word_embeddings:
377
+ word_emb_df = pd.DataFrame.from_dict(
378
+ {word: values for word, values in word_embeddings.items()},
379
+ orient='index',
380
+ columns=[f"dim_{i}" for i in range(10)]
381
+ )
382
+ word_emb_df = word_emb_df.reset_index().rename(columns={"index": "Word"})
383
+ st.dataframe(word_emb_df, use_container_width=True)
384
+
385
+ # 下載詞級別 embeddings
386
+ st.download_button(
387
+ label="💾 下載詞級別向量 CSV",
388
+ data=word_emb_df.to_csv(index=False).encode("utf-8-sig"),
389
+ file_name="word_embeddings.csv",
390
+ mime="text/csv"
391
+ )
392
+
393
+ except Exception as e:
394
+ st.error(f"處理時發生錯誤: {str(e)}")
395
+ import traceback
396
+
397
+ st.code(traceback.format_exc(), language="python")
398
+
399
+ # ===============================
400
+ # 說明與幫助
401
+ # ===============================
402
+ with st.expander("📖 使用說明"):
403
+ st.markdown("""
404
+ ### 工具功能
405
+
406
+ 這個工具可以幫助您理解 Transformer 模型如何處理中文文本:
407
+
408
+ 1. **分詞與映射**:使用 Jieba 將文本分詞,然後映射到 Transformer 模型的 token
409
+ 2. **Embedding 可視化**:查看每個 token 和位置的 embedding 向量前10維
410
+ 3. **Attention 可視化**:查看不同層和頭的注意力模式
411
+ 4. **詞級別分析**:整合 token 級別信息,得到詞級別的 embedding 和注意力模式
412
+
413
+ ### 使用方法
414
+
415
+ 1. 選擇一個預訓練的中文模型
416
+ 2. 輸入您想分析的中文文本
417
+ 3. 點擊"開始分析"按鈕
418
+ 4. 使用滑塊選擇不同的層和注意力頭進行可視化
419
+ 5. 下載 CSV 文件以進一步分析
420
+
421
+ ### 技術細節
422
+
423
+ - **詞-Token映射**:中文字符通常會被映射到單個Token,而詞通常由多個Token組成
424
+ - **注意力機制**:每一層的每個注意力頭都關注不同的模式
425
+ - **注意力熱圖**���顏色越深表示注意力越強
426
+
427
+ ### 注意事項
428
+
429
+ - Transformer 模型可能會將一個詞切分成多個 token
430
+ - 特殊標記(如 [CLS], [SEP])會被排除在詞級別分析之外
431
+ - 較長的文本可能需要更多處理時間
432
  """)