Remsky commited on
Commit
b26b502
β€’
1 Parent(s): 4289090

Swapped to bokeh, changed up samples

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. app.py +2 -3
  3. lib/graph_extract.py +2 -0
  4. lib/samples.py +4 -4
  5. lib/visualize.py +84 -2
  6. requirements.txt +2 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -5,12 +5,11 @@ import gradio as gr
5
  import spaces
6
 
7
  from lib.graph_extract import triplextract, parse_triples
8
- from lib.visualize import create_cytoscape_plot
9
  from lib.samples import snippets
10
 
11
  WORD_LIMIT = 300
12
 
13
- @spaces.GPU
14
  def process_text(text, entity_types, predicates):
15
  if not text:
16
  return None, "Please enter some text."
@@ -40,7 +39,7 @@ def process_text(text, entity_types, predicates):
40
  "No entities or relationships found. Try different text or check your input.",
41
  )
42
 
43
- fig = create_cytoscape_plot(entities, relationships)
44
  return (
45
  fig,
46
  f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}",
 
5
  import spaces
6
 
7
  from lib.graph_extract import triplextract, parse_triples
8
+ from lib.visualize import create_bokeh_plot #, create_plotly_plot
9
  from lib.samples import snippets
10
 
11
  WORD_LIMIT = 300
12
 
 
13
  def process_text(text, entity_types, predicates):
14
  if not text:
15
  return None, "Please enter some text."
 
39
  "No entities or relationships found. Try different text or check your input.",
40
  )
41
 
42
+ fig = create_bokeh_plot(entities, relationships)
43
  return (
44
  fig,
45
  f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}",
lib/graph_extract.py CHANGED
@@ -54,6 +54,8 @@ print("Model and tokenizer loaded successfully.")
54
  generation_config = GenerationConfig.from_pretrained("sciphi/triplex")
55
  generation_config.max_length = 2048
56
  generation_config.pad_token_id = tokenizer.eos_token_id
 
 
57
  @spaces.GPU
58
  def triplextract(text, entity_types, predicates):
59
  input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
 
54
  generation_config = GenerationConfig.from_pretrained("sciphi/triplex")
55
  generation_config.max_length = 2048
56
  generation_config.pad_token_id = tokenizer.eos_token_id
57
+
58
+
59
  @spaces.GPU
60
  def triplextract(text, entity_types, predicates):
61
  input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
lib/samples.py CHANGED
@@ -19,8 +19,8 @@ snippets = {
19
  we were all going direct to Heaven, we were all going direct the other way – in short,
20
  the period was so far like the present period, that some of its noisiest authorities
21
  insisted on its being received, for good or for evil, in the superlative degree of comparison only.""",
22
- entity_types="TIME, EMOTION, LOCATION, EVENT, OUTCOME, PLACE",
23
- predicates="WAS, HAD, WERE"
24
  ),
25
 
26
  'tech_company': Snippet(
@@ -29,8 +29,8 @@ snippets = {
29
  software, and online services. The company's flagship products include the iPhone smartphone,
30
  iPad tablet, and Mac personal computer. As of 2023, Apple has over 150,000 employees worldwide
31
  and generates annual revenue exceeding $350 billion.""",
32
- entity_types="COMPANY, PERSON, PRODUCT, LOCATION, DATE, NUMBER, EVENT, SUBJECT",
33
- predicates="FOUNDED, HEADQUARTERED_IN, PRODUCES, HAS, EMPLOYEES, "
34
  ),
35
 
36
  'climate_change': Snippet(
 
19
  we were all going direct to Heaven, we were all going direct the other way – in short,
20
  the period was so far like the present period, that some of its noisiest authorities
21
  insisted on its being received, for good or for evil, in the superlative degree of comparison only.""",
22
+ entity_types="EMOTION, EVENT, OUTCOME, PLACE",
23
+ predicates="WAS, HAD, WERE, IS"
24
  ),
25
 
26
  'tech_company': Snippet(
 
29
  software, and online services. The company's flagship products include the iPhone smartphone,
30
  iPad tablet, and Mac personal computer. As of 2023, Apple has over 150,000 employees worldwide
31
  and generates annual revenue exceeding $350 billion.""",
32
+ entity_types="COMPANY, PERSON, PRODUCT, LOCATION, DATE, EVENT",
33
+ predicates="FOUNDED, PRODUCES, HAS, IN, EMPLOYEES, "
34
  ),
35
 
36
  'climate_change': Snippet(
lib/visualize.py CHANGED
@@ -1,10 +1,92 @@
1
  import plotly.graph_objects as go
2
  import networkx as nx
3
 
4
- import plotly.graph_objects as go
5
  import networkx as nx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def create_cytoscape_plot(entities, relationships):
8
  G = nx.DiGraph() # Use DiGraph for directed edges
9
 
10
  for entity_id, entity_data in entities.items():
 
1
  import plotly.graph_objects as go
2
  import networkx as nx
3
 
 
4
  import networkx as nx
5
+ from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
6
+ Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
7
+ from bokeh.palettes import Spectral4
8
+ from bokeh.plotting import from_networkx
9
+
10
+ def create_bokeh_plot(entities, relationships):
11
+ # Create a NetworkX graph
12
+ G = nx.Graph()
13
+ for entity_id, entity_data in entities.items():
14
+ G.add_node(entity_id, label=f"{entity_data['value']} ({entity_data['type']})")
15
+ for source, relation, target in relationships:
16
+ G.add_edge(source, target, label=relation)
17
+
18
+ plot = Plot(width=600, height=600, # Increased size for better visibility
19
+ x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
20
+ plot.title.text = "Knowledge Graph Interaction"
21
+
22
+ # Use tooltips to show node and edge labels on hover
23
+ node_hover = HoverTool(tooltips=[("Entity", "@label")])
24
+ edge_hover = HoverTool(tooltips=[("Relation", "@label")])
25
+ plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
26
+
27
+ graph_renderer = from_networkx(G, nx.spring_layout, scale=1,k=0.5, iterations=50, center=(0, 0))
28
+
29
+ graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
30
+ graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
31
+ graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])
32
+
33
+ graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
34
+ graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
35
+ graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)
36
+
37
+ graph_renderer.selection_policy = NodesAndLinkedEdges()
38
+ graph_renderer.inspection_policy = NodesAndLinkedEdges()
39
+
40
+ plot.renderers.append(graph_renderer)
41
+
42
+ # Add node labels
43
+ x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
44
+ node_labels = nx.get_node_attributes(G, 'label')
45
+ source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
46
+ labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
47
+ text_font_size='8pt', background_fill_alpha=0.7)
48
+ plot.renderers.append(labels)
49
+
50
+ # Add edge labels
51
+ edge_x = []
52
+ edge_y = []
53
+ edge_labels = []
54
+ for (start_node, end_node, label) in G.edges(data='label'):
55
+ start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
56
+ end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
57
+ edge_x.append((start_x + end_x) / 2)
58
+ edge_y.append((start_y + end_y) / 2)
59
+ edge_labels.append(label)
60
+
61
+ edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
62
+ edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
63
+ background_fill_color='white', text_font_size='8pt',
64
+ background_fill_alpha=0.7)
65
+ plot.renderers.append(edge_labels)
66
+
67
+ return plot
68
+
69
+ # def create_bokeh_plot(entities, relationships):
70
+ # # Create a NetworkX graph
71
+ # G = nx.Graph()
72
+ # for entity_id, entity_data in entities.items():
73
+ # G.add_node(entity_id, **entity_data)
74
+ # for source, relation, target in relationships:
75
+ # G.add_edge(source, target)
76
+
77
+ # # Create a Bokeh plot
78
+ # plot = figure(title="Knowledge Graph", x_range=(-1.1,1.1), y_range=(-1.1,1.1),
79
+ # width=400, height=400, tools="pan,wheel_zoom,box_zoom,reset")
80
+
81
+ # # Create graph renderer
82
+ # graph_renderer = from_networkx(G, nx.spring_layout, scale=1, center=(0,0))
83
+
84
+ # # Add graph renderer to plot
85
+ # plot.renderers.append(graph_renderer)
86
+
87
+ # return plot
88
 
89
+ def create_plotly_plot(entities, relationships):
90
  G = nx.DiGraph() # Use DiGraph for directed edges
91
 
92
  for entity_id, entity_data in entities.items():
requirements.txt CHANGED
@@ -4,4 +4,5 @@ matplotlib==3.7.2
4
  torch==2.0.1
5
  transformers==4.43.3
6
  accelerate==0.33.0
7
- networkx
 
 
4
  torch==2.0.1
5
  transformers==4.43.3
6
  accelerate==0.33.0
7
+ networkx
8
+ bokeh==3.5.1