File size: 850 Bytes
9b37b1f 6dd4824 9b37b1f 6dd4824 df1b7f8 9b37b1f df1b7f8 a671856 df1b7f8 a671856 df1b7f8 b426c55 9b37b1f b426c55 28a440d 32e4e72 9b37b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import pytest as pytest
from grouped_sampling import GroupedSamplingPipeLine
from available_models import AVAILABLE_MODELS
from hanlde_form_submit import create_pipeline, on_form_submit
def test_on_form_submit():
model_name = "gpt2"
output_length = 10
prompt = "Answer yes or no, is the sky blue?"
output = on_form_submit(model_name, output_length, prompt)
assert output is not None
assert len(output) > 0
empty_prompt = ""
with pytest.raises(ValueError):
on_form_submit(model_name, output_length, empty_prompt)
def test_create_pipeline():
pipeline: GroupedSamplingPipeLine = create_pipeline("gpt2")
assert pipeline is not None
assert pipeline.model_name == "gpt2"
assert pipeline.wrapped_model.end_of_sentence_stop is False
del pipeline
if __name__ == "__main__":
pytest.main()
|