Spaces:
Runtime error
Runtime error
Swapped to bokeh, changed up samples
Browse files- .gitignore +1 -0
- app.py +2 -3
- lib/graph_extract.py +2 -0
- lib/samples.py +4 -4
- lib/visualize.py +84 -2
- requirements.txt +2 -1
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
CHANGED
@@ -5,12 +5,11 @@ import gradio as gr
|
|
5 |
import spaces
|
6 |
|
7 |
from lib.graph_extract import triplextract, parse_triples
|
8 |
-
from lib.visualize import
|
9 |
from lib.samples import snippets
|
10 |
|
11 |
WORD_LIMIT = 300
|
12 |
|
13 |
-
@spaces.GPU
|
14 |
def process_text(text, entity_types, predicates):
|
15 |
if not text:
|
16 |
return None, "Please enter some text."
|
@@ -40,7 +39,7 @@ def process_text(text, entity_types, predicates):
|
|
40 |
"No entities or relationships found. Try different text or check your input.",
|
41 |
)
|
42 |
|
43 |
-
fig =
|
44 |
return (
|
45 |
fig,
|
46 |
f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}",
|
|
|
5 |
import spaces
|
6 |
|
7 |
from lib.graph_extract import triplextract, parse_triples
|
8 |
+
from lib.visualize import create_bokeh_plot #, create_plotly_plot
|
9 |
from lib.samples import snippets
|
10 |
|
11 |
WORD_LIMIT = 300
|
12 |
|
|
|
13 |
def process_text(text, entity_types, predicates):
|
14 |
if not text:
|
15 |
return None, "Please enter some text."
|
|
|
39 |
"No entities or relationships found. Try different text or check your input.",
|
40 |
)
|
41 |
|
42 |
+
fig = create_bokeh_plot(entities, relationships)
|
43 |
return (
|
44 |
fig,
|
45 |
f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}",
|
lib/graph_extract.py
CHANGED
@@ -54,6 +54,8 @@ print("Model and tokenizer loaded successfully.")
|
|
54 |
generation_config = GenerationConfig.from_pretrained("sciphi/triplex")
|
55 |
generation_config.max_length = 2048
|
56 |
generation_config.pad_token_id = tokenizer.eos_token_id
|
|
|
|
|
57 |
@spaces.GPU
|
58 |
def triplextract(text, entity_types, predicates):
|
59 |
input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
|
|
|
54 |
generation_config = GenerationConfig.from_pretrained("sciphi/triplex")
|
55 |
generation_config.max_length = 2048
|
56 |
generation_config.pad_token_id = tokenizer.eos_token_id
|
57 |
+
|
58 |
+
|
59 |
@spaces.GPU
|
60 |
def triplextract(text, entity_types, predicates):
|
61 |
input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
|
lib/samples.py
CHANGED
@@ -19,8 +19,8 @@ snippets = {
|
|
19 |
we were all going direct to Heaven, we were all going direct the other way β in short,
|
20 |
the period was so far like the present period, that some of its noisiest authorities
|
21 |
insisted on its being received, for good or for evil, in the superlative degree of comparison only.""",
|
22 |
-
entity_types="
|
23 |
-
predicates="WAS, HAD, WERE"
|
24 |
),
|
25 |
|
26 |
'tech_company': Snippet(
|
@@ -29,8 +29,8 @@ snippets = {
|
|
29 |
software, and online services. The company's flagship products include the iPhone smartphone,
|
30 |
iPad tablet, and Mac personal computer. As of 2023, Apple has over 150,000 employees worldwide
|
31 |
and generates annual revenue exceeding $350 billion.""",
|
32 |
-
entity_types="COMPANY, PERSON, PRODUCT, LOCATION, DATE,
|
33 |
-
predicates="FOUNDED,
|
34 |
),
|
35 |
|
36 |
'climate_change': Snippet(
|
|
|
19 |
we were all going direct to Heaven, we were all going direct the other way β in short,
|
20 |
the period was so far like the present period, that some of its noisiest authorities
|
21 |
insisted on its being received, for good or for evil, in the superlative degree of comparison only.""",
|
22 |
+
entity_types="EMOTION, EVENT, OUTCOME, PLACE",
|
23 |
+
predicates="WAS, HAD, WERE, IS"
|
24 |
),
|
25 |
|
26 |
'tech_company': Snippet(
|
|
|
29 |
software, and online services. The company's flagship products include the iPhone smartphone,
|
30 |
iPad tablet, and Mac personal computer. As of 2023, Apple has over 150,000 employees worldwide
|
31 |
and generates annual revenue exceeding $350 billion.""",
|
32 |
+
entity_types="COMPANY, PERSON, PRODUCT, LOCATION, DATE, EVENT",
|
33 |
+
predicates="FOUNDED, PRODUCES, HAS, IN, EMPLOYEES, "
|
34 |
),
|
35 |
|
36 |
'climate_change': Snippet(
|
lib/visualize.py
CHANGED
@@ -1,10 +1,92 @@
|
|
1 |
import plotly.graph_objects as go
|
2 |
import networkx as nx
|
3 |
|
4 |
-
import plotly.graph_objects as go
|
5 |
import networkx as nx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
def
|
8 |
G = nx.DiGraph() # Use DiGraph for directed edges
|
9 |
|
10 |
for entity_id, entity_data in entities.items():
|
|
|
1 |
import plotly.graph_objects as go
|
2 |
import networkx as nx
|
3 |
|
|
|
4 |
import networkx as nx
|
5 |
+
from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
|
6 |
+
Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
|
7 |
+
from bokeh.palettes import Spectral4
|
8 |
+
from bokeh.plotting import from_networkx
|
9 |
+
|
10 |
+
def create_bokeh_plot(entities, relationships):
|
11 |
+
# Create a NetworkX graph
|
12 |
+
G = nx.Graph()
|
13 |
+
for entity_id, entity_data in entities.items():
|
14 |
+
G.add_node(entity_id, label=f"{entity_data['value']} ({entity_data['type']})")
|
15 |
+
for source, relation, target in relationships:
|
16 |
+
G.add_edge(source, target, label=relation)
|
17 |
+
|
18 |
+
plot = Plot(width=600, height=600, # Increased size for better visibility
|
19 |
+
x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
|
20 |
+
plot.title.text = "Knowledge Graph Interaction"
|
21 |
+
|
22 |
+
# Use tooltips to show node and edge labels on hover
|
23 |
+
node_hover = HoverTool(tooltips=[("Entity", "@label")])
|
24 |
+
edge_hover = HoverTool(tooltips=[("Relation", "@label")])
|
25 |
+
plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
|
26 |
+
|
27 |
+
graph_renderer = from_networkx(G, nx.spring_layout, scale=1,k=0.5, iterations=50, center=(0, 0))
|
28 |
+
|
29 |
+
graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
|
30 |
+
graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
|
31 |
+
graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])
|
32 |
+
|
33 |
+
graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
|
34 |
+
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
|
35 |
+
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)
|
36 |
+
|
37 |
+
graph_renderer.selection_policy = NodesAndLinkedEdges()
|
38 |
+
graph_renderer.inspection_policy = NodesAndLinkedEdges()
|
39 |
+
|
40 |
+
plot.renderers.append(graph_renderer)
|
41 |
+
|
42 |
+
# Add node labels
|
43 |
+
x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
|
44 |
+
node_labels = nx.get_node_attributes(G, 'label')
|
45 |
+
source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
|
46 |
+
labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
|
47 |
+
text_font_size='8pt', background_fill_alpha=0.7)
|
48 |
+
plot.renderers.append(labels)
|
49 |
+
|
50 |
+
# Add edge labels
|
51 |
+
edge_x = []
|
52 |
+
edge_y = []
|
53 |
+
edge_labels = []
|
54 |
+
for (start_node, end_node, label) in G.edges(data='label'):
|
55 |
+
start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
|
56 |
+
end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
|
57 |
+
edge_x.append((start_x + end_x) / 2)
|
58 |
+
edge_y.append((start_y + end_y) / 2)
|
59 |
+
edge_labels.append(label)
|
60 |
+
|
61 |
+
edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
|
62 |
+
edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
|
63 |
+
background_fill_color='white', text_font_size='8pt',
|
64 |
+
background_fill_alpha=0.7)
|
65 |
+
plot.renderers.append(edge_labels)
|
66 |
+
|
67 |
+
return plot
|
68 |
+
|
69 |
+
# def create_bokeh_plot(entities, relationships):
|
70 |
+
# # Create a NetworkX graph
|
71 |
+
# G = nx.Graph()
|
72 |
+
# for entity_id, entity_data in entities.items():
|
73 |
+
# G.add_node(entity_id, **entity_data)
|
74 |
+
# for source, relation, target in relationships:
|
75 |
+
# G.add_edge(source, target)
|
76 |
+
|
77 |
+
# # Create a Bokeh plot
|
78 |
+
# plot = figure(title="Knowledge Graph", x_range=(-1.1,1.1), y_range=(-1.1,1.1),
|
79 |
+
# width=400, height=400, tools="pan,wheel_zoom,box_zoom,reset")
|
80 |
+
|
81 |
+
# # Create graph renderer
|
82 |
+
# graph_renderer = from_networkx(G, nx.spring_layout, scale=1, center=(0,0))
|
83 |
+
|
84 |
+
# # Add graph renderer to plot
|
85 |
+
# plot.renderers.append(graph_renderer)
|
86 |
+
|
87 |
+
# return plot
|
88 |
|
89 |
+
def create_plotly_plot(entities, relationships):
|
90 |
G = nx.DiGraph() # Use DiGraph for directed edges
|
91 |
|
92 |
for entity_id, entity_data in entities.items():
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ matplotlib==3.7.2
|
|
4 |
torch==2.0.1
|
5 |
transformers==4.43.3
|
6 |
accelerate==0.33.0
|
7 |
-
networkx
|
|
|
|
4 |
torch==2.0.1
|
5 |
transformers==4.43.3
|
6 |
accelerate==0.33.0
|
7 |
+
networkx
|
8 |
+
bokeh==3.5.1
|