File size: 2,878 Bytes
faf4679
bcea21a
faf4679
 
16b83b8
 
 
 
1a8333a
 
 
 
 
bcea21a
16b83b8
 
 
 
faf4679
 
16b83b8
 
 
 
1a8333a
16b83b8
bcea21a
faf4679
 
 
 
bcea21a
 
faf4679
 
 
 
1a8333a
faf4679
bcea21a
 
 
 
 
 
 
 
16b83b8
 
1a8333a
16b83b8
 
bcea21a
 
 
 
 
 
 
 
1a8333a
bcea21a
 
 
 
 
 
 
 
 
 
faf4679
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import requests
from transformers import AutoModel

def model_file_exists_and_valid(model_file_path):
    # Check if the model file exists and has a size greater than 0
    return os.path.exists(model_file_path) and os.path.getsize(model_file_path) > 0

def write_model_path_to_txt_file(model_file_path):
    # Write the model path to model_path.txt
    with open('/home/user/data/models/model_path.txt', 'w') as f:
        f.write(model_file_path)

def download_hf_model():
    '''
    Model File Path for HF Models: The download_hf_model function now includes a default model file path (pytorch_model.bin) check. 
    Adjust this path based on the expected model file type (e.g., TensorFlow or Flax models might have different names).
    '''
    model_name = os.getenv("HF_MODEL_NAME")
    model_dir = f"/home/user/data/models/{model_name}"
    model_file_path = os.path.join(model_dir, "pytorch_model.bin") # Assuming PyTorch model for simplicity
    
    if model_file_exists_and_valid(model_file_path):
        print(f"Model {model_name} already downloaded.")
        write_model_path_to_txt_file(model_file_path)
        return
    
    # Authenticate with Hugging Face using the token, if available
    hf_token = os.getenv("HF_TOKEN")
    if hf_token:
        from huggingface_hub import HfFolder
        HfFolder.save_token(hf_token)
    
    print(f"Downloading model: {model_name}...")
    model = AutoModel.from_pretrained(model_name)
    model.save_pretrained(model_dir)
    print(f"Model {model_name} downloaded and saved to {model_dir}")
    write_model_path_to_txt_file(model_file_path)

def download_gguf_model():
    model_name = os.getenv("HF_MODEL_NAME")
    model_dir = f"/home/user/data/models/{model_name}"
    os.makedirs(model_dir, exist_ok=True)
    
    model_url = os.getenv("GGUF_MODEL_URL")  # Assuming URL is provided as an env variable
    model_file_path = os.path.join(model_dir, os.path.basename(model_url))
    
    if model_file_exists_and_valid(model_file_path):
        print(f"Model {model_name} already downloaded.")
        write_model_path_to_txt_file(model_file_path)
        return
    
    print(f"Downloading model from {model_url}...")
    response = requests.get(model_url, stream=True)
    if response.status_code == 200:
        with open(model_file_path, 'wb') as f:
            f.write(response.content)
        print(f"Model downloaded and saved to {model_file_path}")
    else:
        print(f"Failed to download the model. Status code: {response.status_code}")
    write_model_path_to_txt_file(model_file_path)

def download_model():
    model_class = os.getenv("MODEL_CLASS")
    if model_class == 'gguf':
        download_gguf_model()
    elif model_class == 'hf':
        download_hf_model()
    else:
        print(f"Unsupported model class: {model_class}")

if __name__ == "__main__":
    download_model()