meg-huggingface commited on
Commit
20fd212
1 Parent(s): 5c33832

Background scheduling of the evaluation.

Browse files
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  from src.logging import configure_root_logger
3
  logging.getLogger("httpx").setLevel(logging.WARNING)
@@ -8,7 +9,7 @@ configure_root_logger()
8
  from functools import partial
9
 
10
  import gradio as gr
11
- from main_backend_toxicity import run_auto_eval
12
  from src.display.log_visualizer import log_file_to_html_string
13
  from src.display.css_html_js import dark_mode_gradio_js
14
  from src.envs import REFRESH_RATE, REPO_ID, QUEUE_REPO, RESULTS_REPO
@@ -32,28 +33,35 @@ links_md = f"""
32
  | Results Repo | [{RESULTS_REPO}](https://huggingface.co/datasets/{RESULTS_REPO}) |
33
  """
34
 
35
- def button_auto_eval():
36
- logger.info("Manually triggering Auto Eval")
37
- run_auto_eval()
38
-
39
 
40
  reverse_order_checkbox = gr.Checkbox(label="Reverse Order", value=False)
41
 
42
- with gr.Blocks(js=dark_mode_gradio_js) as demo:
43
  gr.Markdown(intro_md)
44
  with gr.Tab("Application"):
45
- output_html = gr.HTML(partial(log_file_to_html_string, reverse=reverse_order_checkbox), every=1)
 
46
  with gr.Row():
47
- download_button = gr.DownloadButton("Download Log File", value=log_file)
 
48
  with gr.Accordion('Log View Configuration', open=False):
49
  reverse_order_checkbox.render()
50
  # Add a button that when pressed, triggers run_auto_eval
51
  button = gr.Button("Manually Run Evaluation")
52
  gr.Markdown(links_md)
53
- button.click(fn=button_auto_eval, inputs=[], outputs=[])
54
-
55
- dummy = gr.Markdown(run_auto_eval, every=REFRESH_RATE, visible=False)
 
 
56
 
57
  if __name__ == '__main__':
58
- demo.queue(default_concurrency_limit=40).launch(server_name="0.0.0.0",
59
- show_error=True, server_port=7860)
 
 
 
 
 
1
+ from apscheduler.schedulers.background import BackgroundScheduler
2
  import logging
3
  from src.logging import configure_root_logger
4
  logging.getLogger("httpx").setLevel(logging.WARNING)
 
9
  from functools import partial
10
 
11
  import gradio as gr
12
+ import main_backend_toxicity
13
  from src.display.log_visualizer import log_file_to_html_string
14
  from src.display.css_html_js import dark_mode_gradio_js
15
  from src.envs import REFRESH_RATE, REPO_ID, QUEUE_REPO, RESULTS_REPO
 
33
  | Results Repo | [{RESULTS_REPO}](https://huggingface.co/datasets/{RESULTS_REPO}) |
34
  """
35
 
36
+ def auto_eval():
37
+ logger.info("Triggering Auto Eval")
38
+ main_backend_toxicity.run_auto_eval()
 
39
 
40
  reverse_order_checkbox = gr.Checkbox(label="Reverse Order", value=False)
41
 
42
+ with gr.Blocks(js=dark_mode_gradio_js) as backend_ui:
43
  gr.Markdown(intro_md)
44
  with gr.Tab("Application"):
45
+ output_html = gr.HTML(partial(log_file_to_html_string,
46
+ reverse=reverse_order_checkbox), every=10)
47
  with gr.Row():
48
+ download_button = gr.DownloadButton("Download Log File",
49
+ value=log_file)
50
  with gr.Accordion('Log View Configuration', open=False):
51
  reverse_order_checkbox.render()
52
  # Add a button that when pressed, triggers run_auto_eval
53
  button = gr.Button("Manually Run Evaluation")
54
  gr.Markdown(links_md)
55
+ # This will run the eval before fully loading the UI,
56
+ # and the UI will error out if it takes longer than 30 seconds.
57
+ # Changing to use BackgroundScheduler instead.
58
+ # dummy = gr.Markdown(main_backend_toxicity.run_auto_eval(), every=REFRESH_RATE, visible=False)
59
+ button.click(fn=auto_eval, inputs=[], outputs=[])
60
 
61
  if __name__ == '__main__':
62
+ scheduler = BackgroundScheduler()
63
+ scheduler.add_job(auto_eval, "interval", seconds=REFRESH_RATE)
64
+ scheduler.start()
65
+ backend_ui.queue(default_concurrency_limit=40).launch(server_name="0.0.0.0",
66
+ show_error=True,
67
+ server_port=7860)
src/backend/inference_endpoint.py CHANGED
@@ -10,6 +10,7 @@ import requests
10
  logging.basicConfig(level=logging.DEBUG)
11
  logger = setup_logger(__name__)
12
  TIMEOUT = 20
 
13
 
14
 
15
  def create_endpoint(endpoint_name, repository, framework='pytorch',
@@ -26,7 +27,8 @@ def create_endpoint(endpoint_name, repository, framework='pytorch',
26
  vendor=vendor, region=region,
27
  type=type,
28
  instance_size=instance_size,
29
- instance_type=instance_type)
 
30
  except huggingface_hub.utils._errors.HfHubHTTPError as e:
31
  # Workload with the same name already exists error.
32
  # Use it again, just make sure it has the right settings.
@@ -38,7 +40,8 @@ def create_endpoint(endpoint_name, repository, framework='pytorch',
38
  framework=framework, task=task,
39
  accelerator=accelerator,
40
  instance_size=instance_size,
41
- instance_type=instance_type)
 
42
  except requests.exceptions.HTTPError as e:
43
  # Not enough compute, wrong compute, or quota exceeded
44
  logger.debug("Hit error:")
@@ -92,9 +95,11 @@ def update_endpoint_exception(endpoint):
92
  cur_instance_size = raw_info['compute']['instanceSize']
93
  cur_instance_type = raw_info['compute']['instanceType']
94
  if (cur_instance_type, cur_instance_size) == ('nvidia-l4', 'x4'):
95
- endpoint.update(instance_size='x1', instance_type='nvidia-a100')
 
96
  elif (cur_instance_type, cur_instance_size) == ('a100', 'x1'):
97
- endpoint.update(instance_size='x4', instance_type='nvidia-a10g')
 
98
  else:
99
  logger.info(
100
  "Getting expensive to try to run this model without human oversight. Exiting.")
 
10
  logging.basicConfig(level=logging.DEBUG)
11
  logger = setup_logger(__name__)
12
  TIMEOUT = 20
13
+ MAX_REPLICA = 3
14
 
15
 
16
  def create_endpoint(endpoint_name, repository, framework='pytorch',
 
27
  vendor=vendor, region=region,
28
  type=type,
29
  instance_size=instance_size,
30
+ instance_type=instance_type,
31
+ max_replica=MAX_REPLICA)
32
  except huggingface_hub.utils._errors.HfHubHTTPError as e:
33
  # Workload with the same name already exists error.
34
  # Use it again, just make sure it has the right settings.
 
40
  framework=framework, task=task,
41
  accelerator=accelerator,
42
  instance_size=instance_size,
43
+ instance_type=instance_type,
44
+ max_replica=MAX_REPLICA)
45
  except requests.exceptions.HTTPError as e:
46
  # Not enough compute, wrong compute, or quota exceeded
47
  logger.debug("Hit error:")
 
95
  cur_instance_size = raw_info['compute']['instanceSize']
96
  cur_instance_type = raw_info['compute']['instanceType']
97
  if (cur_instance_type, cur_instance_size) == ('nvidia-l4', 'x4'):
98
+ endpoint.update(instance_size='x1', instance_type='nvidia-a100',
99
+ max_replica=MAX_REPLICA)
100
  elif (cur_instance_type, cur_instance_size) == ('a100', 'x1'):
101
+ endpoint.update(instance_size='x4', instance_type='nvidia-a10g',
102
+ max_replica=MAX_REPLICA)
103
  else:
104
  logger.info(
105
  "Getting expensive to try to run this model without human oversight. Exiting.")
src/backend/manage_requests.py CHANGED
@@ -91,7 +91,6 @@ def get_eval_requests(job_status: list, local_dir: str, hf_repo: str) -> list[Ev
91
  # TODO: isn't job_status the string "RUNNING"?
92
  if data["status"] in job_status:
93
  data["json_filepath"] = json_filepath
94
- print(data.items())
95
  eval_request = EvalRequest(**data)
96
  eval_requests.append(eval_request)
97
 
 
91
  # TODO: isn't job_status the string "RUNNING"?
92
  if data["status"] in job_status:
93
  data["json_filepath"] = json_filepath
 
94
  eval_request = EvalRequest(**data)
95
  eval_requests.append(eval_request)
96
 
src/backend/run_toxicity_eval.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  from datetime import datetime
6
  import sys
7
  from tqdm import tqdm
 
8
 
9
  import requests
10
  from requests.adapters import HTTPAdapter, Retry
@@ -167,7 +168,8 @@ def main(endpoint_url, eval_request):
167
  ds = load_dataset("allenai/real-toxicity-prompts")
168
  prompts = [row['text'] for row in ds['train']['prompt']]
169
  # All the generated responses from the endpoint
170
- generated_responses = map(lambda x: get_generation(endpoint_url, x), prompts[:DATASET_CUTOFF])
 
171
  att_scores_out = score_generations(prompts, generated_responses)
172
  logger.debug("Scores are:")
173
  logger.debug(att_scores_out)
 
5
  from datetime import datetime
6
  import sys
7
  from tqdm import tqdm
8
+ from multiprocessing import Pool
9
 
10
  import requests
11
  from requests.adapters import HTTPAdapter, Retry
 
168
  ds = load_dataset("allenai/real-toxicity-prompts")
169
  prompts = [row['text'] for row in ds['train']['prompt']]
170
  # All the generated responses from the endpoint
171
+ with Pool() as pool:
172
+ generated_responses = pool.map(lambda x: get_generation(endpoint_url, x), prompts[:DATASET_CUTOFF])
173
  att_scores_out = score_generations(prompts, generated_responses)
174
  logger.debug("Scores are:")
175
  logger.debug(att_scores_out)