import pytest as pytest from grouped_sampling import GroupedSamplingPipeLine from available_models import AVAILABLE_MODELS from hanlde_form_submit import create_pipeline, on_form_submit def test_on_form_submit(): model_name = "gpt2" output_length = 10 prompt = "Answer yes or no, is the sky blue?" output = on_form_submit(model_name, output_length, prompt) assert output is not None assert len(output) > 0 empty_prompt = "" with pytest.raises(ValueError): on_form_submit(model_name, output_length, empty_prompt) @pytest.mark.parametrize( "model_name", AVAILABLE_MODELS, ) def test_create_pipeline(model_name: str): pipeline: GroupedSamplingPipeLine = create_pipeline(model_name) assert pipeline is not None assert pipeline.model_name == model_name assert pipeline.wrapped_model.group_size == 5 assert pipeline.wrapped_model.end_of_sentence_stop is False del pipeline if __name__ == "__main__": pytest.main()