yonikremer's picture
testing the demo - available models and not all the supported models
3e0e787
raw
history blame
1.2 kB
import pytest as pytest
from grouped_sampling import GroupedSamplingPipeLine, UnsupportedModelNameException
from available_models import AVAILABLE_MODELS
from hanlde_form_submit import create_pipeline, on_form_submit
def test_on_form_submit():
model_name = "gpt2"
output_length = 10
prompt = "Answer yes or no, is the sky blue?"
output = on_form_submit(model_name, output_length, prompt)
assert output is not None
assert len(output) > 0
empty_prompt = ""
with pytest.raises(ValueError):
on_form_submit(model_name, output_length, empty_prompt)
unsupported_model_name = "unsupported_model_name"
with pytest.raises(UnsupportedModelNameException):
on_form_submit(unsupported_model_name, output_length, prompt)
@pytest.mark.parametrize(
"model_name",
AVAILABLE_MODELS,
)
def test_create_pipeline(model_name: str):
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
assert pipeline is not None
assert pipeline.model_name == model_name
assert pipeline.wrapped_model.group_size == 5
assert pipeline.wrapped_model.end_of_sentence_stop is False
del pipeline
if __name__ == "__main__":
pytest.main()