#%% import os, time, io, zipfile from preprocessing import FileIO import shutil import modal import streamlit as st from llama_index.finetuning import EmbeddingQAFinetuneDataset from dotenv import load_dotenv, find_dotenv env = load_dotenv(find_dotenv('env'), override=True) #%% training_path = 'data/training_data_300.json' valid_path = 'data/validation_data_100.json' training_set = EmbeddingQAFinetuneDataset.from_json(training_path) valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path) def finetune(model: str='sentence-transformers/all-mpnet-base-v2', savemodel: bool =False, outpath: str='.'): """ Finetunes a model on Modal GPU A100. The model is saved in /root/models on a Modal volume and can be stored locally. Args: model (str): the Sentence Transformer model name savemodel (bool, optional): whether to save the model or not. Returns: path of the saved model (when saved) """ f = modal.Function.lookup("vector-search-project", "finetune") if 'sentence-transformers' not in model: model = model.replace('/','') model = f"sentence-transformers/{model}" fullpath = os.path.join(outpath, f"finetuned-{model.split('/')[-1]}-300") # st.sidebar.write(f"Model saved in {fullpath}") if os.path.exists(fullpath): msg = "Model already exists!" print(msg) return msg start = time.perf_counter() try: finetuned_model = f.remote(training_path, valid_path, model_id=model) except FunctionTimeoutError: return "Timeout!" # will be displayed by app.py end = time.perf_counter() - start # with st.sidebar: # c1,c2 = st.columns([8,1]) # with c1: st.sidebar.write(f"Finetuning with GPU lasted {end:.2f} seconds") if savemodel: # save it as zip filess # with open(fullpath+'.zip', 'wb') as file: # # Write the contents of the BytesIO object to a new file # file.write(finetuned_model.getbuffer()) # print(f"Model zip file saved in {fullpath}") # zipfile.ZipExtFile(finetuned_model) # to unzip # import sys # sys.getsizeof(zippedio) # or save as folder directly zipfile.ZipFile(finetuned_model).extractall(fullpath) print(f"Model saved in {fullpath}") return fullpath