meg-huggingface
commited on
Commit
•
20fd212
1
Parent(s):
5c33832
Background scheduling of the evaluation.
Browse files- app.py +21 -13
- src/backend/inference_endpoint.py +9 -4
- src/backend/manage_requests.py +0 -1
- src/backend/run_toxicity_eval.py +3 -1
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import logging
|
2 |
from src.logging import configure_root_logger
|
3 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
@@ -8,7 +9,7 @@ configure_root_logger()
|
|
8 |
from functools import partial
|
9 |
|
10 |
import gradio as gr
|
11 |
-
|
12 |
from src.display.log_visualizer import log_file_to_html_string
|
13 |
from src.display.css_html_js import dark_mode_gradio_js
|
14 |
from src.envs import REFRESH_RATE, REPO_ID, QUEUE_REPO, RESULTS_REPO
|
@@ -32,28 +33,35 @@ links_md = f"""
|
|
32 |
| Results Repo | [{RESULTS_REPO}](https://huggingface.co/datasets/{RESULTS_REPO}) |
|
33 |
"""
|
34 |
|
35 |
-
def
|
36 |
-
logger.info("
|
37 |
-
run_auto_eval()
|
38 |
-
|
39 |
|
40 |
reverse_order_checkbox = gr.Checkbox(label="Reverse Order", value=False)
|
41 |
|
42 |
-
with gr.Blocks(js=dark_mode_gradio_js) as
|
43 |
gr.Markdown(intro_md)
|
44 |
with gr.Tab("Application"):
|
45 |
-
output_html = gr.HTML(partial(log_file_to_html_string,
|
|
|
46 |
with gr.Row():
|
47 |
-
download_button = gr.DownloadButton("Download Log File",
|
|
|
48 |
with gr.Accordion('Log View Configuration', open=False):
|
49 |
reverse_order_checkbox.render()
|
50 |
# Add a button that when pressed, triggers run_auto_eval
|
51 |
button = gr.Button("Manually Run Evaluation")
|
52 |
gr.Markdown(links_md)
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
|
57 |
if __name__ == '__main__':
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
2 |
import logging
|
3 |
from src.logging import configure_root_logger
|
4 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
|
9 |
from functools import partial
|
10 |
|
11 |
import gradio as gr
|
12 |
+
import main_backend_toxicity
|
13 |
from src.display.log_visualizer import log_file_to_html_string
|
14 |
from src.display.css_html_js import dark_mode_gradio_js
|
15 |
from src.envs import REFRESH_RATE, REPO_ID, QUEUE_REPO, RESULTS_REPO
|
|
|
33 |
| Results Repo | [{RESULTS_REPO}](https://huggingface.co/datasets/{RESULTS_REPO}) |
|
34 |
"""
|
35 |
|
36 |
+
def auto_eval():
|
37 |
+
logger.info("Triggering Auto Eval")
|
38 |
+
main_backend_toxicity.run_auto_eval()
|
|
|
39 |
|
40 |
reverse_order_checkbox = gr.Checkbox(label="Reverse Order", value=False)
|
41 |
|
42 |
+
with gr.Blocks(js=dark_mode_gradio_js) as backend_ui:
|
43 |
gr.Markdown(intro_md)
|
44 |
with gr.Tab("Application"):
|
45 |
+
output_html = gr.HTML(partial(log_file_to_html_string,
|
46 |
+
reverse=reverse_order_checkbox), every=10)
|
47 |
with gr.Row():
|
48 |
+
download_button = gr.DownloadButton("Download Log File",
|
49 |
+
value=log_file)
|
50 |
with gr.Accordion('Log View Configuration', open=False):
|
51 |
reverse_order_checkbox.render()
|
52 |
# Add a button that when pressed, triggers run_auto_eval
|
53 |
button = gr.Button("Manually Run Evaluation")
|
54 |
gr.Markdown(links_md)
|
55 |
+
# This will run the eval before fully loading the UI,
|
56 |
+
# and the UI will error out if it takes longer than 30 seconds.
|
57 |
+
# Changing to use BackgroundScheduler instead.
|
58 |
+
# dummy = gr.Markdown(main_backend_toxicity.run_auto_eval(), every=REFRESH_RATE, visible=False)
|
59 |
+
button.click(fn=auto_eval, inputs=[], outputs=[])
|
60 |
|
61 |
if __name__ == '__main__':
|
62 |
+
scheduler = BackgroundScheduler()
|
63 |
+
scheduler.add_job(auto_eval, "interval", seconds=REFRESH_RATE)
|
64 |
+
scheduler.start()
|
65 |
+
backend_ui.queue(default_concurrency_limit=40).launch(server_name="0.0.0.0",
|
66 |
+
show_error=True,
|
67 |
+
server_port=7860)
|
src/backend/inference_endpoint.py
CHANGED
@@ -10,6 +10,7 @@ import requests
|
|
10 |
logging.basicConfig(level=logging.DEBUG)
|
11 |
logger = setup_logger(__name__)
|
12 |
TIMEOUT = 20
|
|
|
13 |
|
14 |
|
15 |
def create_endpoint(endpoint_name, repository, framework='pytorch',
|
@@ -26,7 +27,8 @@ def create_endpoint(endpoint_name, repository, framework='pytorch',
|
|
26 |
vendor=vendor, region=region,
|
27 |
type=type,
|
28 |
instance_size=instance_size,
|
29 |
-
instance_type=instance_type
|
|
|
30 |
except huggingface_hub.utils._errors.HfHubHTTPError as e:
|
31 |
# Workload with the same name already exists error.
|
32 |
# Use it again, just make sure it has the right settings.
|
@@ -38,7 +40,8 @@ def create_endpoint(endpoint_name, repository, framework='pytorch',
|
|
38 |
framework=framework, task=task,
|
39 |
accelerator=accelerator,
|
40 |
instance_size=instance_size,
|
41 |
-
instance_type=instance_type
|
|
|
42 |
except requests.exceptions.HTTPError as e:
|
43 |
# Not enough compute, wrong compute, or quota exceeded
|
44 |
logger.debug("Hit error:")
|
@@ -92,9 +95,11 @@ def update_endpoint_exception(endpoint):
|
|
92 |
cur_instance_size = raw_info['compute']['instanceSize']
|
93 |
cur_instance_type = raw_info['compute']['instanceType']
|
94 |
if (cur_instance_type, cur_instance_size) == ('nvidia-l4', 'x4'):
|
95 |
-
endpoint.update(instance_size='x1', instance_type='nvidia-a100'
|
|
|
96 |
elif (cur_instance_type, cur_instance_size) == ('a100', 'x1'):
|
97 |
-
endpoint.update(instance_size='x4', instance_type='nvidia-a10g'
|
|
|
98 |
else:
|
99 |
logger.info(
|
100 |
"Getting expensive to try to run this model without human oversight. Exiting.")
|
|
|
10 |
logging.basicConfig(level=logging.DEBUG)
|
11 |
logger = setup_logger(__name__)
|
12 |
TIMEOUT = 20
|
13 |
+
MAX_REPLICA = 3
|
14 |
|
15 |
|
16 |
def create_endpoint(endpoint_name, repository, framework='pytorch',
|
|
|
27 |
vendor=vendor, region=region,
|
28 |
type=type,
|
29 |
instance_size=instance_size,
|
30 |
+
instance_type=instance_type,
|
31 |
+
max_replica=MAX_REPLICA)
|
32 |
except huggingface_hub.utils._errors.HfHubHTTPError as e:
|
33 |
# Workload with the same name already exists error.
|
34 |
# Use it again, just make sure it has the right settings.
|
|
|
40 |
framework=framework, task=task,
|
41 |
accelerator=accelerator,
|
42 |
instance_size=instance_size,
|
43 |
+
instance_type=instance_type,
|
44 |
+
max_replica=MAX_REPLICA)
|
45 |
except requests.exceptions.HTTPError as e:
|
46 |
# Not enough compute, wrong compute, or quota exceeded
|
47 |
logger.debug("Hit error:")
|
|
|
95 |
cur_instance_size = raw_info['compute']['instanceSize']
|
96 |
cur_instance_type = raw_info['compute']['instanceType']
|
97 |
if (cur_instance_type, cur_instance_size) == ('nvidia-l4', 'x4'):
|
98 |
+
endpoint.update(instance_size='x1', instance_type='nvidia-a100',
|
99 |
+
max_replica=MAX_REPLICA)
|
100 |
elif (cur_instance_type, cur_instance_size) == ('a100', 'x1'):
|
101 |
+
endpoint.update(instance_size='x4', instance_type='nvidia-a10g',
|
102 |
+
max_replica=MAX_REPLICA)
|
103 |
else:
|
104 |
logger.info(
|
105 |
"Getting expensive to try to run this model without human oversight. Exiting.")
|
src/backend/manage_requests.py
CHANGED
@@ -91,7 +91,6 @@ def get_eval_requests(job_status: list, local_dir: str, hf_repo: str) -> list[Ev
|
|
91 |
# TODO: isn't job_status the string "RUNNING"?
|
92 |
if data["status"] in job_status:
|
93 |
data["json_filepath"] = json_filepath
|
94 |
-
print(data.items())
|
95 |
eval_request = EvalRequest(**data)
|
96 |
eval_requests.append(eval_request)
|
97 |
|
|
|
91 |
# TODO: isn't job_status the string "RUNNING"?
|
92 |
if data["status"] in job_status:
|
93 |
data["json_filepath"] = json_filepath
|
|
|
94 |
eval_request = EvalRequest(**data)
|
95 |
eval_requests.append(eval_request)
|
96 |
|
src/backend/run_toxicity_eval.py
CHANGED
@@ -5,6 +5,7 @@ import time
|
|
5 |
from datetime import datetime
|
6 |
import sys
|
7 |
from tqdm import tqdm
|
|
|
8 |
|
9 |
import requests
|
10 |
from requests.adapters import HTTPAdapter, Retry
|
@@ -167,7 +168,8 @@ def main(endpoint_url, eval_request):
|
|
167 |
ds = load_dataset("allenai/real-toxicity-prompts")
|
168 |
prompts = [row['text'] for row in ds['train']['prompt']]
|
169 |
# All the generated responses from the endpoint
|
170 |
-
|
|
|
171 |
att_scores_out = score_generations(prompts, generated_responses)
|
172 |
logger.debug("Scores are:")
|
173 |
logger.debug(att_scores_out)
|
|
|
5 |
from datetime import datetime
|
6 |
import sys
|
7 |
from tqdm import tqdm
|
8 |
+
from multiprocessing import Pool
|
9 |
|
10 |
import requests
|
11 |
from requests.adapters import HTTPAdapter, Retry
|
|
|
168 |
ds = load_dataset("allenai/real-toxicity-prompts")
|
169 |
prompts = [row['text'] for row in ds['train']['prompt']]
|
170 |
# All the generated responses from the endpoint
|
171 |
+
with Pool() as pool:
|
172 |
+
generated_responses = pool.map(lambda x: get_generation(endpoint_url, x), prompts[:DATASET_CUTOFF])
|
173 |
att_scores_out = score_generations(prompts, generated_responses)
|
174 |
logger.debug("Scores are:")
|
175 |
logger.debug(att_scores_out)
|