import argparse import contextlib import logging import os import sys import shutil class ColoredFilter(logging.Filter): """ A logging filter to add color to certain log levels. """ RESET = "\033[0m" RED = "\033[31m" GREEN = "\033[32m" YELLOW = "\033[33m" BLUE = "\033[34m" MAGENTA = "\033[35m" CYAN = "\033[36m" COLORS = { "WARNING": YELLOW, "INFO": GREEN, "DEBUG": BLUE, "CRITICAL": MAGENTA, "ERROR": RED, } RESET = "\x1b[0m" def __init__(self): super().__init__() def filter(self, record): if record.levelname in self.COLORS: color_start = self.COLORS[record.levelname] record.levelname = f"{color_start}[{record.levelname}]" record.msg = f"{record.msg}{self.RESET}" return True def main(args, extras) -> None: # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] selected_gpus = [0] # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. # As far as Pytorch Lightning is concerned, we always use all available GPUs # (possibly filtered by CUDA_VISIBLE_DEVICES). devices = -1 if len(env_gpus) > 0: # CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script. n_gpus = len(env_gpus) else: selected_gpus = list(args.gpu.split(",")) n_gpus = len(selected_gpus) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu import pytorch_lightning as pl import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger from pytorch_lightning.utilities.rank_zero import rank_zero_only if args.typecheck: from jaxtyping import install_import_hook install_import_hook("threestudio", "typeguard.typechecked") import threestudio from threestudio.systems.base import BaseSystem from threestudio.utils.callbacks import ( CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar, ProgressCallback, ) from threestudio.utils.config import ExperimentConfig, load_config from threestudio.utils.misc import get_rank from threestudio.utils.typing import Optional logger = logging.getLogger("pytorch_lightning") if args.verbose: logger.setLevel(logging.DEBUG) for handler in logger.handlers: if handler.stream == sys.stderr: # type: ignore if not args.gradio: handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) handler.addFilter(ColoredFilter()) else: handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) # parse YAML config to OmegaConf cfg: ExperimentConfig cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) # set a different seed for each device pl.seed_everything(cfg.seed + get_rank(), workers=True) dm = threestudio.find(cfg.data_type)(cfg.data) system: BaseSystem = threestudio.find(cfg.system_type)( cfg.system, resumed=cfg.resume is not None ) system.set_save_dir(os.path.join(cfg.trial_dir, "save")) if args.gradio: fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) fh.setLevel(logging.INFO) if args.verbose: fh.setLevel(logging.DEBUG) fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) logger.addHandler(fh) callbacks = [] if args.train: callbacks += [ ModelCheckpoint( dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint ), LearningRateMonitor(logging_interval="step"), CodeSnapshotCallback( os.path.join(cfg.trial_dir, "code"), use_version=False ), ConfigSnapshotCallback( args.config, cfg, os.path.join(cfg.trial_dir, "configs"), use_version=False, ), ] if args.gradio: callbacks += [ ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) ] else: callbacks += [CustomProgressBar(refresh_rate=1)] def write_to_text(file, lines): with open(file, "w") as f: for line in lines: f.write(line + "\n") loggers = [] if args.train: # make tensorboard logging dir to suppress warning rank_zero_only( lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) )() loggers += [ TensorBoardLogger(cfg.trial_dir, name="tb_logs"), CSVLogger(cfg.trial_dir, name="csv_logs"), ] + system.get_loggers() rank_zero_only( lambda: write_to_text( os.path.join(cfg.trial_dir, "cmd.txt"), ["python " + " ".join(sys.argv), str(args)], ) )() # if not os.path.exists( cfg.trial_dir+"/gaussiansplatting"): # shutil.copytree("./gaussiansplatting", cfg.trial_dir+"/gaussiansplatting") trainer = Trainer( callbacks=callbacks, logger=loggers, inference_mode=False, accelerator="gpu", devices=devices, **cfg.trainer, ) def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): if ckpt_path is None: return ckpt = torch.load(ckpt_path, map_location="cpu") system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) if args.train: trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) trainer.test(system, datamodule=dm) if args.gradio: # also export assets if in gradio mode trainer.predict(system, datamodule=dm) elif args.validate: # manually set epoch and global_step as they cannot be automatically resumed set_system_status(system, cfg.resume) trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) elif args.test: # manually set epoch and global_step as they cannot be automatically resumed set_system_status(system, cfg.resume) trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) elif args.export: set_system_status(system, cfg.resume) trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", required=True, help="path to config file") parser.add_argument( "--gpu", default="0", help="GPU(s) to be used. 0 means use the 1st available GPU. " "1,2 means use the 2nd and 3rd available GPU. " "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " "this argument is ignored and all available GPUs are always used.", ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--train", action="store_true") group.add_argument("--validate", action="store_true") group.add_argument("--test", action="store_true") group.add_argument("--export", action="store_true") parser.add_argument( "--gradio", action="store_true", help="if true, run in gradio mode" ) parser.add_argument( "--verbose", action="store_true", help="if true, set logging level to DEBUG" ) parser.add_argument( "--typecheck", action="store_true", help="whether to enable dynamic type checking", ) args, extras = parser.parse_known_args() if args.gradio: # FIXME: no effect, stdout is not captured with contextlib.redirect_stdout(sys.stderr): main(args, extras) else: main(args, extras)