File size: 2,430 Bytes
30ffb9e
 
 
 
 
a6d45c6
30ffb9e
 
 
 
 
 
 
 
 
 
 
 
88b4a61
 
 
30ffb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
a6d45c6
30ffb9e
 
30eb437
 
30ffb9e
 
 
 
 
 
 
30eb437
 
 
 
30ffb9e
 
30eb437
 
 
 
 
30ffb9e
 
30eb437
 
 
 
 
 
 
 
 
30ffb9e
30eb437
 
 
30ffb9e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#%%
import os, time, io, zipfile
from preprocessing import FileIO
import shutil
import modal
import streamlit as st
from llama_index.finetuning import EmbeddingQAFinetuneDataset

from dotenv import load_dotenv, find_dotenv
env = load_dotenv(find_dotenv('env'), override=True)

#%%
training_path = 'data/training_data_300.json'
valid_path = 'data/validation_data_100.json'

training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)

def finetune(model: str='sentence-transformers/all-mpnet-base-v2', 
             savemodel: bool =False, 
             outpath: str='.'):
    """ Finetunes a model on Modal GPU A100.
        The model is saved in /root/models on a Modal volume 
        and can be stored locally.

    Args:
        model (str): the Sentence Transformer model name
        savemodel (bool, optional): whether to save the model or not.

    Returns:
        path of the saved model (when saved)
    """
    f = modal.Function.lookup("vector-search-project", "finetune")
    
    if 'sentence-transformers' not in model:
        model = model.replace('/','')
        model = f"sentence-transformers/{model}"

    fullpath = os.path.join(outpath, f"finetuned-{model.split('/')[-1]}-300")
    # st.sidebar.write(f"Model saved in {fullpath}")
    
    if os.path.exists(fullpath):
        msg = "Model already exists!"
        print(msg)
        return msg

    start = time.perf_counter()
    try:
        finetuned_model = f.remote(training_path, valid_path, model_id=model)
    except FunctionTimeoutError:
        return "Timeout!"  # will be displayed by app.py

    end = time.perf_counter() - start
    
    # with st.sidebar:
    #     c1,c2 = st.columns([8,1])
    #     with c1: 
    st.sidebar.write(f"Finetuning with GPU lasted {end:.2f} seconds")

    if savemodel:
    
        # save it as zip filess
        # with open(fullpath+'.zip', 'wb') as file:
        #     # Write the contents of the BytesIO object to a new file
        #     file.write(finetuned_model.getbuffer())
        # print(f"Model zip file saved in {fullpath}")
        # zipfile.ZipExtFile(finetuned_model) # to unzip 
        # import sys
        # sys.getsizeof(zippedio)
        
        # or save as folder directly
        zipfile.ZipFile(finetuned_model).extractall(fullpath)
            
        print(f"Model saved in {fullpath}")
        return fullpath