Spaces:
Build error
Build error
taskswithcode
commited on
Commit
·
5fe6115
1
Parent(s):
e4cf805
Fixes
Browse files- app.py +21 -14
- clus_app_clustypes.json +4 -0
- twc_clustering.py +82 -22
app.py
CHANGED
|
@@ -103,16 +103,16 @@ def load_model(model_name,model_class,load_model_name):
|
|
| 103 |
|
| 104 |
|
| 105 |
@st.experimental_memo
|
| 106 |
-
def cached_compute_similarity(sentences,_model,model_name,threshold,_cluster):
|
| 107 |
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
| 108 |
-
results = _cluster.cluster(None,texts,embeddings,threshold)
|
| 109 |
return results
|
| 110 |
|
| 111 |
|
| 112 |
-
def uncached_compute_similarity(sentences,_model,model_name,threshold,cluster):
|
| 113 |
with st.spinner('Computing vectors for sentences'):
|
| 114 |
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
| 115 |
-
results = cluster.cluster(None,texts,embeddings,threshold)
|
| 116 |
#st.success("Similarity computation complete")
|
| 117 |
return results
|
| 118 |
|
|
@@ -124,7 +124,7 @@ def get_model_info(model_names,model_name):
|
|
| 124 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
| 125 |
|
| 126 |
|
| 127 |
-
def run_test(model_names,model_name,sentences,display_area,threshold,user_uploaded,custom_model):
|
| 128 |
display_area.text("Loading model:" + model_name)
|
| 129 |
#Note. model_name may get mapped to new name in the call below for custom models
|
| 130 |
orig_model_name = model_name
|
|
@@ -140,10 +140,10 @@ def run_test(model_names,model_name,sentences,display_area,threshold,user_upload
|
|
| 140 |
display_area.text("Model " + model_name + " load complete")
|
| 141 |
try:
|
| 142 |
if (user_uploaded):
|
| 143 |
-
results = uncached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"])
|
| 144 |
else:
|
| 145 |
display_area.text("Computing vectors for sentences")
|
| 146 |
-
results = cached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"])
|
| 147 |
display_area.text("Similarity computation complete")
|
| 148 |
return results
|
| 149 |
|
|
@@ -193,16 +193,19 @@ def init_session():
|
|
| 193 |
st.session_state["model_name"] = "ss_test"
|
| 194 |
st.session_state["threshold"] = 1.5
|
| 195 |
st.session_state["file_name"] = "default"
|
|
|
|
| 196 |
st.session_state["cluster"] = TWCClustering()
|
| 197 |
else:
|
| 198 |
print("Skipping init session")
|
| 199 |
|
| 200 |
-
def app_main(app_mode,example_files,model_name_files):
|
| 201 |
init_session()
|
| 202 |
with open(example_files) as fp:
|
| 203 |
example_file_names = json.load(fp)
|
| 204 |
with open(model_name_files) as fp:
|
| 205 |
model_names = json.load(fp)
|
|
|
|
|
|
|
| 206 |
curr_use_case = use_case[app_mode].split(".")[0]
|
| 207 |
st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
|
| 208 |
st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
|
|
@@ -215,7 +218,7 @@ def app_main(app_mode,example_files,model_name_files):
|
|
| 215 |
|
| 216 |
with st.form('twc_form'):
|
| 217 |
|
| 218 |
-
step1_line = "
|
| 219 |
if (app_mode == DOC_RETRIEVAL):
|
| 220 |
step1_line += ". The first line is treated as the query"
|
| 221 |
uploaded_file = st.file_uploader(step1_line, type=".txt")
|
|
@@ -224,14 +227,17 @@ def app_main(app_mode,example_files,model_name_files):
|
|
| 224 |
options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
|
| 225 |
st.write("")
|
| 226 |
options_arr,markdown_str = construct_model_info_for_display(model_names)
|
| 227 |
-
selection_label = '
|
| 228 |
selected_model = st.selectbox(label=selection_label,
|
| 229 |
options = options_arr, index=0, key = "twc_model")
|
| 230 |
st.write("")
|
| 231 |
custom_model_selection = st.text_input("Model not listed above? Type any Huggingface sentence embedding model name ", "",key="custom_model")
|
| 232 |
hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface sentence embedding models</a><br/><br/><br/></div>"
|
| 233 |
st.markdown(hf_link_str, unsafe_allow_html=True)
|
| 234 |
-
threshold = st.number_input('
|
|
|
|
|
|
|
|
|
|
| 235 |
st.write("")
|
| 236 |
submit_button = st.form_submit_button('Run')
|
| 237 |
|
|
@@ -256,7 +262,8 @@ def app_main(app_mode,example_files,model_name_files):
|
|
| 256 |
run_model = selected_model
|
| 257 |
st.session_state["model_name"] = selected_model
|
| 258 |
st.session_state["threshold"] = threshold
|
| 259 |
-
|
|
|
|
| 260 |
display_area.empty()
|
| 261 |
with display_area.container():
|
| 262 |
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
|
@@ -269,7 +276,7 @@ def app_main(app_mode,example_files,model_name_files):
|
|
| 269 |
label="Download results as json",
|
| 270 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
| 271 |
disabled = False if st.session_state["download_ready"] != None else True,
|
| 272 |
-
file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
|
| 273 |
mime='text/json',
|
| 274 |
key ="download"
|
| 275 |
)
|
|
@@ -288,5 +295,5 @@ if __name__ == "__main__":
|
|
| 288 |
#print("comand line input:",len(sys.argv),str(sys.argv))
|
| 289 |
#app_main(sys.argv[1],sys.argv[2],sys.argv[3])
|
| 290 |
#app_main("1","sim_app_examples.json","sim_app_models.json")
|
| 291 |
-
app_main("3","clus_app_examples.json","clus_app_models.json")
|
| 292 |
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@st.experimental_memo
|
| 106 |
+
def cached_compute_similarity(sentences,_model,model_name,threshold,_cluster,clustering_type):
|
| 107 |
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
| 108 |
+
results = _cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
| 109 |
return results
|
| 110 |
|
| 111 |
|
| 112 |
+
def uncached_compute_similarity(sentences,_model,model_name,threshold,cluster,clustering_type):
|
| 113 |
with st.spinner('Computing vectors for sentences'):
|
| 114 |
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
| 115 |
+
results = cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
| 116 |
#st.success("Similarity computation complete")
|
| 117 |
return results
|
| 118 |
|
|
|
|
| 124 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
| 125 |
|
| 126 |
|
| 127 |
+
def run_test(model_names,model_name,sentences,display_area,threshold,user_uploaded,custom_model,clustering_type):
|
| 128 |
display_area.text("Loading model:" + model_name)
|
| 129 |
#Note. model_name may get mapped to new name in the call below for custom models
|
| 130 |
orig_model_name = model_name
|
|
|
|
| 140 |
display_area.text("Model " + model_name + " load complete")
|
| 141 |
try:
|
| 142 |
if (user_uploaded):
|
| 143 |
+
results = uncached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
| 144 |
else:
|
| 145 |
display_area.text("Computing vectors for sentences")
|
| 146 |
+
results = cached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
| 147 |
display_area.text("Similarity computation complete")
|
| 148 |
return results
|
| 149 |
|
|
|
|
| 193 |
st.session_state["model_name"] = "ss_test"
|
| 194 |
st.session_state["threshold"] = 1.5
|
| 195 |
st.session_state["file_name"] = "default"
|
| 196 |
+
st.session_state["overlapped"] = "overlapped"
|
| 197 |
st.session_state["cluster"] = TWCClustering()
|
| 198 |
else:
|
| 199 |
print("Skipping init session")
|
| 200 |
|
| 201 |
+
def app_main(app_mode,example_files,model_name_files,clus_types):
|
| 202 |
init_session()
|
| 203 |
with open(example_files) as fp:
|
| 204 |
example_file_names = json.load(fp)
|
| 205 |
with open(model_name_files) as fp:
|
| 206 |
model_names = json.load(fp)
|
| 207 |
+
with open(clus_types) as fp:
|
| 208 |
+
cluster_types = json.load(fp)
|
| 209 |
curr_use_case = use_case[app_mode].split(".")[0]
|
| 210 |
st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
|
| 211 |
st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
|
|
|
|
| 218 |
|
| 219 |
with st.form('twc_form'):
|
| 220 |
|
| 221 |
+
step1_line = "Upload text file(one sentence in a line) or choose an example text file below"
|
| 222 |
if (app_mode == DOC_RETRIEVAL):
|
| 223 |
step1_line += ". The first line is treated as the query"
|
| 224 |
uploaded_file = st.file_uploader(step1_line, type=".txt")
|
|
|
|
| 227 |
options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
|
| 228 |
st.write("")
|
| 229 |
options_arr,markdown_str = construct_model_info_for_display(model_names)
|
| 230 |
+
selection_label = 'Select Model'
|
| 231 |
selected_model = st.selectbox(label=selection_label,
|
| 232 |
options = options_arr, index=0, key = "twc_model")
|
| 233 |
st.write("")
|
| 234 |
custom_model_selection = st.text_input("Model not listed above? Type any Huggingface sentence embedding model name ", "",key="custom_model")
|
| 235 |
hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface sentence embedding models</a><br/><br/><br/></div>"
|
| 236 |
st.markdown(hf_link_str, unsafe_allow_html=True)
|
| 237 |
+
threshold = st.number_input('Choose a zscore threshold (number of std devs from mean)',value=st.session_state["threshold"],min_value = 0.0,step=.01)
|
| 238 |
+
st.write("")
|
| 239 |
+
clustering_type = st.selectbox(label=f'Select type of clustering',
|
| 240 |
+
options = list(dict.keys(cluster_types)), index=0, key = "twc_cluster_types")
|
| 241 |
st.write("")
|
| 242 |
submit_button = st.form_submit_button('Run')
|
| 243 |
|
|
|
|
| 262 |
run_model = selected_model
|
| 263 |
st.session_state["model_name"] = selected_model
|
| 264 |
st.session_state["threshold"] = threshold
|
| 265 |
+
st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
|
| 266 |
+
results = run_test(model_names,run_model,sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0),cluster_types[clustering_type]["type"])
|
| 267 |
display_area.empty()
|
| 268 |
with display_area.container():
|
| 269 |
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
|
|
|
| 276 |
label="Download results as json",
|
| 277 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
| 278 |
disabled = False if st.session_state["download_ready"] != None else True,
|
| 279 |
+
file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + st.session_state["overlapped"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
|
| 280 |
mime='text/json',
|
| 281 |
key ="download"
|
| 282 |
)
|
|
|
|
| 295 |
#print("comand line input:",len(sys.argv),str(sys.argv))
|
| 296 |
#app_main(sys.argv[1],sys.argv[2],sys.argv[3])
|
| 297 |
#app_main("1","sim_app_examples.json","sim_app_models.json")
|
| 298 |
+
app_main("3","clus_app_examples.json","clus_app_models.json","clus_app_clustypes.json")
|
| 299 |
|
clus_app_clustypes.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Overlapped clustering (cluster size determined by zscore)": {"type":"overlapped"},
|
| 3 |
+
"Non-overlapped clustering (overlapped clusters aggregated)":{"type":"non-overlapped"}
|
| 4 |
+
}
|
twc_clustering.py
CHANGED
|
@@ -31,27 +31,30 @@ class TWCClustering:
|
|
| 31 |
picked_arr = []
|
| 32 |
while (run_index < len(embeddings)):
|
| 33 |
if (matrix[pivot_index][run_index] >= threshold):
|
| 34 |
-
|
| 35 |
-
picked_arr.append({"index":run_index})
|
| 36 |
run_index += 1
|
| 37 |
return picked_arr
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def update_picked_dict(self,picked_dict,in_dict):
|
| 40 |
for key in in_dict:
|
| 41 |
picked_dict[key] = 1
|
| 42 |
|
| 43 |
-
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold):
|
| 44 |
center_index = pivot_index
|
| 45 |
center_score = 0
|
| 46 |
center_dict = {}
|
| 47 |
for i in range(len(arr)):
|
| 48 |
-
node_i_index = arr[i]
|
| 49 |
running_score = 0
|
| 50 |
temp_dict = {}
|
| 51 |
for j in range(len(arr)):
|
| 52 |
-
node_j_index = arr[j]
|
| 53 |
cosine_dist = matrix[node_i_index][node_j_index]
|
| 54 |
-
if (cosine_dist < threshold):
|
| 55 |
continue
|
| 56 |
running_score += cosine_dist
|
| 57 |
temp_dict[node_j_index] = cosine_dist
|
|
@@ -80,8 +83,76 @@ class TWCClustering:
|
|
| 80 |
bucket_dict[overlap_dict[key]] += 1
|
| 81 |
sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
|
| 82 |
return sorted_d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
def cluster(self,output_file,texts,embeddings,threshold
|
|
|
|
| 85 |
matrix = self.compute_matrix(embeddings)
|
| 86 |
mean = np.mean(matrix)
|
| 87 |
std = np.std(matrix)
|
|
@@ -95,22 +166,11 @@ class TWCClustering:
|
|
| 95 |
#print("In clustering:",round(std,2),zscores)
|
| 96 |
cluster_dict = {}
|
| 97 |
cluster_dict["clusters"] = []
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
if (i in picked_dict):
|
| 103 |
-
continue
|
| 104 |
-
zscore = mean + threshold*std
|
| 105 |
-
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
|
| 106 |
-
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore)
|
| 107 |
-
self.update_picked_dict(picked_dict,cluster_info["neighs"])
|
| 108 |
-
self.update_overlap_stats(overlap_dict,cluster_info)
|
| 109 |
-
cluster_dict["clusters"].append(cluster_info)
|
| 110 |
curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
|
| 111 |
-
sorted_d = OrderedDict(sorted(overlap_dict.items(), key=lambda kv: kv[1], reverse=True))
|
| 112 |
-
#print(sorted_d)
|
| 113 |
-
sorted_d = self.bucket_overlap(overlap_dict)
|
| 114 |
cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
|
| 115 |
return cluster_dict
|
| 116 |
|
|
|
|
| 31 |
picked_arr = []
|
| 32 |
while (run_index < len(embeddings)):
|
| 33 |
if (matrix[pivot_index][run_index] >= threshold):
|
| 34 |
+
picked_arr.append(run_index)
|
|
|
|
| 35 |
run_index += 1
|
| 36 |
return picked_arr
|
| 37 |
|
| 38 |
+
def update_picked_dict_arr(self,picked_dict,arr):
|
| 39 |
+
for i in range(len(arr)):
|
| 40 |
+
picked_dict[arr[i]] = 1
|
| 41 |
+
|
| 42 |
def update_picked_dict(self,picked_dict,in_dict):
|
| 43 |
for key in in_dict:
|
| 44 |
picked_dict[key] = 1
|
| 45 |
|
| 46 |
+
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold,strict_cluster = True):
|
| 47 |
center_index = pivot_index
|
| 48 |
center_score = 0
|
| 49 |
center_dict = {}
|
| 50 |
for i in range(len(arr)):
|
| 51 |
+
node_i_index = arr[i]
|
| 52 |
running_score = 0
|
| 53 |
temp_dict = {}
|
| 54 |
for j in range(len(arr)):
|
| 55 |
+
node_j_index = arr[j]
|
| 56 |
cosine_dist = matrix[node_i_index][node_j_index]
|
| 57 |
+
if ((cosine_dist < threshold) and strict_cluster):
|
| 58 |
continue
|
| 59 |
running_score += cosine_dist
|
| 60 |
temp_dict[node_j_index] = cosine_dist
|
|
|
|
| 83 |
bucket_dict[overlap_dict[key]] += 1
|
| 84 |
sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
|
| 85 |
return sorted_d
|
| 86 |
+
|
| 87 |
+
def merge_clusters(self,ref_cluster,curr_cluster):
|
| 88 |
+
dup_arr = ref_cluster.copy()
|
| 89 |
+
for j in range(len(curr_cluster)):
|
| 90 |
+
if (curr_cluster[j] not in dup_arr):
|
| 91 |
+
ref_cluster.append(curr_cluster[j])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def non_overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
|
| 95 |
+
picked_dict = {}
|
| 96 |
+
overlap_dict = {}
|
| 97 |
+
candidates = []
|
| 98 |
+
|
| 99 |
+
for i in range(len(embeddings)):
|
| 100 |
+
if (i in picked_dict):
|
| 101 |
+
continue
|
| 102 |
+
zscore = mean + threshold*std
|
| 103 |
+
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
|
| 104 |
+
candidates.append(arr)
|
| 105 |
+
self.update_picked_dict_arr(picked_dict,arr)
|
| 106 |
+
|
| 107 |
+
# Merge arrays to create non-overlapping sets
|
| 108 |
+
run_index_i = 0
|
| 109 |
+
while (run_index_i < len(candidates)):
|
| 110 |
+
ref_cluster = candidates[run_index_i]
|
| 111 |
+
run_index_j = run_index_i + 1
|
| 112 |
+
found = False
|
| 113 |
+
while (run_index_j < len(candidates)):
|
| 114 |
+
curr_cluster = candidates[run_index_j]
|
| 115 |
+
for k in range(len(curr_cluster)):
|
| 116 |
+
if (curr_cluster[k] in ref_cluster):
|
| 117 |
+
self.merge_clusters(ref_cluster,curr_cluster)
|
| 118 |
+
candidates.pop(run_index_j)
|
| 119 |
+
found = True
|
| 120 |
+
run_index_i = 0
|
| 121 |
+
break
|
| 122 |
+
if (found):
|
| 123 |
+
break
|
| 124 |
+
else:
|
| 125 |
+
run_index_j += 1
|
| 126 |
+
if (not found):
|
| 127 |
+
run_index_i += 1
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
zscore = mean + threshold*std
|
| 131 |
+
for i in range(len(candidates)):
|
| 132 |
+
arr = candidates[i]
|
| 133 |
+
cluster_info = self.find_pivot_subgraph(arr[0],arr,matrix,zscore,strict_cluster = False)
|
| 134 |
+
cluster_dict["clusters"].append(cluster_info)
|
| 135 |
+
return {}
|
| 136 |
+
|
| 137 |
+
def overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
|
| 138 |
+
picked_dict = {}
|
| 139 |
+
overlap_dict = {}
|
| 140 |
+
|
| 141 |
+
zscore = mean + threshold*std
|
| 142 |
+
for i in range(len(embeddings)):
|
| 143 |
+
if (i in picked_dict):
|
| 144 |
+
continue
|
| 145 |
+
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
|
| 146 |
+
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore,strict_cluster = True)
|
| 147 |
+
self.update_picked_dict(picked_dict,cluster_info["neighs"])
|
| 148 |
+
self.update_overlap_stats(overlap_dict,cluster_info)
|
| 149 |
+
cluster_dict["clusters"].append(cluster_info)
|
| 150 |
+
sorted_d = self.bucket_overlap(overlap_dict)
|
| 151 |
+
return sorted_d
|
| 152 |
+
|
| 153 |
|
| 154 |
+
def cluster(self,output_file,texts,embeddings,threshold,clustering_type):
|
| 155 |
+
is_overlapped = True if clustering_type == "overlapped" else False
|
| 156 |
matrix = self.compute_matrix(embeddings)
|
| 157 |
mean = np.mean(matrix)
|
| 158 |
std = np.std(matrix)
|
|
|
|
| 166 |
#print("In clustering:",round(std,2),zscores)
|
| 167 |
cluster_dict = {}
|
| 168 |
cluster_dict["clusters"] = []
|
| 169 |
+
if (is_overlapped):
|
| 170 |
+
sorted_d = self.overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
|
| 171 |
+
else:
|
| 172 |
+
sorted_d = self.non_overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
|
|
|
|
|
|
|
|
|
|
| 174 |
cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
|
| 175 |
return cluster_dict
|
| 176 |
|