import os import pytest as pytest from grouped_sampling import GroupedSamplingPipeLine from on_server_start import download_useful_models from hanlde_form_submit import create_pipeline, on_form_submit from prompt_engeneering import rewrite_prompt from supported_models import get_supported_model_names HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub" def test_prompt_engineering(): example_prompt = "Answer yes or no, is the sky blue?" rewritten_prompt = rewrite_prompt(example_prompt) assert rewritten_prompt.startswith("Web search results:") assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?") assert "Current date: " in rewritten_prompt assert "Instructions: " in rewritten_prompt def test_get_supported_model_names(): supported_model_names = get_supported_model_names() assert len(supported_model_names) > 0 assert "gpt2" in supported_model_names assert all(isinstance(name, str) for name in supported_model_names) def test_on_server_start(): download_useful_models() assert os.path.exists(HUGGING_FACE_CACHE_DIR) assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0 def test_on_form_submit(): model_name = "gpt2" output_length = 10 prompt = "Answer yes or no, is the sky blue?" output = on_form_submit(model_name, output_length, prompt, web_search=False) assert output is not None assert len(output) > 0 empty_prompt = "" with pytest.raises(ValueError): on_form_submit(model_name, output_length, empty_prompt, web_search=False) unsupported_model_name = "unsupported_model_name" with pytest.raises(ValueError): on_form_submit(unsupported_model_name, output_length, prompt, web_search=False) @pytest.mark.parametrize( "model_name", get_supported_model_names( min_number_of_downloads=1000, min_number_of_likes=100, ) ) def test_create_pipeline(model_name: str): pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5) assert pipeline is not None assert pipeline.model_name == model_name assert pipeline.wrapped_model.group_size == 5 assert pipeline.wrapped_model.end_of_sentence_stop is False del pipeline if __name__ == "__main__": pytest.main()