yonikremer commited on
Commit
df1b7f8
1 Parent(s): 70130da

added tests

Browse files
Files changed (1) hide show
  1. tests.py +34 -3
tests.py CHANGED
@@ -4,7 +4,8 @@ import shutil
4
  import pytest as pytest
5
  from grouped_sampling import GroupedSamplingPipeLine
6
 
7
- from hanlde_form_submit import create_pipeline
 
8
  from prompt_engeneering import rewrite_prompt
9
  from supported_models import get_supported_model_names
10
 
@@ -24,10 +25,40 @@ def test_get_supported_model_names():
24
  supported_model_names = get_supported_model_names()
25
  assert len(supported_model_names) > 0
26
  assert "gpt2" in supported_model_names
27
- assert all([isinstance(name, str) for name in supported_model_names])
28
 
29
 
30
- @pytest.mark.parametrize("model_name", get_supported_model_names())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def test_create_pipeline(model_name: str):
32
  pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
33
  assert pipeline is not None
 
4
  import pytest as pytest
5
  from grouped_sampling import GroupedSamplingPipeLine
6
 
7
+ from on_server_start import download_useful_models
8
+ from hanlde_form_submit import create_pipeline, on_form_submit
9
  from prompt_engeneering import rewrite_prompt
10
  from supported_models import get_supported_model_names
11
 
 
25
  supported_model_names = get_supported_model_names()
26
  assert len(supported_model_names) > 0
27
  assert "gpt2" in supported_model_names
28
+ assert all(isinstance(name, str) for name in supported_model_names)
29
 
30
 
31
+ def test_on_server_start():
32
+ if os.path.exists(HUGGING_FACE_CACHE_DIR):
33
+ shutil.rmtree(HUGGING_FACE_CACHE_DIR)
34
+ assert not os.path.exists(HUGGING_FACE_CACHE_DIR)
35
+ download_useful_models()
36
+ assert os.path.exists(HUGGING_FACE_CACHE_DIR)
37
+ assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0
38
+
39
+
40
+ def test_on_form_submit():
41
+ model_name = "gpt2"
42
+ output_length = 10
43
+ prompt = "Answer yes or no, is the sky blue?"
44
+ output = on_form_submit(model_name, output_length, prompt)
45
+ assert output is not None
46
+ assert len(output) > 0
47
+ empty_prompt = ""
48
+ with pytest.raises(ValueError):
49
+ on_form_submit(model_name, output_length, empty_prompt)
50
+ unsupported_model_name = "unsupported_model_name"
51
+ with pytest.raises(ValueError):
52
+ on_form_submit(unsupported_model_name, output_length, prompt)
53
+
54
+
55
+ @pytest.mark.parametrize(
56
+ "model_name",
57
+ get_supported_model_names(
58
+ min_number_of_downloads=1000,
59
+ min_number_of_likes=100,
60
+ )
61
+ )
62
  def test_create_pipeline(model_name: str):
63
  pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
64
  assert pipeline is not None