File size: 3,964 Bytes
61047f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from typing import List
import numpy as np
import redis
import google.generativeai as genai
from tqdm import tqdm
import time

from redis.commands.search.field import (
    TagField,
    TextField,
    VectorField,
)
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
from sourcegraph import Sourcegraph


INDEX_NAME = "idx:codes_vss"

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

generation_config = {
  "temperature": 1,
  "top_p": 0.95,
  "top_k": 64,
  "max_output_tokens": 8192,
  "response_mime_type": "text/plain",
}

model = genai.GenerativeModel(
  model_name="gemini-1.5-flash",
  generation_config=generation_config,
  system_instruction="You are optimized to generate accurate descriptions for given Python codes. When the user inputs the code, you must return the description according to its goal and functionality.  You are not allowed to generate additional details. The user expects at least 5 sentence-long descriptions.",
)

def fetch_data(url):
    def get_description(code):
      chat_session = model.start_chat(
        history=[
          {
            "role": "user",
            "parts": [
              f"Code: {code}",
            ],
          },
        ]
      )
      response = chat_session.send_message("INSERT_INPUT_HERE")

      return response.text
    gihub_repository = Sourcegraph(url)
    gihub_repository.run()
    data = dict(gihub_repository.node_data)
    for key, value in tqdm(data.items()):
      data[key]['description'] = get_description(value['definition'])
      data[key]['uses'] = ", ".join(list(gihub_repository.get_dependencies(key)))
      time.sleep(3) #to overcome limit issues
    return data

def get_embeddings(content: List):
    return genai.embed_content(model='models/text-embedding-004',content=content)['embedding']

def ingest_data(client: redis.Redis, data):
    try:
       client.delete(client.keys("code:*"))
    except:
       pass
    pipeline = client.pipeline()
    for i, code_metadata in enumerate(data.values(), start=1):
        redis_key = f"code:{i:03}"
        pipeline.json().set(redis_key, "$", code_metadata)
    _ = pipeline.execute()
    keys = sorted(client.keys("code:*"))
    defs = client.json().mget(keys, "$.definition")
    descs = client.json().mget(keys, "$.description")
    embed_inputs = []

    for i in range(1, len(keys)+1):
        embed_inputs.append(
            f"""{defs[i-1][0]}\n\n{descs[i-1][0]}"""
        )
    embeddings = get_embeddings(embed_inputs)
    VECTOR_DIMENSION = len(embeddings[0])
    pipeline = client.pipeline()
    for key, embedding in zip(keys, embeddings):
        pipeline.json().set(key, "$.embeddings", embedding)
    pipeline.execute()

    schema = (
        TextField("$.name", no_stem=True, as_name="name"),
        TagField("$.type", as_name="type"),
        TextField("$.definition", no_stem=True, as_name="definition"),
        TextField("$.file_name", no_stem=True, as_name="file_name"),
        TextField("$.description", no_stem=True, as_name="description"),
        TextField("$.uses", no_stem=True, as_name="uses"),
        VectorField(
            "$.embeddings",
            "HNSW",
            {
                "TYPE": "FLOAT32",
                "DIM": VECTOR_DIMENSION,
                "DISTANCE_METRIC": "COSINE",
            },
            as_name="vector",
        ),
    )
    definition = IndexDefinition(prefix=["code:"], index_type=IndexType.JSON)
    try:
       _ = client.ft(INDEX_NAME).create_index(fields=schema, definition=definition)
    except redis.exceptions.ResponseError:
       client.ft(INDEX_NAME).dropindex()
       _ = client.ft(INDEX_NAME).create_index(fields=schema, definition=definition)
    info = client.ft(INDEX_NAME).info()
    num_docs = info["num_docs"]
    indexing_failures = info["hash_indexing_failures"]
    return f"{num_docs} documents indexed with {indexing_failures} failures"