|
import gradio as gr |
|
|
|
import os |
|
import glob |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from rxnscribe import RxnScribe |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
REPO_ID = "yujieq/RxnScribe" |
|
FILENAME = "pix2seq_reaction_full.ckpt" |
|
ckpt_path = hf_hub_download(REPO_ID, FILENAME) |
|
|
|
device = torch.device('cpu') |
|
model = RxnScribe(ckpt_path, device) |
|
|
|
|
|
def get_markdown(reaction): |
|
output = [] |
|
for x in ['reactants', 'conditions', 'products']: |
|
s = '' |
|
for ent in reaction[x]: |
|
if 'smiles' in ent: |
|
s += "\n```\n" + ent['smiles'] + "\n```\n" |
|
elif 'text' in ent: |
|
s += ' '.join(ent['text']) + '<br>' |
|
else: |
|
s += ent['category'] |
|
output.append(s) |
|
return output |
|
|
|
|
|
def predict(image, molscribe, ocr): |
|
predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr) |
|
pred_image = model.draw_predictions_combined(predictions, image=image) |
|
markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)] |
|
return pred_image, markdown |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
<center> <h1>RxnScribe</h1> </center> |
|
|
|
Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram. |
|
|
|
The predicted reactions are visualized in separate images. |
|
<b style="color:red">Red boxes are <i><u style="color:red">reactants</u></i>.</b> |
|
<b style="color:green">Green boxes are <i><u style="color:green">reaction conditions</u></i>.</b> |
|
<b style="color:blue">Blue boxes are <i><u style="color:blue">products</u></i>.</b> |
|
|
|
It usually takes 5-10 seconds to process a diagram with this demo. |
|
Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course). |
|
|
|
Paper: [RxnScribe: A Sequence Generation Model for Reaction Diagram Parsing](https://pubs.acs.org/doi/10.1021/acs.jcim.3c00439) |
|
|
|
Code: [https://github.com/thomas0809/RxnScribe](https://github.com/thomas0809/RxnScribe) |
|
|
|
Authors: [Yujie Qian](mailto:[email protected]), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_. |
|
""") |
|
with gr.Column(): |
|
with gr.Row(): |
|
image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256) |
|
with gr.Row(): |
|
molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures") |
|
ocr = gr.Checkbox(label="Run OCR to recognize text") |
|
btn = gr.Button("Submit").style(full_width=False) |
|
with gr.Row(): |
|
gallery = gr.Image(label='Predicted reactions', show_label=True).style(height="auto") |
|
markdown = gr.Dataframe( |
|
headers=['#', 'reactant', 'condition', 'product'], |
|
datatype=['number'] + ['markdown'] * 3, |
|
wrap=False |
|
) |
|
|
|
btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown]) |
|
|
|
gr.Examples( |
|
examples=sorted(glob.glob('examples/*.png')), |
|
inputs=[image], |
|
outputs=[gallery, markdown], |
|
fn=predict, |
|
) |
|
|
|
demo.launch() |
|
|