|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
import pytest |
|
import os |
|
from os.path import abspath, dirname, join |
|
import torch |
|
import warnings |
|
|
|
|
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" |
|
|
|
|
|
|
|
git_repo_path = abspath(join(dirname(dirname(__file__)), "src")) |
|
sys.path.insert(1, git_repo_path) |
|
|
|
|
|
def pytest_configure(config): |
|
|
|
|
|
|
|
config.option.verbose = True |
|
|
|
|
|
def pytest_addoption(parser): |
|
parser.addoption("--torch_ver", default=None, type=str) |
|
parser.addoption("--cuda_ver", default=None, type=str) |
|
|
|
|
|
def validate_version(expected, found): |
|
version_depth = expected.count(".") + 1 |
|
found = ".".join(found.split(".")[:version_depth]) |
|
return found == expected |
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True) |
|
def check_environment(pytestconfig): |
|
expected_torch_version = pytestconfig.getoption("torch_ver") |
|
expected_cuda_version = pytestconfig.getoption("cuda_ver") |
|
if expected_torch_version is None: |
|
warnings.warn( |
|
"Running test without verifying torch version, please provide an expected torch version with --torch_ver" |
|
) |
|
elif not validate_version(expected_torch_version, torch.__version__): |
|
pytest.exit( |
|
f"expected torch version {expected_torch_version} did not match found torch version {torch.__version__}", |
|
returncode=2, |
|
) |
|
if expected_cuda_version is None: |
|
warnings.warn( |
|
"Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver" |
|
) |
|
elif not validate_version(expected_cuda_version, torch.version.cuda): |
|
pytest.exit( |
|
f"expected cuda version {expected_cuda_version} did not match found cuda version {torch.version.cuda}", |
|
returncode=2, |
|
) |
|
|
|
|
|
|
|
|
|
@pytest.hookimpl(tryfirst=True) |
|
def pytest_runtest_call(item): |
|
|
|
if getattr(item.cls, "is_dist_test", False): |
|
dist_test_class = item.cls() |
|
dist_test_class(item._request) |
|
item.runtest = lambda: True |
|
|
|
|
|
|
|
|
|
|
|
def pytest_runtest_teardown(item, nextitem): |
|
if getattr(item.cls, "reuse_dist_env", False) and not nextitem: |
|
dist_test_class = item.cls() |
|
for num_procs, pool in dist_test_class._pool_cache.items(): |
|
dist_test_class._close_pool(pool, num_procs, force=True) |
|
|
|
|
|
@pytest.hookimpl(tryfirst=True) |
|
def pytest_fixture_setup(fixturedef, request): |
|
if getattr(fixturedef.func, "is_dist_fixture", False): |
|
dist_fixture_class = fixturedef.func() |
|
dist_fixture_class(request) |
|
|