import os from dataclasses import dataclass, field import pytorch_lightning as pl import torch.nn.functional as F import threestudio from threestudio.models.exporters.base import Exporter, ExporterOutput from threestudio.systems.utils import parse_optimizer, parse_scheduler from threestudio.utils.base import Updateable, update_if_possible from threestudio.utils.config import parse_structured from threestudio.utils.misc import C, cleanup, get_device, load_module_weights from threestudio.utils.saving import SaverMixin from threestudio.utils.typing import * class BaseSystem(pl.LightningModule, Updateable, SaverMixin): @dataclass class Config: loggers: dict = field(default_factory=dict) loss: dict = field(default_factory=dict) optimizer: dict = field(default_factory=dict) scheduler: Optional[dict] = None weights: Optional[str] = None weights_ignore_modules: Optional[List[str]] = None cleanup_after_validation_step: bool = False cleanup_after_test_step: bool = False cfg: Config def __init__(self, cfg, resumed=False) -> None: super().__init__() self.cfg = parse_structured(self.Config, cfg) self._save_dir: Optional[str] = None self._resumed: bool = resumed self._resumed_eval: bool = False self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} if "loggers" in cfg: self.create_loggers(cfg.loggers) self.configure() if self.cfg.weights is not None: self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) self.post_configure() def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): state_dict, epoch, global_step = load_module_weights( weights, ignore_modules=ignore_modules, map_location="cpu" ) self.load_state_dict(state_dict, strict=False) # restore step-dependent states self.do_update_step(epoch, global_step, on_load_weights=True) def set_resume_status(self, current_epoch: int, global_step: int): # restore correct epoch and global step in eval self._resumed_eval = True self._resumed_eval_status["current_epoch"] = current_epoch self._resumed_eval_status["global_step"] = global_step @property def resumed(self): # whether from resumed checkpoint return self._resumed @property def true_global_step(self): if self._resumed_eval: return self._resumed_eval_status["global_step"] else: return self.global_step @property def true_current_epoch(self): if self._resumed_eval: return self._resumed_eval_status["current_epoch"] else: return self.current_epoch def configure(self) -> None: pass def post_configure(self) -> None: """ executed after weights are loaded """ pass def C(self, value: Any) -> float: return C(value, self.true_current_epoch, self.true_global_step) def configure_optimizers(self): optim = parse_optimizer(self.cfg.optimizer, self) ret = { "optimizer": optim, } if self.cfg.scheduler is not None: ret.update( { "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), } ) return ret def training_step(self, batch, batch_idx): raise NotImplementedError def validation_step(self, batch, batch_idx): raise NotImplementedError def on_validation_batch_end(self, outputs, batch, batch_idx): if self.cfg.cleanup_after_validation_step: # cleanup to save vram cleanup() def on_validation_epoch_end(self): raise NotImplementedError def test_step(self, batch, batch_idx): raise NotImplementedError def on_test_batch_end(self, outputs, batch, batch_idx): if self.cfg.cleanup_after_test_step: # cleanup to save vram cleanup() def on_test_epoch_end(self): pass def predict_step(self, batch, batch_idx): raise NotImplementedError def on_predict_batch_end(self, outputs, batch, batch_idx): if self.cfg.cleanup_after_test_step: # cleanup to save vram cleanup() def on_predict_epoch_end(self): pass def preprocess_data(self, batch, stage): pass """ Implementing on_after_batch_transfer of DataModule does the same. But on_after_batch_transfer does not support DP. """ def on_train_batch_start(self, batch, batch_idx, unused=0): self.preprocess_data(batch, "train") self.dataset = self.trainer.train_dataloader.dataset update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) self.do_update_step(self.true_current_epoch, self.true_global_step) def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): self.preprocess_data(batch, "validation") self.dataset = self.trainer.val_dataloaders.dataset update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) self.do_update_step(self.true_current_epoch, self.true_global_step) def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): self.preprocess_data(batch, "test") self.dataset = self.trainer.test_dataloaders.dataset update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) self.do_update_step(self.true_current_epoch, self.true_global_step) def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): self.preprocess_data(batch, "predict") self.dataset = self.trainer.predict_dataloaders.dataset update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) self.do_update_step(self.true_current_epoch, self.true_global_step) def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): pass def on_before_optimizer_step(self, optimizer): """ # some gradient-related debugging goes here, example: from lightning.pytorch.utilities import grad_norm norms = grad_norm(self.geometry, norm_type=2) print(norms) """ pass class BaseLift3DSystem(BaseSystem): @dataclass class Config(BaseSystem.Config): geometry_type: str = "" geometry: dict = field(default_factory=dict) geometry_convert_from: Optional[str] = None geometry_convert_inherit_texture: bool = False # used to override configurations of the previous geometry being converted from, # for example isosurface_threshold geometry_convert_override: dict = field(default_factory=dict) material_type: str = "" material: dict = field(default_factory=dict) background_type: str = "" background: dict = field(default_factory=dict) renderer_type: str = "" renderer: dict = field(default_factory=dict) guidance_type: str = "" guidance: dict = field(default_factory=dict) prompt_processor_type: str = "" prompt_processor: dict = field(default_factory=dict) # geometry export configurations, no need to specify in training exporter_type: str = "mesh-exporter" exporter: dict = field(default_factory=dict) cfg: Config def configure(self) -> None: if ( self.cfg.geometry_convert_from # from_coarse must be specified and not self.cfg.weights # not initialized from coarse when weights are specified and not self.resumed # not initialized from coarse when resumed from checkpoints ): threestudio.info("Initializing geometry from a given checkpoint ...") from threestudio.utils.config import load_config, parse_structured prev_cfg = load_config( os.path.join( os.path.dirname(self.cfg.geometry_convert_from), "../configs/parsed.yaml", ) ) # TODO: hard-coded relative path prev_system_cfg: BaseLift3DSystem.Config = parse_structured( self.Config, prev_cfg.system ) prev_geometry_cfg = prev_system_cfg.geometry prev_geometry_cfg.update(self.cfg.geometry_convert_override) prev_geometry = threestudio.find(prev_system_cfg.geometry_type)( prev_geometry_cfg ) state_dict, epoch, global_step = load_module_weights( self.cfg.geometry_convert_from, module_name="geometry", map_location="cpu", ) prev_geometry.load_state_dict(state_dict, strict=False) # restore step-dependent states prev_geometry.do_update_step(epoch, global_step, on_load_weights=True) # convert from coarse stage geometry prev_geometry = prev_geometry.to(get_device()) self.geometry = threestudio.find(self.cfg.geometry_type).create_from( prev_geometry, self.cfg.geometry, copy_net=self.cfg.geometry_convert_inherit_texture, ) del prev_geometry cleanup() else: self.geometry = threestudio.find(self.cfg.geometry_type)(self.cfg.geometry) self.material = threestudio.find(self.cfg.material_type)(self.cfg.material) self.background = threestudio.find(self.cfg.background_type)( self.cfg.background ) self.renderer = threestudio.find(self.cfg.renderer_type)( self.cfg.renderer, geometry=self.geometry, material=self.material, background=self.background, ) def on_fit_start(self) -> None: if self._save_dir is not None: threestudio.info(f"Validation results will be saved to {self._save_dir}") else: threestudio.warn( f"Saving directory not set for the system, visualization results will not be saved" ) def on_test_end(self) -> None: if self._save_dir is not None: threestudio.info(f"Test results saved to {self._save_dir}") def on_predict_start(self) -> None: self.exporter: Exporter = threestudio.find(self.cfg.exporter_type)( self.cfg.exporter, geometry=self.geometry, material=self.material, background=self.background, ) def predict_step(self, batch, batch_idx): if self.exporter.cfg.save_video: self.test_step(batch, batch_idx) def on_predict_epoch_end(self) -> None: if self.exporter.cfg.save_video: self.on_test_epoch_end() exporter_output: List[ExporterOutput] = self.exporter() for out in exporter_output: save_func_name = f"save_{out.save_type}" if not hasattr(self, save_func_name): raise ValueError(f"{save_func_name} not supported by the SaverMixin") save_func = getattr(self, save_func_name) save_func(f"it{self.true_global_step}-export/{out.save_name}", **out.params) def on_predict_end(self) -> None: if self._save_dir is not None: threestudio.info(f"Export assets saved to {self._save_dir}") def guidance_evaluation_save(self, comp_rgb, guidance_eval_out): B, size = comp_rgb.shape[:2] resize = lambda x: F.interpolate( x.permute(0, 3, 1, 2), (size, size), mode="bilinear", align_corners=False ).permute(0, 2, 3, 1) filename = f"it{self.true_global_step}-train.png" def merge12(x): return x.reshape(-1, *x.shape[2:]) self.save_image_grid( filename, [ { "type": "rgb", "img": merge12(comp_rgb), "kwargs": {"data_format": "HWC"}, }, ] + ( [ { "type": "rgb", "img": merge12(resize(guidance_eval_out["imgs_noisy"])), "kwargs": {"data_format": "HWC"}, } ] ) + ( [ { "type": "rgb", "img": merge12(resize(guidance_eval_out["imgs_1step"])), "kwargs": {"data_format": "HWC"}, } ] ) + ( [ { "type": "rgb", "img": merge12(resize(guidance_eval_out["imgs_1orig"])), "kwargs": {"data_format": "HWC"}, } ] ) + ( [ { "type": "rgb", "img": merge12(resize(guidance_eval_out["imgs_final"])), "kwargs": {"data_format": "HWC"}, } ] ), name="train_step", step=self.true_global_step, texts=guidance_eval_out["texts"], )