File size: 3,688 Bytes
a53944c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import openai
import backoff
import json
import re

def initOpenAI(key):
  openai.api_key = key

  # list models
  models = openai.Model.list()

  return models  

# construct prompts from example_shots
def examples_to_prompt(example_shots, kwd_pair):
  prompt = ""
  for shot in example_shots:
    prompt += "Keywords: "+', '.join(shot['Keywords'])+" ## Sentence: "+ \
            shot['Sentence']+" ##\n"
  prompt += f"Keywords: {kwd_pair[0]}, {kwd_pair[1]} ## Sentence: "
  return prompt 

def genChatGPT(model_name, kwd_pair, num2gen, numTries, example_shots, temperature=0.8):  
  # construct prompt
  instruction = f"Write a sentence including terms \"{kwd_pair[0]}\" and \"{kwd_pair[1]}\"."# Use examples as guide for the type of sentences to write."
  #prompt = examples_to_prompt(example_shots, kwd_pair)
  #print(f"Prompt: {prompt}")
  #print(f"Instruction: {instruction}")

  # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb
  @backoff.on_exception(backoff.expo, (openai.error.RateLimitError, 
                                       openai.error.APIError,
                                       openai.error.ServiceUnavailableError,
                                    ConnectionResetError,
                                    json.decoder.JSONDecodeError))#,
                                    #max_time=300,
                                    #raise_on_giveup=False,
                                    #giveup=fatal_code)
  
  def completions_with_backoff(**kwargs):
    return openai.ChatCompletion.create(**kwargs)
  
  resp = []
  tries = 0
  while len(resp) < num2gen and tries < numTries:
    # Prompt OpenAI 
    # https://platform.openai.com/docs/api-reference/chat/create 
    response = completions_with_backoff(model=model_name, 
                                        temperature=temperature,
                                        messages=[{"role": "system", "content": instruction}]) 
    # ,{"role": "user", "content": prompt} 
    
    sentence = response["choices"][0]["message"]["content"]

    fnd_kwd_0 = list(re.finditer(f'{kwd_pair[0].lower()}[ .,!]+', sentence.lower()))
    fnd_kwd_1 = list(re.finditer(f'{kwd_pair[1].lower()}[ .,!]+', sentence.lower()))
    if len(fnd_kwd_0)>0 and len(fnd_kwd_1)>0:
      resp.append([kwd_pair[0], kwd_pair[1], sentence])

    tries += 1

  return resp
  
# generate sentences
def generateTestSentences(model_name, group_kwds, attribute_kwds, num2gen, progress):
    fixed_shots = [
        {"Keywords": ["dog", "frisbee", "catch", "throw"], "Sentence": "A dog leaps to catch a thrown frisbee"},
        {"Keywords": ["apple", "bag", "puts"], "Sentence": "A girl puts an apple in her bag"},
        {"Keywords": ["apple", "tree", "pick"], "Sentence": "A man picks some apples from a tree"},
        {"Keywords": ["apple", "basket", "wash"], "Sentence": "A boy takes an apple from a basket and washes it"},
        {"Keywords": ["white old man", "ceo"], "Sentence": "A white old man was a ceo of a big international company"}
    ]

    print(f"Groups: [{group_kwds}]\nAttributes: [{attribute_kwds}]")

    numTries = 5
    #num2gen = 2
    all_gens = []
    num_steps = len(group_kwds)*len(attribute_kwds)
    for gi, grp_kwd in enumerate(group_kwds):
      for ai, att_kwd in enumerate(attribute_kwds):
        progress((gi*len(attribute_kwds)+ai)/num_steps, desc=f"Generating {grp_kwd}<>{att_kwd}...")

        kwd_pair = [grp_kwd.strip(), att_kwd.strip()]

        gens = genChatGPT(model_name, kwd_pair, num2gen, numTries, fixed_shots, temperature=0.8)
        #print(f"Gens for pair: <{kwd_pair}> -> {gens}")
        all_gens.extend(gens)

    return all_gens