File size: 7,041 Bytes
53ad959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Ultralytics YOLO 🚀, AGPL-3.0 license

import getpass
from typing import List

import cv2
import numpy as np
import pandas as pd

from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER as logger
from ultralytics.utils import SETTINGS
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.ops import xyxy2xywh
from ultralytics.utils.plotting import plot_images


def get_table_schema(vector_size):
    """Extracts and returns the schema of a database table."""
    from lancedb.pydantic import LanceModel, Vector

    class Schema(LanceModel):
        im_file: str
        labels: List[str]
        cls: List[int]
        bboxes: List[List[float]]
        masks: List[List[List[int]]]
        keypoints: List[List[List[float]]]
        vector: Vector(vector_size)

    return Schema


def get_sim_index_schema():
    """Returns a LanceModel schema for a database table with specified vector size."""
    from lancedb.pydantic import LanceModel

    class Schema(LanceModel):
        idx: int
        im_file: str
        count: int
        sim_im_files: List[str]

    return Schema


def sanitize_batch(batch, dataset_info):
    """Sanitizes input batch for inference, ensuring correct format and dimensions."""
    batch["cls"] = batch["cls"].flatten().int().tolist()
    box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
    batch["bboxes"] = [box for box, _ in box_cls_pair]
    batch["cls"] = [cls for _, cls in box_cls_pair]
    batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
    batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
    batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
    return batch


def plot_query_result(similar_set, plot_labels=True):
    """
    Plot images from the similar set.

    Args:
        similar_set (list): Pyarrow or pandas object containing the similar data points
        plot_labels (bool): Whether to plot labels or not
    """
    similar_set = (
        similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
    )
    empty_masks = [[[]]]
    empty_boxes = [[]]
    images = similar_set.get("im_file", [])
    bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
    masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
    kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
    cls = similar_set.get("cls", [])

    plot_size = 640
    imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
    for i, imf in enumerate(images):
        im = cv2.imread(imf)
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        h, w = im.shape[:2]
        r = min(plot_size / h, plot_size / w)
        imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
        if plot_labels:
            if len(bboxes) > i and len(bboxes[i]) > 0:
                box = np.array(bboxes[i], dtype=np.float32)
                box[:, [0, 2]] *= r
                box[:, [1, 3]] *= r
                plot_boxes.append(box)
            if len(masks) > i and len(masks[i]) > 0:
                mask = np.array(masks[i], dtype=np.uint8)[0]
                plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
            if len(kpts) > i and kpts[i] is not None:
                kpt = np.array(kpts[i], dtype=np.float32)
                kpt[:, :, :2] *= r
                plot_kpts.append(kpt)
        batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
    imgs = np.stack(imgs, axis=0)
    masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8)
    kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32)
    boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32)
    batch_idx = np.concatenate(batch_idx, axis=0)
    cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)

    return plot_images(
        imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
    )


def prompt_sql_query(query):
    """Plots images with optional labels from a similar data set."""
    check_requirements("openai>=1.6.1")
    from openai import OpenAI

    if not SETTINGS["openai_api_key"]:
        logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
        openai_api_key = getpass.getpass("OpenAI API key: ")
        SETTINGS.update({"openai_api_key": openai_api_key})
    openai = OpenAI(api_key=SETTINGS["openai_api_key"])

    messages = [
        {
            "role": "system",
            "content": """
                You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
                the following schema and a user request. You only need to output the format with fixed selection
                statement that selects everything from "'table'", like `SELECT * from 'table'`

                Schema:
                im_file: string not null
                labels: list<item: string> not null
                child 0, item: string
                cls: list<item: int64> not null
                child 0, item: int64
                bboxes: list<item: list<item: double>> not null
                child 0, item: list<item: double>
                    child 0, item: double
                masks: list<item: list<item: list<item: int64>>> not null
                child 0, item: list<item: list<item: int64>>
                    child 0, item: list<item: int64>
                        child 0, item: int64
                keypoints: list<item: list<item: list<item: double>>> not null
                child 0, item: list<item: list<item: double>>
                    child 0, item: list<item: double>
                        child 0, item: double
                vector: fixed_size_list<item: float>[256] not null
                child 0, item: float

                Some details about the schema:
                - the "labels" column contains the string values like 'person' and 'dog' for the respective objects
                    in each image
                - the "cls" column contains the integer values on these classes that map them the labels

                Example of a correct query:
                request - Get all data points that contain 2 or more people and at least one dog
                correct query-
                SELECT * FROM 'table' WHERE  ARRAY_LENGTH(cls) >= 2  AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2  AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
             """,
        },
        {"role": "user", "content": f"{query}"},
    ]

    response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
    return response.choices[0].message.content