File size: 2,590 Bytes
55b1220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# + tags=["hide_inp"]

desc = """
### Agent

Chain that executes different tools based on model decisions. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb)

(Adapted from LangChain )
"""
# -

# $

from minichain import Id, prompt, OpenAI, show, transform, Mock, Break
from gradio_tools.tools import StableDiffusionTool, ImageCaptioningTool, ImageToMusicTool


# class ImageCaptioningTool:
#     def run(self, inp):
#         return "This is a picture of a smiling huggingface logo."

#     description = "Image Captioning"

tools = [StableDiffusionTool(), ImageCaptioningTool(), ImageToMusicTool()]


@prompt(OpenAI(stop=["Observation:"]),
        template_file="agent.pmpt.tpl")
def agent(model, query, history):
    return model(dict(tools=[(str(tool.__class__.__name__), tool.description)
                             for tool in tools],
                      input=query,
                      agent_scratchpad=history
                      ))
@transform()
def tool_parse(out):
    lines = out.split("\n")
    if lines[0].split("?")[-1].strip() == "Yes":
        tool = lines[1].split(":", 1)[-1].strip()
        command = lines[2].split(":", 1)[-1].strip()
        return tool, command
    else:
        return Break()

@prompt(tools)
def tool_use(model, usage):
    selector, command = usage
    for i, tool in enumerate(tools):
        if selector == tool.__class__.__name__:
            return model(command, tool_num=i)
    return ("",)

@transform()
def append(history, new, observation):
    return history + "\n" + new + "Observation: " + observation

def run(query):
    history = ""
    observations = []
    for i in range(3):
        select_input = agent(query, history)
        observations.append(tool_use(tool_parse(select_input)))
        history = append(history, select_input, observations[i])

    return observations[-1]

# $

gradio = show(run,
              subprompts=[agent, tool_use] * 3,
              examples=[
                  "I would please like a photo of a dog riding a skateboard. "
                  "Please caption this image and create a song for it.",
                  'Use an image generator tool to draw a cat.',
                  'Caption the image https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png from the internet'],
              out_type="markdown",
              description=desc,
              show_advanced=False
              )
if __name__ == "__main__":
    gradio.queue().launch()