File size: 3,256 Bytes
bd8cfdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a677d37
bd8cfdf
 
 
 
 
 
 
 
 
 
c157f5a
bd8cfdf
c157f5a
bd8cfdf
 
 
c157f5a
 
 
 
 
 
1d479e0
 
 
c157f5a
a677d37
c157f5a
 
81860c4
 
 
c157f5a
 
 
bd8cfdf
 
 
 
 
 
 
 
a677d37
bd8cfdf
 
 
 
 
 
 
 
c157f5a
 
 
 
 
 
 
bd8cfdf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr

import os
import glob
import cv2
import numpy as np
import torch
from rxnscribe import RxnScribe

from huggingface_hub import hf_hub_download

REPO_ID = "yujieq/RxnScribe"
FILENAME = "pix2seq_reaction_full.ckpt"
ckpt_path = hf_hub_download(REPO_ID, FILENAME)

device = torch.device('cpu')
model = RxnScribe(ckpt_path, device)


def get_markdown(reaction):
    output = []
    for x in ['reactants', 'conditions', 'products']:
        s = ''
        for ent in reaction[x]:
            if 'smiles' in ent:
                s += "\n```\n" + ent['smiles'] + "\n```\n"
            elif 'text' in ent:
                s += ' '.join(ent['text']) + '<br>'
            else:
                s += ent['category']
        output.append(s)
    return output


def predict(image, molscribe, ocr):
    predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr)
    pred_image = model.draw_predictions_combined(predictions, image=image)
    markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)]
    return pred_image, markdown


with gr.Blocks() as demo:
    gr.Markdown("""
    <center> <h1>RxnScribe</h1> </center>

    Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram. 

    The predicted reactions are visualized in separate images. 
    <b style="color:red">Red boxes are <i><u style="color:red">reactants</u></i>.</b>
    <b style="color:green">Green boxes are <i><u style="color:green">reaction conditions</u></i>.</b>
    <b style="color:blue">Blue boxes are <i><u style="color:blue">products</u></i>.</b>

    It usually takes 5-10 seconds to process a diagram with this demo.
    Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course).

    Paper: [RxnScribe: A Sequence Generation Model for Reaction Diagram Parsing](https://pubs.acs.org/doi/10.1021/acs.jcim.3c00439)
    
    Code: [https://github.com/thomas0809/RxnScribe](https://github.com/thomas0809/RxnScribe)

    Authors: [Yujie Qian](mailto:[email protected]), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_.
    """)
    with gr.Column():
        with gr.Row():
            image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256)
        with gr.Row():
            molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures")
            ocr = gr.Checkbox(label="Run OCR to recognize text")
            btn = gr.Button("Submit").style(full_width=False)
        with gr.Row():
            gallery = gr.Image(label='Predicted reactions', show_label=True).style(height="auto")
            markdown = gr.Dataframe(
                headers=['#', 'reactant', 'condition', 'product'],
                datatype=['number'] + ['markdown'] * 3,
                wrap=False
            )

    btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown])

    gr.Examples(
        examples=sorted(glob.glob('examples/*.png')),
        inputs=[image],
        outputs=[gallery, markdown],
        fn=predict,
    )

demo.launch()