|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from megatron.neox_arguments import NeoXArgs |
|
from tests.common import BASE_CONFIG, DistributedTest |
|
|
|
|
|
def test_main_constructor(): |
|
input_args = ["train.py", "tests/config/test_setup.yml"] |
|
neox_args = NeoXArgs.consume_deepy_args(input_args) |
|
deepspeed_main_args = neox_args.get_deepspeed_main_args() |
|
neox_args = NeoXArgs.consume_neox_args(input_args=deepspeed_main_args) |
|
neox_args.configure_distributed_args() |
|
|
|
|
|
class test_constructor_from_ymls_class(DistributedTest): |
|
world_size = 2 |
|
|
|
def test(self): |
|
neox_args = NeoXArgs.from_ymls(["tests/config/test_setup.yml"]) |
|
neox_args.configure_distributed_args() |
|
|
|
|
|
def test_constructor_from_ymls(): |
|
t1 = test_constructor_from_ymls_class() |
|
t1.test() |
|
|
|
|
|
class test_constructor_from_dict_class(DistributedTest): |
|
world_size = 2 |
|
|
|
def test(self): |
|
neox_args = NeoXArgs.from_dict(BASE_CONFIG) |
|
|
|
|
|
def test_constructor_from_dict(): |
|
t1 = test_constructor_from_dict_class() |
|
t1.test() |
|
|