|
import pytest as pytest |
|
from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException |
|
|
|
from hanlde_form_submit import create_pipeline, on_form_submit |
|
|
|
|
|
def test_on_form_submit(): |
|
model_name = "gpt2" |
|
output_length = 10 |
|
prompt = "Answer yes or no, is the sky blue?" |
|
output = on_form_submit(model_name, output_length, prompt) |
|
assert output is not None |
|
assert len(output) > 0 |
|
empty_prompt = "" |
|
with pytest.raises(ValueError): |
|
on_form_submit(model_name, output_length, empty_prompt) |
|
unsupported_model_name = "unsupported_model_name" |
|
with pytest.raises(UnsupportedModelNameException): |
|
on_form_submit(unsupported_model_name, output_length, prompt) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"model_name", |
|
get_full_models_list()[:3] |
|
) |
|
def test_create_pipeline(model_name: str): |
|
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5) |
|
assert pipeline is not None |
|
assert pipeline.model_name == model_name |
|
assert pipeline.wrapped_model.group_size == 5 |
|
assert pipeline.wrapped_model.end_of_sentence_stop is False |
|
del pipeline |
|
|
|
|
|
if __name__ == "__main__": |
|
pytest.main() |
|
|