File size: 3,256 Bytes
bd8cfdf a677d37 bd8cfdf c157f5a bd8cfdf c157f5a bd8cfdf c157f5a 1d479e0 c157f5a a677d37 c157f5a 81860c4 c157f5a bd8cfdf a677d37 bd8cfdf c157f5a bd8cfdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import os
import glob
import cv2
import numpy as np
import torch
from rxnscribe import RxnScribe
from huggingface_hub import hf_hub_download
REPO_ID = "yujieq/RxnScribe"
FILENAME = "pix2seq_reaction_full.ckpt"
ckpt_path = hf_hub_download(REPO_ID, FILENAME)
device = torch.device('cpu')
model = RxnScribe(ckpt_path, device)
def get_markdown(reaction):
output = []
for x in ['reactants', 'conditions', 'products']:
s = ''
for ent in reaction[x]:
if 'smiles' in ent:
s += "\n```\n" + ent['smiles'] + "\n```\n"
elif 'text' in ent:
s += ' '.join(ent['text']) + '<br>'
else:
s += ent['category']
output.append(s)
return output
def predict(image, molscribe, ocr):
predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr)
pred_image = model.draw_predictions_combined(predictions, image=image)
markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)]
return pred_image, markdown
with gr.Blocks() as demo:
gr.Markdown("""
<center> <h1>RxnScribe</h1> </center>
Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram.
The predicted reactions are visualized in separate images.
<b style="color:red">Red boxes are <i><u style="color:red">reactants</u></i>.</b>
<b style="color:green">Green boxes are <i><u style="color:green">reaction conditions</u></i>.</b>
<b style="color:blue">Blue boxes are <i><u style="color:blue">products</u></i>.</b>
It usually takes 5-10 seconds to process a diagram with this demo.
Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course).
Paper: [RxnScribe: A Sequence Generation Model for Reaction Diagram Parsing](https://pubs.acs.org/doi/10.1021/acs.jcim.3c00439)
Code: [https://github.com/thomas0809/RxnScribe](https://github.com/thomas0809/RxnScribe)
Authors: [Yujie Qian](mailto:[email protected]), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_.
""")
with gr.Column():
with gr.Row():
image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256)
with gr.Row():
molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures")
ocr = gr.Checkbox(label="Run OCR to recognize text")
btn = gr.Button("Submit").style(full_width=False)
with gr.Row():
gallery = gr.Image(label='Predicted reactions', show_label=True).style(height="auto")
markdown = gr.Dataframe(
headers=['#', 'reactant', 'condition', 'product'],
datatype=['number'] + ['markdown'] * 3,
wrap=False
)
btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown])
gr.Examples(
examples=sorted(glob.glob('examples/*.png')),
inputs=[image],
outputs=[gallery, markdown],
fn=predict,
)
demo.launch()
|