dtyago commited on
Commit
16b83b8
1 Parent(s): 89100e4

Set the environment variable MODEL_PATH at entrypoint

Browse files
Files changed (1) hide show
  1. app/utils/download_model.py +21 -1
app/utils/download_model.py CHANGED
@@ -2,9 +2,23 @@ import os
2
  import requests
3
  from transformers import AutoModel
4
 
 
 
 
 
5
  def download_hf_model():
 
 
 
 
6
  model_name = os.getenv("HF_MODEL_NAME")
7
  model_dir = f"/home/user/data/models/{model_name}"
 
 
 
 
 
 
8
 
9
  # Authenticate with Hugging Face using the token, if available
10
  hf_token = os.getenv("HF_TOKEN")
@@ -16,6 +30,7 @@ def download_hf_model():
16
  model = AutoModel.from_pretrained(model_name)
17
  model.save_pretrained(model_dir)
18
  print(f"Model {model_name} downloaded and saved to {model_dir}")
 
19
 
20
  def download_gguf_model():
21
  model_name = os.getenv("HF_MODEL_NAME")
@@ -23,9 +38,13 @@ def download_gguf_model():
23
  os.makedirs(model_dir, exist_ok=True)
24
 
25
  model_url = os.getenv("GGUF_MODEL_URL") # Assuming URL is provided as an env variable
26
-
27
  model_file_path = os.path.join(model_dir, os.path.basename(model_url))
28
 
 
 
 
 
 
29
  print(f"Downloading model from {model_url}...")
30
  response = requests.get(model_url, stream=True)
31
  if response.status_code == 200:
@@ -34,6 +53,7 @@ def download_gguf_model():
34
  print(f"Model downloaded and saved to {model_file_path}")
35
  else:
36
  print(f"Failed to download the model. Status code: {response.status_code}")
 
37
 
38
  def download_model():
39
  model_class = os.getenv("MODEL_CLASS")
 
2
  import requests
3
  from transformers import AutoModel
4
 
5
+ def model_file_exists_and_valid(model_file_path):
6
+ # Check if the model file exists and has a size greater than 0
7
+ return os.path.exists(model_file_path) and os.path.getsize(model_file_path) > 0
8
+
9
  def download_hf_model():
10
+ '''
11
+ Model File Path for HF Models: The download_hf_model function now includes a default model file path (pytorch_model.bin) check.
12
+ Adjust this path based on the expected model file type (e.g., TensorFlow or Flax models might have different names).
13
+ '''
14
  model_name = os.getenv("HF_MODEL_NAME")
15
  model_dir = f"/home/user/data/models/{model_name}"
16
+ model_file_path = os.path.join(model_dir, "pytorch_model.bin") # Assuming PyTorch model for simplicity
17
+
18
+ if model_file_exists_and_valid(model_file_path):
19
+ print(f"Model {model_name} already downloaded.")
20
+ os.environ['MODEL_PATH'] = model_file_path
21
+ return
22
 
23
  # Authenticate with Hugging Face using the token, if available
24
  hf_token = os.getenv("HF_TOKEN")
 
30
  model = AutoModel.from_pretrained(model_name)
31
  model.save_pretrained(model_dir)
32
  print(f"Model {model_name} downloaded and saved to {model_dir}")
33
+ os.environ['MODEL_PATH'] = model_file_path
34
 
35
  def download_gguf_model():
36
  model_name = os.getenv("HF_MODEL_NAME")
 
38
  os.makedirs(model_dir, exist_ok=True)
39
 
40
  model_url = os.getenv("GGUF_MODEL_URL") # Assuming URL is provided as an env variable
 
41
  model_file_path = os.path.join(model_dir, os.path.basename(model_url))
42
 
43
+ if model_file_exists_and_valid(model_file_path):
44
+ print(f"Model {model_name} already downloaded.")
45
+ os.environ['MODEL_PATH'] = model_file_path
46
+ return
47
+
48
  print(f"Downloading model from {model_url}...")
49
  response = requests.get(model_url, stream=True)
50
  if response.status_code == 200:
 
53
  print(f"Model downloaded and saved to {model_file_path}")
54
  else:
55
  print(f"Failed to download the model. Status code: {response.status_code}")
56
+ os.environ['MODEL_PATH'] = model_file_path
57
 
58
  def download_model():
59
  model_class = os.getenv("MODEL_CLASS")