chewing commited on
Commit
5c5acec
1 Parent(s): 4e8832f

添加tagger

Browse files
app.py CHANGED
@@ -1,5 +1,20 @@
1
  import streamlit as st
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  x = st.slider('Select a value')
 
1
  import streamlit as st
2
+ from tagger_map import Tagger as Tagger_Map
3
+ from tagger_map import zh_dict
4
 
5
+ tagger_map = Tagger_Map()
6
+ def search_text(search_sentences,topn= 5):
7
+ search_sentences = search_sentences.replace("_"," ")
8
+ search_sentences = search_sentences.strip()
9
+ if search_sentences not in zh_dict:
10
+ return ["error"]
11
+
12
+ else:
13
+ rtn0 = tagger_map.get_top_weighted_neighbors(search_sentences,topn)
14
+ rtn = []
15
+ for tag in rtn0:
16
+ rtn.append(f"{tag.replace(' ','_')}《{zh_dict[tag]}》")
17
+ return rtn
18
 
19
 
20
  x = st.slider('Select a value')
data/all.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85cf84f331fca6e0ef00f6b94f99d9e0d40330df46e32a41d8bd4a9b4b3a69bb
3
+ size 56671846
data/all_name_id_cut.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c6250b7f2bcb8ea507d5d11bc82747c6fa0959f360c403145facf7c68a46c0c
3
+ size 326486
data/all_name_id_zh.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/safe.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0169ed4790fd9b54450bb12980b29ac29d6553719f533fe2715ce82808ddfb0e
3
+ size 20330969
data/safe_name_id_cut.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:250b9c37df828de0d4ad208fc8e4886ce8e6e2f476400be3027e5dc832aba488
3
+ size 185038
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ RainbowPrint
2
+ bidict
tagger_map.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import pickle
3
+ import networkx as nx
4
+ import heapq
5
+ from bidict import bidict
6
+ from RainbowPrint import RainbowPrint as rp
7
+
8
+ zh_path = r"./data/all_name_id_zh.txt"
9
+
10
+ zh_dict = {}
11
+ with open(zh_path, "r", encoding="utf-8-sig") as f:
12
+ for line in f.readlines():
13
+ line = line.replace("\n", "")
14
+ tag, zh = line.split("|!|!|")
15
+ zh_dict[tag]=zh
16
+
17
+
18
+ class Tagger():
19
+ def __init__(self, pkl_name=r"./data/all.pkl"):
20
+ with open(pkl_name, 'rb') as f:
21
+ self.G = pickle.load(f)
22
+ with open(pkl_name.replace(".pkl", "_name_id_cut.pkl"), 'rb') as f:
23
+ self.nodes_id = pickle.load(f)
24
+
25
+ zh_dict = {}
26
+ with open(pkl_name.replace(".pkl", "_name_id_zh.txt"), "r", encoding="utf-8-sig") as f:
27
+ for line in f.readlines():
28
+ line = line.replace("\n", "")
29
+ tag, zh = line.split("|!|!|")
30
+ zh_dict[tag] = zh
31
+ self.zh_dict = zh_dict
32
+
33
+ assert len(self.G.nodes) == len(self.nodes_id.keys())
34
+
35
+
36
+ def get_top_weighted_neighbors(self, node_str, n=20):
37
+ rp.debug('map: query:', node_str)
38
+ node = self.nodes_id[node_str]
39
+ if node not in self.G:
40
+ raise ValueError(f"Node {node} is not in the graph")
41
+
42
+ if not nx.get_edge_attributes(self.G, 'weight'):
43
+ raise nx.NetworkXError("Edges do not have a 'weight' attribute")
44
+
45
+ # 创建一个小顶堆来保持前n个权重最大的邻居
46
+ min_heap = []
47
+ for nbr in self.G.neighbors(node):
48
+ edge_weight = self.G[node][nbr]['weight']
49
+ nbr_weight = self.G.nodes[nbr]['weight']
50
+ combined_weight = edge_weight / nbr_weight
51
+ if len(min_heap) < n:
52
+ heapq.heappush(min_heap, (combined_weight, nbr))
53
+ else:
54
+ heapq.heappushpop(min_heap, (combined_weight, nbr))
55
+
56
+ top_neighbors_with_weights = sorted(min_heap, key=lambda x: x[0], reverse=True)
57
+
58
+ # 仅返回邻居节点的标识
59
+ return [self.nodes_id.inverse[nbr] for _, nbr in top_neighbors_with_weights]
60
+
61
+
62
+
63
+
64
+ if __name__ == '__main__':
65
+ tagger = Tagger()
66
+ print(tagger.get_top_weighted_neighbors("doll"))