Spaces:
Runtime error
Runtime error
import argparse | |
import contextlib | |
import logging | |
import os | |
import sys | |
import shutil | |
class ColoredFilter(logging.Filter): | |
""" | |
A logging filter to add color to certain log levels. | |
""" | |
RESET = "\033[0m" | |
RED = "\033[31m" | |
GREEN = "\033[32m" | |
YELLOW = "\033[33m" | |
BLUE = "\033[34m" | |
MAGENTA = "\033[35m" | |
CYAN = "\033[36m" | |
COLORS = { | |
"WARNING": YELLOW, | |
"INFO": GREEN, | |
"DEBUG": BLUE, | |
"CRITICAL": MAGENTA, | |
"ERROR": RED, | |
} | |
RESET = "\x1b[0m" | |
def __init__(self): | |
super().__init__() | |
def filter(self, record): | |
if record.levelname in self.COLORS: | |
color_start = self.COLORS[record.levelname] | |
record.levelname = f"{color_start}[{record.levelname}]" | |
record.msg = f"{record.msg}{self.RESET}" | |
return True | |
def main(args, extras) -> None: | |
# set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) | |
env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] | |
selected_gpus = [0] | |
# Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. | |
# As far as Pytorch Lightning is concerned, we always use all available GPUs | |
# (possibly filtered by CUDA_VISIBLE_DEVICES). | |
devices = -1 | |
if len(env_gpus) > 0: | |
# CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script. | |
n_gpus = len(env_gpus) | |
else: | |
selected_gpus = list(args.gpu.split(",")) | |
n_gpus = len(selected_gpus) | |
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu | |
import pytorch_lightning as pl | |
import torch | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | |
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger | |
from pytorch_lightning.utilities.rank_zero import rank_zero_only | |
if args.typecheck: | |
from jaxtyping import install_import_hook | |
install_import_hook("threestudio", "typeguard.typechecked") | |
import threestudio | |
from threestudio.systems.base import BaseSystem | |
from threestudio.utils.callbacks import ( | |
CodeSnapshotCallback, | |
ConfigSnapshotCallback, | |
CustomProgressBar, | |
ProgressCallback, | |
) | |
from threestudio.utils.config import ExperimentConfig, load_config | |
from threestudio.utils.misc import get_rank | |
from threestudio.utils.typing import Optional | |
logger = logging.getLogger("pytorch_lightning") | |
if args.verbose: | |
logger.setLevel(logging.DEBUG) | |
for handler in logger.handlers: | |
if handler.stream == sys.stderr: # type: ignore | |
if not args.gradio: | |
handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) | |
handler.addFilter(ColoredFilter()) | |
else: | |
handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) | |
# parse YAML config to OmegaConf | |
cfg: ExperimentConfig | |
cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) | |
# set a different seed for each device | |
pl.seed_everything(cfg.seed + get_rank(), workers=True) | |
dm = threestudio.find(cfg.data_type)(cfg.data) | |
system: BaseSystem = threestudio.find(cfg.system_type)( | |
cfg.system, resumed=cfg.resume is not None | |
) | |
system.set_save_dir(os.path.join(cfg.trial_dir, "save")) | |
if args.gradio: | |
fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) | |
fh.setLevel(logging.INFO) | |
if args.verbose: | |
fh.setLevel(logging.DEBUG) | |
fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) | |
logger.addHandler(fh) | |
callbacks = [] | |
if args.train: | |
callbacks += [ | |
ModelCheckpoint( | |
dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint | |
), | |
LearningRateMonitor(logging_interval="step"), | |
CodeSnapshotCallback( | |
os.path.join(cfg.trial_dir, "code"), use_version=False | |
), | |
ConfigSnapshotCallback( | |
args.config, | |
cfg, | |
os.path.join(cfg.trial_dir, "configs"), | |
use_version=False, | |
), | |
] | |
if args.gradio: | |
callbacks += [ | |
ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) | |
] | |
else: | |
callbacks += [CustomProgressBar(refresh_rate=1)] | |
def write_to_text(file, lines): | |
with open(file, "w") as f: | |
for line in lines: | |
f.write(line + "\n") | |
loggers = [] | |
if args.train: | |
# make tensorboard logging dir to suppress warning | |
rank_zero_only( | |
lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) | |
)() | |
loggers += [ | |
TensorBoardLogger(cfg.trial_dir, name="tb_logs"), | |
CSVLogger(cfg.trial_dir, name="csv_logs"), | |
] + system.get_loggers() | |
rank_zero_only( | |
lambda: write_to_text( | |
os.path.join(cfg.trial_dir, "cmd.txt"), | |
["python " + " ".join(sys.argv), str(args)], | |
) | |
)() | |
# if not os.path.exists( cfg.trial_dir+"/gaussiansplatting"): | |
# shutil.copytree("./gaussiansplatting", cfg.trial_dir+"/gaussiansplatting") | |
trainer = Trainer( | |
callbacks=callbacks, | |
logger=loggers, | |
inference_mode=False, | |
accelerator="gpu", | |
devices=devices, | |
**cfg.trainer, | |
) | |
def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): | |
if ckpt_path is None: | |
return | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) | |
if args.train: | |
trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) | |
trainer.test(system, datamodule=dm) | |
if args.gradio: | |
# also export assets if in gradio mode | |
trainer.predict(system, datamodule=dm) | |
elif args.validate: | |
# manually set epoch and global_step as they cannot be automatically resumed | |
set_system_status(system, cfg.resume) | |
trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) | |
elif args.test: | |
# manually set epoch and global_step as they cannot be automatically resumed | |
set_system_status(system, cfg.resume) | |
trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) | |
elif args.export: | |
set_system_status(system, cfg.resume) | |
trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", required=True, help="path to config file") | |
parser.add_argument( | |
"--gpu", | |
default="0", | |
help="GPU(s) to be used. 0 means use the 1st available GPU. " | |
"1,2 means use the 2nd and 3rd available GPU. " | |
"If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " | |
"this argument is ignored and all available GPUs are always used.", | |
) | |
group = parser.add_mutually_exclusive_group(required=True) | |
group.add_argument("--train", action="store_true") | |
group.add_argument("--validate", action="store_true") | |
group.add_argument("--test", action="store_true") | |
group.add_argument("--export", action="store_true") | |
parser.add_argument( | |
"--gradio", action="store_true", help="if true, run in gradio mode" | |
) | |
parser.add_argument( | |
"--verbose", action="store_true", help="if true, set logging level to DEBUG" | |
) | |
parser.add_argument( | |
"--typecheck", | |
action="store_true", | |
help="whether to enable dynamic type checking", | |
) | |
args, extras = parser.parse_known_args() | |
if args.gradio: | |
# FIXME: no effect, stdout is not captured | |
with contextlib.redirect_stdout(sys.stderr): | |
main(args, extras) | |
else: | |
main(args, extras) | |