|
import sys |
|
import os |
|
|
|
sys.path.append( |
|
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) |
|
) |
|
from megatron.neox_arguments import neox_args, deepspeed_args |
|
from inspect import getmembers, getsource |
|
from dataclasses import field, is_dataclass |
|
from itertools import tee, zip_longest |
|
import pathlib |
|
|
|
|
|
def pairwise(iterable): |
|
"s -> (s0,s1), (s1,s2), (s2, s3), ..." |
|
a, b = tee(iterable) |
|
next(b, None) |
|
return zip_longest(a, b) |
|
|
|
|
|
def get_docs(module): |
|
ARGS_CLASSES = getmembers(module, is_dataclass) |
|
results = {} |
|
for name, dcls in ARGS_CLASSES: |
|
assert is_dataclass(dcls) |
|
src = getsource(dcls) |
|
d = dcls() |
|
loc = 0 |
|
results[name] = {"doc": d.__doc__.strip(), "attributes": {}} |
|
for cur, _next in pairwise(d.__dataclass_fields__.items()): |
|
field_name, field_def = cur |
|
field_type = field_def.type |
|
if hasattr(field_type, "__name__"): |
|
if field_type.__name__ == "Literal" or field_type.__name__ == "Union": |
|
field_type = field_type |
|
else: |
|
field_type = str(field_type.__name__) |
|
else: |
|
field_type = str(field_type) |
|
|
|
field_default = field_def.default |
|
|
|
|
|
loc = src.find(f" {field_name}:", loc + len(field_name) + 1) |
|
|
|
if _next is not None: |
|
next_field_name, _ = _next |
|
|
|
next_loc = src.find(f"{next_field_name}:", loc + len(field_name)) |
|
else: |
|
next_loc = len(src) |
|
|
|
|
|
_src = src[loc:next_loc].strip() |
|
if '"""' in _src: |
|
doc = _src.split('"""')[1].strip() |
|
elif "'''" in _src: |
|
doc = _src.split("'''")[1].strip() |
|
else: |
|
doc = "" |
|
results[name]["attributes"][field_name] = { |
|
"name": field_name, |
|
"type": field_type, |
|
"default": field_default, |
|
"doc": doc, |
|
} |
|
return results |
|
|
|
|
|
def to_md(docs, intro_str=""): |
|
""" |
|
Writes the docs dictionary to markdown format |
|
""" |
|
lines = [] |
|
lines.append(intro_str) |
|
for name, doc in docs.items(): |
|
lines.append(f"## {name}") |
|
lines.append(f"{doc['doc']}") |
|
lines.append("") |
|
for field_name, field_def in doc["attributes"].items(): |
|
|
|
lines.append(f"- **{field_name}**: {field_def['type']}") |
|
|
|
lines.append(f" Default = {str(field_def['default'])}") |
|
lines.append(f" {field_def['doc']}") |
|
lines.append("") |
|
return "\n\n".join(lines) |
|
|
|
|
|
if __name__ == "__main__": |
|
docs = get_docs(neox_args) |
|
docs.update(get_docs(deepspeed_args)) |
|
intro_str = """Arguments for gpt-neox. All of the following can be specified in your .yml config file(s):\n""" |
|
md = to_md(docs, intro_str=intro_str) |
|
with open(f"{pathlib.Path(__file__).parent.resolve()}/neox_arguments.md", "w") as f: |
|
f.write(md) |
|
|