{ "cells": [ { "cell_type": "code", "execution_count": 14, "id": "d024645c", "metadata": {}, "outputs": [], "source": [ "__import__('pysqlite3')\n", "import sys\n", "import os\n", "sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')\n", "os.environ['ALLOW_RESET'] = 'True'\n", "\n", "import pandas as pd\n", "from tqdm import tqdm\n", "import time\n", "\n", "import chromadb\n", "from gigachat import GigaChat\n", "\n", "client = chromadb.PersistentClient(path='db')\n", "collection = client.get_collection(name=\"administrative_codex\")" ] }, { "cell_type": "code", "execution_count": 25, "id": "17dae6a5", "metadata": {}, "outputs": [], "source": [ "docs = collection.get()['documents']\n", "prompt = 'Задание: напиши в виде нумерованного списка 3 конкретных независимых друг от друга вопроса, ответ на которые можно найти в приведенном тексте. Не упоминай федеральные законы. Не упоминай КоАП.'" ] }, { "cell_type": "markdown", "id": "91549726-3c7a-44ef-8519-c1afc3adde0f", "metadata": {}, "source": [ "### Генерируем вопросы к каждому фрагменту текста" ] }, { "cell_type": "code", "execution_count": 29, "id": "06f82948", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████████████████████████████████████| 2130/2130 [54:44<00:00, 1.54s/it]\n" ] } ], "source": [ "for doc in tqdm(docs[1:]):\n", " question_ready = False\n", " \n", " while not question_ready:\n", " try:\n", " text = f'{doc}\\n\\n{prompt}'\n", " \n", " with GigaChat(credentials='N2ZiNDIxZTgtM2Y4Yy00MGJjLWI4OTgtN2M5NGM5MTYzZTNiOmFmYjJmZTUwLTc1OWItNGQ5MC1iMGVmLTMwYTNlODU3YzVmZg==', scope='GIGACHAT_API_PERS', verify_ssl_certs=False) as giga:\n", " questions = giga.chat(text).choices[0].message.content\n", " \n", " question_ready = True\n", " except:\n", " time.sleep(5)\n", "\n", " df = pd.read_csv('generated_questions.csv')\n", " new_df = pd.DataFrame({'text': [doc], 'questions': [questions]})\n", " pd.concat([df, new_df], ignore_index=True).to_csv('generated_questions.csv', index=False)" ] }, { "cell_type": "code", "execution_count": 31, "id": "2f44eac2-7ce0-4d26-9f4a-41f5bfe0fa44", "metadata": {}, "outputs": [], "source": [ "generated_questions_df = pd.read_csv('generated_questions.csv')\n", "docs = generated_questions_df['text'].tolist()\n", "generated_questions = generated_questions_df['questions'].tolist()\n", "\n", "prompt = 'В России. Дай подробный ответ текстом, похожим на закон, не пиши ничего лишнего.'" ] }, { "cell_type": "markdown", "id": "90b543f8-0b94-4c0c-8a69-9574b7c54db9", "metadata": {}, "source": [ "### Генерируем ответы к вопросам, в которых есть слово штраф" ] }, { "cell_type": "code", "execution_count": 39, "id": "a7937078-6200-44ba-b43f-4867e947b750", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████████████████████████████████████| 1978/1978 [50:30<00:00, 1.53s/it]\n" ] } ], "source": [ "for doc, g_questions in zip(tqdm(docs[153:]), generated_questions[153:]):\n", " llm_answer_ready = False\n", " fine_question = ''\n", "\n", " for question in g_questions.split('\\n'):\n", " question = ' '.join(question.split()[1:])\n", " \n", " if 'штраф' in question:\n", " fine_question = question\n", " break\n", "\n", " if not fine_question:\n", " continue\n", " \n", " while not llm_answer_ready:\n", " try:\n", " text = f'Помоги, пожалуйста. {fine_question} {prompt}'\n", " \n", " with GigaChat(credentials='MmU3OTdhNmItMTQzYy00NGQzLWEyYTctZjcxOWJmYThiMWE5OmE1ZDdhNDkxLWI5ZTEtNGFkZS04N2JjLTExZjE5MTYwNGQ5Yg==', scope='GIGACHAT_API_PERS', verify_ssl_certs=False) as giga:\n", " llm_answer = giga.chat(text).choices[0].message.content.split('\\n')[0]\n", " \n", " llm_answer_ready = True\n", " except:\n", " time.sleep(5)\n", "\n", " \n", " if len(llm_answer) > 100:\n", " df = pd.read_csv('generated_additional_llm_answer.csv')\n", " new_df = pd.DataFrame({'text': [doc], 'question': [fine_question], 'llm_answer': [llm_answer]})\n", " pd.concat([df, new_df], ignore_index=True).to_csv('generated_additional_llm_answer.csv', index=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "5fd64855-01b5-4c66-a425-b6d91b355a22", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }