from dataclasses import dataclass, field import torch import threestudio from threestudio.systems.base import BaseLift3DSystem from threestudio.utils.ops import binary_cross_entropy, dot from threestudio.utils.typing import * from gaussiansplatting.gaussian_renderer import render from gaussiansplatting.scene import Scene, GaussianModel from gaussiansplatting.arguments import ModelParams, PipelineParams, get_combined_args,OptimizationParams from gaussiansplatting.scene.cameras import Camera from argparse import ArgumentParser, Namespace import os from pathlib import Path from plyfile import PlyData, PlyElement from gaussiansplatting.utils.sh_utils import SH2RGB from gaussiansplatting.scene.gaussian_model import BasicPointCloud import numpy as np from shap_e.diffusion.sample import sample_latents from shap_e.diffusion.gaussian_diffusion import diffusion_from_config as diffusion_from_config_shape from shap_e.models.download import load_model, load_config from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget from shap_e.util.notebooks import decode_latent_mesh import io from PIL import Image import open3d as o3d def load_ply(path,save_path): C0 = 0.28209479177387814 def SH2RGB(sh): return sh * C0 + 0.5 plydata = PlyData.read(path) xyz = np.stack((np.asarray(plydata.elements[0]["x"]), np.asarray(plydata.elements[0]["y"]), np.asarray(plydata.elements[0]["z"])), axis=1) features_dc = np.zeros((xyz.shape[0], 3, 1)) features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) color = SH2RGB(features_dc[:,:,0]) point_cloud = o3d.geometry.PointCloud() point_cloud.points = o3d.utility.Vector3dVector(xyz) point_cloud.colors = o3d.utility.Vector3dVector(color) o3d.io.write_point_cloud(save_path, point_cloud) def storePly(path, xyz, rgb): # Define the dtype for the structured array dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] normals = np.zeros_like(xyz) elements = np.empty(xyz.shape[0], dtype=dtype) attributes = np.concatenate((xyz, normals, rgb), axis=1) elements[:] = list(map(tuple, attributes)) # Create the PlyData object and write to file vertex_element = PlyElement.describe(elements, 'vertex') ply_data = PlyData([vertex_element]) ply_data.write(path) def fetchPly(path): plydata = PlyData.read(path) vertices = plydata['vertex'] positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T return BasicPointCloud(points=positions, colors=colors, normals=normals) @threestudio.register("gaussiandreamer-system") class GaussianDreamer(BaseLift3DSystem): @dataclass class Config(BaseLift3DSystem.Config): radius: float = 4 sh_degree: int = 0 cfg: Config def configure(self) -> None: self.radius = self.cfg.radius self.sh_degree =self.cfg.sh_degree self.gaussian = GaussianModel(sh_degree = self.sh_degree) bg_color = [1, 1, 1] if False else [0, 0, 0] self.background_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda") def save_gif_to_file(self,images, output_file): with io.BytesIO() as writer: images[0].save( writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0 ) writer.seek(0) with open(output_file, 'wb') as file: file.write(writer.read()) def shape(self): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') xm = load_model('transmitter', device=device) model = load_model('text300M', device=device) model.load_state_dict(torch.load('./load/shapE_finetuned_with_330kdata.pth', map_location=device)['model_state_dict']) diffusion = diffusion_from_config_shape(load_config('diffusion')) batch_size = 1 guidance_scale = 15.0 prompt = str(self.cfg.prompt_processor.prompt) print('prompt',prompt) latents = sample_latents( batch_size=batch_size, model=model, diffusion=diffusion, guidance_scale=guidance_scale, model_kwargs=dict(texts=[prompt] * batch_size), progress=True, clip_denoised=True, use_fp16=True, use_karras=True, karras_steps=64, sigma_min=1e-3, sigma_max=160, s_churn=0, ) render_mode = 'nerf' # you can change this to 'stf' size = 256 # this is the size of the renders; higher values take longer to render. cameras = create_pan_cameras(size, device) self.shapeimages = decode_latent_images(xm, latents[0], cameras, rendering_mode=render_mode) pc = decode_latent_mesh(xm, latents[0]).tri_mesh() skip = 4 coords = pc.verts rgb = np.concatenate([pc.vertex_channels['R'][:,None],pc.vertex_channels['G'][:,None],pc.vertex_channels['B'][:,None]],axis=1) coords = coords[::skip] rgb = rgb[::skip] self.num_pts = coords.shape[0] point_cloud = o3d.geometry.PointCloud() point_cloud.points = o3d.utility.Vector3dVector(coords) point_cloud.colors = o3d.utility.Vector3dVector(rgb) self.point_cloud = point_cloud return coords,rgb,0.4 def add_points(self,coords,rgb): pcd_by3d = o3d.geometry.PointCloud() pcd_by3d.points = o3d.utility.Vector3dVector(np.array(coords)) bbox = pcd_by3d.get_axis_aligned_bounding_box() np.random.seed(0) num_points = 1000000 points = np.random.uniform(low=np.asarray(bbox.min_bound), high=np.asarray(bbox.max_bound), size=(num_points, 3)) kdtree = o3d.geometry.KDTreeFlann(pcd_by3d) points_inside = [] color_inside= [] for point in points: _, idx, _ = kdtree.search_knn_vector_3d(point, 1) nearest_point = np.asarray(pcd_by3d.points)[idx[0]] if np.linalg.norm(point - nearest_point) < 0.01: # 这个阈值可能需要调整 points_inside.append(point) color_inside.append(rgb[idx[0]]+0.2*np.random.random(3)) all_coords = np.array(points_inside) all_rgb = np.array(color_inside) all_coords = np.concatenate([all_coords,coords],axis=0) all_rgb = np.concatenate([all_rgb,rgb],axis=0) return all_coords,all_rgb def pcb(self): # Since this data set has no colmap data, we start with random points coords,rgb,scale = self.shape() bound= self.radius*scale all_coords,all_rgb = self.add_points(coords,rgb) pcd = BasicPointCloud(points=all_coords *bound, colors=all_rgb, normals=np.zeros((self.num_pts, 3))) return pcd def forward(self, batch: Dict[str, Any],renderbackground = None) -> Dict[str, Any]: if renderbackground is None: renderbackground = self.background_tensor images = [] depths = [] self.viewspace_point_list = [] for id in range(batch['c2w_3dgs'].shape[0]): viewpoint_cam = Camera(c2w = batch['c2w_3dgs'][id],FoVy = batch['fovy'][id],height = batch['height'],width = batch['width']) render_pkg = render(viewpoint_cam, self.gaussian, self.pipe, renderbackground) image, viewspace_point_tensor, _, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] self.viewspace_point_list.append(viewspace_point_tensor) if id == 0: self.radii = radii else: self.radii = torch.max(radii,self.radii) depth = render_pkg["depth_3dgs"] depth = depth.permute(1, 2, 0) image = image.permute(1, 2, 0) images.append(image) depths.append(depth) images = torch.stack(images, 0) depths = torch.stack(depths, 0) self.visibility_filter = self.radii>0.0 render_pkg["comp_rgb"] = images render_pkg["depth"] = depths render_pkg["opacity"] = depths / (depths.max() + 1e-5) return { **render_pkg, } def on_fit_start(self) -> None: super().on_fit_start() # only used in training self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)( self.cfg.prompt_processor ) self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) def training_step(self, batch, batch_idx): self.gaussian.update_learning_rate(self.true_global_step) if self.true_global_step > 500: self.guidance.set_min_max_steps(min_step_percent=0.02, max_step_percent=0.55) self.gaussian.update_learning_rate(self.true_global_step) out = self(batch) prompt_utils = self.prompt_processor() images = out["comp_rgb"] guidance_eval = (self.true_global_step % 200 == 0) # guidance_eval = False guidance_out = self.guidance( images, prompt_utils, **batch, rgb_as_latents=False,guidance_eval=guidance_eval ) loss = 0.0 loss = loss + guidance_out['loss_sds'] *self.C(self.cfg.loss['lambda_sds']) loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean() self.log("train/loss_sparsity", loss_sparsity) loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity) opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped) self.log("train/loss_opaque", loss_opaque) loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque) if guidance_eval: self.guidance_evaluation_save( out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]], guidance_out["eval"], ) for name, value in self.cfg.loss.items(): self.log(f"train_params/{name}", self.C(value)) return {"loss": loss} def on_before_optimizer_step(self, optimizer): with torch.no_grad(): if self.true_global_step < 900: # 15000 viewspace_point_tensor_grad = torch.zeros_like(self.viewspace_point_list[0]) for idx in range(len(self.viewspace_point_list)): viewspace_point_tensor_grad = viewspace_point_tensor_grad + self.viewspace_point_list[idx].grad # Keep track of max radii in image-space for pruning self.gaussian.max_radii2D[self.visibility_filter] = torch.max(self.gaussian.max_radii2D[self.visibility_filter], self.radii[self.visibility_filter]) self.gaussian.add_densification_stats(viewspace_point_tensor_grad, self.visibility_filter) if self.true_global_step > 300 and self.true_global_step % 100 == 0: # 500 100 size_threshold = 20 if self.true_global_step > 500 else None # 3000 self.gaussian.densify_and_prune(0.0002 , 0.05, self.cameras_extent, size_threshold) def validation_step(self, batch, batch_idx): out = self(batch) self.save_image_grid( f"it{self.true_global_step}-{batch['index'][0]}.png", ( [ { "type": "rgb", "img": batch["rgb"][0], "kwargs": {"data_format": "HWC"}, } ] if "rgb" in batch else [] ) + [ { "type": "rgb", "img": out["comp_rgb"][0], "kwargs": {"data_format": "HWC"}, }, ] + ( [ { "type": "rgb", "img": out["comp_normal"][0], "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, } ] if "comp_normal" in out else [] ), name="validation_step", step=self.true_global_step, ) # save_path = self.get_save_path(f"it{self.true_global_step}-val.ply") # self.gaussian.save_ply(save_path) # load_ply(save_path,self.get_save_path(f"it{self.true_global_step}-val-color.ply")) def on_validation_epoch_end(self): pass def test_step(self, batch, batch_idx): only_rgb = True bg_color = [1, 1, 1] if False else [0, 0, 0] testbackground_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda") out = self(batch,testbackground_tensor) if only_rgb: self.save_image_grid( f"it{self.true_global_step}-test/{batch['index'][0]}.png", ( [ { "type": "rgb", "img": batch["rgb"][0], "kwargs": {"data_format": "HWC"}, } ] if "rgb" in batch else [] ) + [ { "type": "rgb", "img": out["comp_rgb"][0], "kwargs": {"data_format": "HWC"}, }, ] + ( [ { "type": "rgb", "img": out["comp_normal"][0], "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, } ] if "comp_normal" in out else [] ), name="test_step", step=self.true_global_step, ) else: self.save_image_grid( f"it{self.true_global_step}-test/{batch['index'][0]}.png", ( [ { "type": "rgb", "img": batch["rgb"][0], "kwargs": {"data_format": "HWC"}, } ] if "rgb" in batch else [] ) + [ { "type": "rgb", "img": out["comp_rgb"][0], "kwargs": {"data_format": "HWC"}, }, ] + ( [ { "type": "rgb", "img": out["comp_normal"][0], "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, } ] if "comp_normal" in out else [] ) + ( [ { "type": "grayscale", "img": out["depth"][0], "kwargs": {}, } ] if "depth" in out else [] ) + [ { "type": "grayscale", "img": out["opacity"][0, :, :, 0], "kwargs": {"cmap": None, "data_range": (0, 1)}, }, ], name="test_step", step=self.true_global_step, ) def on_test_epoch_end(self): self.save_img_sequence( f"it{self.true_global_step}-test", f"it{self.true_global_step}-test", "(\d+)\.png", save_format="mp4", fps=30, name="test", step=self.true_global_step, ) save_path = self.get_save_path(f"last.ply") self.gaussian.save_ply(save_path) # self.pointefig.savefig(self.get_save_path("pointe.png")) o3d.io.write_point_cloud(self.get_save_path("shape.ply"), self.point_cloud) self.save_gif_to_file(self.shapeimages, self.get_save_path("shape.gif")) load_ply(save_path,self.get_save_path(f"it{self.true_global_step}-test-color.ply")) def configure_optimizers(self): self.parser = ArgumentParser(description="Training script parameters") opt = OptimizationParams(self.parser) point_cloud = self.pcb() self.cameras_extent = 4.0 self.gaussian.create_from_pcd(point_cloud, self.cameras_extent) self.pipe = PipelineParams(self.parser) self.gaussian.training_setup(opt) ret = { "optimizer": self.gaussian.optimizer, } return ret