import random import gradio as gr import networkx as nx from lib.graph_extract import triplextract, parse_triples from lib.visualize import create_graph, create_bokeh_plot, create_plotly_plot from lib.samples import snippets WORD_LIMIT = 300 def process_text(text, entity_types, predicates, layout_type, visualization_type): if not text: return None, None, "Please enter some text." words = text.split() if len(words) > WORD_LIMIT: return None, None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}" entity_types = [et.strip() for et in entity_types.split(",") if et.strip()] predicates = [p.strip() for p in predicates.split(",") if p.strip()] if not entity_types: return None, None, "Please enter at least one entity type." if not predicates: return None, None, "Please enter at least one predicate." try: prediction = triplextract(text, entity_types, predicates) if prediction.startswith("Error"): return None, None, prediction entities, relationships = parse_triples(prediction) if not entities and not relationships: return None, None, "No entities or relationships found. Try different text or check your input." G = create_graph(entities, relationships) if visualization_type == 'Bokeh': fig = create_bokeh_plot(G, layout_type) else: fig = create_plotly_plot(G, layout_type) output_text = f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}" return G, fig, output_text except Exception as e: print(f"Error in process_text: {str(e)}") return None, None, f"An error occurred: {str(e)}" def update_graph(G, layout_type, visualization_type): if G is None: return None, "Please process text first." try: if visualization_type == 'Bokeh': fig = create_bokeh_plot(G, layout_type) else: fig = create_plotly_plot(G, layout_type) return fig, "" except Exception as e: print(f"Error in update_graph: {e}") return None, f"An error occurred while updating the graph: {str(e)}" def update_inputs(sample_name): sample = snippets[sample_name] return sample.text_input, sample.entity_types, sample.predicates with gr.Blocks(theme=gr.themes.Monochrome()) as demo: gr.Markdown("# Knowledge Graph Extractor") default_sample_name = random.choice(list(snippets.keys())) default_sample = snippets[default_sample_name] with gr.Row(): with gr.Column(scale=1): sample_dropdown = gr.Dropdown(choices=list(snippets.keys()), label="Select Sample", value=default_sample_name) input_text = gr.Textbox(label="Input Text", lines=5, value=default_sample.text_input) entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types) predicates = gr.Textbox(label="Predicates", value=default_sample.predicates) layout_type = gr.Dropdown(choices=['spring', 'fruchterman_reingold', 'circular', 'random', 'spectral', 'shell'], label="Layout Type", value='spring') visualization_type = gr.Radio(choices=['Bokeh', 'Plotly'], label="Visualization Type", value='Bokeh') process_btn = gr.Button("Process Text") with gr.Column(scale=2): output_graph = gr.Plot(label="Knowledge Graph") error_message = gr.Textbox(label="Textual Output") graph_state = gr.State(None) def process_and_update(text, entity_types, predicates, layout_type, visualization_type): G, fig, output = process_text(text, entity_types, predicates, layout_type, visualization_type) return G, fig, output def update_graph_wrapper(G, layout_type, visualization_type): if G is not None: fig, _ = update_graph(G, layout_type, visualization_type) return fig sample_dropdown.change(update_inputs, inputs=[sample_dropdown], outputs=[input_text, entity_types, predicates]) process_btn.click(process_and_update, inputs=[input_text, entity_types, predicates, layout_type, visualization_type], outputs=[graph_state, output_graph, error_message]) layout_type.change(update_graph_wrapper, inputs=[graph_state, layout_type, visualization_type], outputs=[output_graph]) visualization_type.change(update_graph_wrapper, inputs=[graph_state, layout_type, visualization_type], outputs=[output_graph]) if __name__ == "__main__": demo.launch(share=True)