# Ultralytics YOLO 🚀, AGPL-3.0 license import getpass from typing import List import cv2 import numpy as np import pandas as pd from ultralytics.data.augment import LetterBox from ultralytics.utils import LOGGER as logger from ultralytics.utils import SETTINGS from ultralytics.utils.checks import check_requirements from ultralytics.utils.ops import xyxy2xywh from ultralytics.utils.plotting import plot_images def get_table_schema(vector_size): """Extracts and returns the schema of a database table.""" from lancedb.pydantic import LanceModel, Vector class Schema(LanceModel): im_file: str labels: List[str] cls: List[int] bboxes: List[List[float]] masks: List[List[List[int]]] keypoints: List[List[List[float]]] vector: Vector(vector_size) return Schema def get_sim_index_schema(): """Returns a LanceModel schema for a database table with specified vector size.""" from lancedb.pydantic import LanceModel class Schema(LanceModel): idx: int im_file: str count: int sim_im_files: List[str] return Schema def sanitize_batch(batch, dataset_info): """Sanitizes input batch for inference, ensuring correct format and dimensions.""" batch["cls"] = batch["cls"].flatten().int().tolist() box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1]) batch["bboxes"] = [box for box, _ in box_cls_pair] batch["cls"] = [cls for _, cls in box_cls_pair] batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]] batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]] batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]] return batch def plot_query_result(similar_set, plot_labels=True): """ Plot images from the similar set. Args: similar_set (list): Pyarrow or pandas object containing the similar data points plot_labels (bool): Whether to plot labels or not """ similar_set = ( similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict() ) empty_masks = [[[]]] empty_boxes = [[]] images = similar_set.get("im_file", []) bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else [] masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else [] kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else [] cls = similar_set.get("cls", []) plot_size = 640 imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], [] for i, imf in enumerate(images): im = cv2.imread(imf) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) h, w = im.shape[:2] r = min(plot_size / h, plot_size / w) imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1)) if plot_labels: if len(bboxes) > i and len(bboxes[i]) > 0: box = np.array(bboxes[i], dtype=np.float32) box[:, [0, 2]] *= r box[:, [1, 3]] *= r plot_boxes.append(box) if len(masks) > i and len(masks[i]) > 0: mask = np.array(masks[i], dtype=np.uint8)[0] plot_masks.append(LetterBox(plot_size, center=False)(image=mask)) if len(kpts) > i and kpts[i] is not None: kpt = np.array(kpts[i], dtype=np.float32) kpt[:, :, :2] *= r plot_kpts.append(kpt) batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i) imgs = np.stack(imgs, axis=0) masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8) kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32) boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32) batch_idx = np.concatenate(batch_idx, axis=0) cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0) return plot_images( imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False ) def prompt_sql_query(query): """Plots images with optional labels from a similar data set.""" check_requirements("openai>=1.6.1") from openai import OpenAI if not SETTINGS["openai_api_key"]: logger.warning("OpenAI API key not found in settings. Please enter your API key below.") openai_api_key = getpass.getpass("OpenAI API key: ") SETTINGS.update({"openai_api_key": openai_api_key}) openai = OpenAI(api_key=SETTINGS["openai_api_key"]) messages = [ { "role": "system", "content": """ You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on the following schema and a user request. You only need to output the format with fixed selection statement that selects everything from "'table'", like `SELECT * from 'table'` Schema: im_file: string not null labels: list not null child 0, item: string cls: list not null child 0, item: int64 bboxes: list> not null child 0, item: list child 0, item: double masks: list>> not null child 0, item: list> child 0, item: list child 0, item: int64 keypoints: list>> not null child 0, item: list> child 0, item: list child 0, item: double vector: fixed_size_list[256] not null child 0, item: float Some details about the schema: - the "labels" column contains the string values like 'person' and 'dog' for the respective objects in each image - the "cls" column contains the integer values on these classes that map them the labels Example of a correct query: request - Get all data points that contain 2 or more people and at least one dog correct query- SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1; """, }, {"role": "user", "content": f"{query}"}, ] response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages) return response.choices[0].message.content