File size: 5,360 Bytes
d0bd9ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import modules.scripts as scripts
import modules.prompt_parser as prompt_parser
import itertools
import torch


def hijacked_get_learned_conditioning(model, prompts, steps):
    global real_get_learned_conditioning

    if not hasattr(model, '__hacked'):
        real_model_func = model.get_learned_conditioning

        def hijacked_model_func(texts):
            weighted_prompts = list(map(lambda t: get_weighted_prompt((t, 1)), texts))
            all_texts = []
            for weighted_prompt in weighted_prompts:
                for (prompt, weight) in weighted_prompt:
                    all_texts.append(prompt)

            if len(all_texts) > len(texts):
                all_conds = real_model_func(all_texts)
                offset = 0

                conds = []

                for weighted_prompt in weighted_prompts:
                    c = torch.zeros_like(all_conds[offset])
                    for (i, (prompt, weight)) in enumerate(weighted_prompt):
                        c = torch.add(c, all_conds[i+offset], alpha=weight)
                    conds.append(c)
                    offset += len(weighted_prompt)
                return conds
            else:
                return real_model_func(texts)

        model.get_learned_conditioning = hijacked_model_func
        model.__hacked = True

    switched_prompts = list(map(lambda p: switch_syntax(p), prompts))
    return real_get_learned_conditioning(model, switched_prompts, steps)


real_get_learned_conditioning = hijacked_get_learned_conditioning  # no really, overriden below


class Script(scripts.Script):
    def title(self):
        return "Prompt Blending"

    def show(self, is_img2img):
        global real_get_learned_conditioning
        if real_get_learned_conditioning == hijacked_get_learned_conditioning:
            real_get_learned_conditioning = prompt_parser.get_learned_conditioning
            prompt_parser.get_learned_conditioning = hijacked_get_learned_conditioning
        return False

    def ui(self, is_img2img):
        return []

    def run(self, p, seeds):
        return


OPEN = '{'
CLOSE = '}'
SEPARATE = '|'
MARK = '@'
REAL_MARK = ':'


def combine(left, right):
    return map(lambda p: (p[0][0] + p[1][0], p[0][1] * p[1][1]), itertools.product(left, right))


def get_weighted_prompt(prompt_weight):
    (prompt, full_weight) = prompt_weight
    results = [('', full_weight)]
    alts = []
    start = 0
    mark = -1
    open_count = 0
    first_open = 0
    nested = False

    for i, c in enumerate(prompt):
        add_alt = False
        do_combine = False
        if c == OPEN:
            open_count += 1
            if open_count == 1:
                first_open = i
                results = list(combine(results, [(prompt[start:i], 1)]))
                start = i + 1
            else:
                nested = True

        if c == MARK and open_count == 1:
            mark = i

        if c == SEPARATE and open_count == 1:
            add_alt = True

        if c == CLOSE:
            open_count -= 1
            if open_count == 0:
                add_alt = True
                do_combine = True
        if i == len(prompt) - 1 and open_count > 0:
            add_alt = True
            do_combine = True

        if add_alt:
            end = i
            weight = 1
            if mark != -1:
                weight_str = prompt[mark + 1:i]
                try:
                    weight = float(weight_str)
                    end = mark
                except ValueError:
                    print("warning, not a number:", weight_str)



            alt = (prompt[start:end], weight)
            alts += get_weighted_prompt(alt) if nested else [alt]
            nested = False
            mark = -1
            start = i + 1

        if do_combine:
            if len(alts) <= 1:
                alts = [(prompt[first_open:i + 1], 1)]

            results = list(combine(results, alts))
            alts = []

    # rest of the prompt
    results = list(combine(results, [(prompt[start:], 1)]))
    weight_sum = sum(map(lambda r: r[1], results))
    results = list(map(lambda p: (p[0], p[1] / weight_sum * full_weight), results))

    return results


def switch_syntax(prompt):
    p = list(prompt)
    stack = []
    for i, c in enumerate(p):
        if c == '{' or c == '[' or c == '(':
            stack.append(c)

        if len(stack) > 0:
            if c == '}' or c == ']' or c == ')':
                stack.pop()

        if c == REAL_MARK and stack[-1] == '{':
            p[i] = MARK

    return "".join(p)

# def test(p, w=1):
#     print('')
#     print(p)
#     result = get_weighted_prompt((p, w))
#     print(result)
#     print(sum(map(lambda x: x[1], result)))
#
#
# test("fantasy landscape")
# test("fantasy {landscape|city}, dark")
# test("fantasy {landscape|city}, {fire|ice} ")
# test("fantasy {landscape|city}, {fire|ice}, {dark|light} ")
# test("fantasy landscape, {{fire|lava}|ice}")
# test("fantasy landscape, {{fire@4|lava@1}|ice@2}")
# test("fantasy landscape, {{fire@error|lava@1}|ice@2}")
# test("fantasy landscape, {{fire|lava}|ice@2")
# test("fantasy landscape, {fire|lava} {cool} {ice,water}")
# test("fantasy landscape, {fire|lava} {cool} {ice,water")
# test("{lava|ice|water@5}")
# test("{fire@4|lava@1}", 5)
# test("{{fire@4|lava@1}|ice@2|water@5}")
# test("{fire|[email protected]}")