import json import os from dataclasses import dataclass, field import torch import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F from pytorch_lightning.utilities.rank_zero import rank_zero_only from transformers import AutoTokenizer, BertForMaskedLM import threestudio from threestudio.utils.base import BaseObject from threestudio.utils.misc import barrier, cleanup, get_rank from threestudio.utils.ops import shifted_cosine_decay, shifted_expotional_decay from threestudio.utils.typing import * def hash_prompt(model: str, prompt: str) -> str: import hashlib identifier = f"{model}-{prompt}" return hashlib.md5(identifier.encode()).hexdigest() @dataclass class DirectionConfig: name: str prompt: Callable[[str], str] negative_prompt: Callable[[str], str] condition: Callable[ [Float[Tensor, "B"], Float[Tensor, "B"], Float[Tensor, "B"]], Float[Tensor, "B"], ] @dataclass class PromptProcessorOutput: text_embeddings: Float[Tensor, "N Nf"] uncond_text_embeddings: Float[Tensor, "N Nf"] text_embeddings_vd: Float[Tensor, "Nv N Nf"] uncond_text_embeddings_vd: Float[Tensor, "Nv N Nf"] directions: List[DirectionConfig] direction2idx: Dict[str, int] use_perp_neg: bool perp_neg_f_sb: Tuple[float, float, float] perp_neg_f_fsb: Tuple[float, float, float] perp_neg_f_fs: Tuple[float, float, float] perp_neg_f_sf: Tuple[float, float, float] def get_text_embeddings( self, elevation: Float[Tensor, "B"], azimuth: Float[Tensor, "B"], camera_distances: Float[Tensor, "B"], view_dependent_prompting: bool = True, ) -> Float[Tensor, "BB N Nf"]: batch_size = elevation.shape[0] if view_dependent_prompting: # Get direction direction_idx = torch.zeros_like(elevation, dtype=torch.long) for d in self.directions: direction_idx[ d.condition(elevation, azimuth, camera_distances) ] = self.direction2idx[d.name] # Get text embeddings text_embeddings = self.text_embeddings_vd[direction_idx] # type: ignore uncond_text_embeddings = self.uncond_text_embeddings_vd[direction_idx] # type: ignore else: text_embeddings = self.text_embeddings.expand(batch_size, -1, -1) # type: ignore uncond_text_embeddings = self.uncond_text_embeddings.expand( # type: ignore batch_size, -1, -1 ) # IMPORTANT: we return (cond, uncond), which is in different order than other implementations! return torch.cat([text_embeddings, uncond_text_embeddings], dim=0) def get_text_embeddings_perp_neg( self, elevation: Float[Tensor, "B"], azimuth: Float[Tensor, "B"], camera_distances: Float[Tensor, "B"], view_dependent_prompting: bool = True, ) -> Tuple[Float[Tensor, "BBBB N Nf"], Float[Tensor, "B 2"]]: assert ( view_dependent_prompting ), "Perp-Neg only works with view-dependent prompting" batch_size = elevation.shape[0] direction_idx = torch.zeros_like(elevation, dtype=torch.long) for d in self.directions: direction_idx[ d.condition(elevation, azimuth, camera_distances) ] = self.direction2idx[d.name] # 0 - side view # 1 - front view # 2 - back view # 3 - overhead view pos_text_embeddings = [] neg_text_embeddings = [] neg_guidance_weights = [] uncond_text_embeddings = [] side_emb = self.text_embeddings_vd[0] front_emb = self.text_embeddings_vd[1] back_emb = self.text_embeddings_vd[2] overhead_emb = self.text_embeddings_vd[3] for idx, ele, azi, dis in zip( direction_idx, elevation, azimuth, camera_distances ): azi = shift_azimuth_deg(azi) # to (-180, 180) uncond_text_embeddings.append( self.uncond_text_embeddings_vd[idx] ) # should be "" if idx.item() == 3: # overhead view pos_text_embeddings.append(overhead_emb) # side view # dummy neg_text_embeddings += [ self.uncond_text_embeddings_vd[idx], self.uncond_text_embeddings_vd[idx], ] neg_guidance_weights += [0.0, 0.0] else: # interpolating views if torch.abs(azi) < 90: # front-side interpolation # 0 - complete side, 1 - complete front r_inter = 1 - torch.abs(azi) / 90 pos_text_embeddings.append( r_inter * front_emb + (1 - r_inter) * side_emb ) neg_text_embeddings += [front_emb, side_emb] neg_guidance_weights += [ -shifted_expotional_decay(*self.perp_neg_f_fs, r_inter), -shifted_expotional_decay(*self.perp_neg_f_sf, 1 - r_inter), ] else: # side-back interpolation # 0 - complete back, 1 - complete side r_inter = 2.0 - torch.abs(azi) / 90 pos_text_embeddings.append( r_inter * side_emb + (1 - r_inter) * back_emb ) neg_text_embeddings += [side_emb, front_emb] neg_guidance_weights += [ -shifted_expotional_decay(*self.perp_neg_f_sb, r_inter), -shifted_expotional_decay(*self.perp_neg_f_fsb, r_inter), ] text_embeddings = torch.cat( [ torch.stack(pos_text_embeddings, dim=0), torch.stack(uncond_text_embeddings, dim=0), torch.stack(neg_text_embeddings, dim=0), ], dim=0, ) return text_embeddings, torch.as_tensor( neg_guidance_weights, device=elevation.device ).reshape(batch_size, 2) def shift_azimuth_deg(azimuth: Float[Tensor, "..."]) -> Float[Tensor, "..."]: # shift azimuth angle (in degrees), to [-180, 180] return (azimuth + 180) % 360 - 180 class PromptProcessor(BaseObject): @dataclass class Config(BaseObject.Config): prompt: str = "a hamburger" # manually assigned view-dependent prompts prompt_front: Optional[str] = None prompt_side: Optional[str] = None prompt_back: Optional[str] = None prompt_overhead: Optional[str] = None negative_prompt: str = "" pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5" overhead_threshold: float = 60.0 front_threshold: float = 45.0 back_threshold: float = 45.0 view_dependent_prompt_front: bool = False use_cache: bool = True spawn: bool = True # perp neg use_perp_neg: bool = False # a*e(-b*r) + c # a * e(-b) + c = 0 perp_neg_f_sb: Tuple[float, float, float] = (1, 0.5, -0.606) perp_neg_f_fsb: Tuple[float, float, float] = (1, 0.5, +0.967) perp_neg_f_fs: Tuple[float, float, float] = ( 4, 0.5, -2.426, ) # f_fs(1) = 0, a, b > 0 perp_neg_f_sf: Tuple[float, float, float] = (4, 0.5, -2.426) # prompt debiasing use_prompt_debiasing: bool = False pretrained_model_name_or_path_prompt_debiasing: str = "bert-base-uncased" # index of words that can potentially be removed prompt_debiasing_mask_ids: Optional[List[int]] = None cfg: Config @rank_zero_only def configure_text_encoder(self) -> None: raise NotImplementedError @rank_zero_only def destroy_text_encoder(self) -> None: raise NotImplementedError def configure(self) -> None: self._cache_dir = ".threestudio_cache/text_embeddings" # FIXME: hard-coded path # view-dependent text embeddings self.directions: List[DirectionConfig] if self.cfg.view_dependent_prompt_front: self.directions = [ DirectionConfig( "side", lambda s: f"side view of {s}", lambda s: s, lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool), ), DirectionConfig( "front", lambda s: f"front view of {s}", lambda s: s, lambda ele, azi, dis: ( shift_azimuth_deg(azi) > -self.cfg.front_threshold ) & (shift_azimuth_deg(azi) < self.cfg.front_threshold), ), DirectionConfig( "back", lambda s: f"backside view of {s}", lambda s: s, lambda ele, azi, dis: ( shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold ) | (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold), ), DirectionConfig( "overhead", lambda s: f"overhead view of {s}", lambda s: s, lambda ele, azi, dis: ele > self.cfg.overhead_threshold, ), ] else: self.directions = [ DirectionConfig( "side", lambda s: f"{s}, side view", lambda s: s, lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool), ), DirectionConfig( "front", lambda s: f"{s}, front view", lambda s: s, lambda ele, azi, dis: ( shift_azimuth_deg(azi) > -self.cfg.front_threshold ) & (shift_azimuth_deg(azi) < self.cfg.front_threshold), ), DirectionConfig( "back", lambda s: f"{s}, back view", lambda s: s, lambda ele, azi, dis: ( shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold ) | (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold), ), DirectionConfig( "overhead", lambda s: f"{s}, overhead view", lambda s: s, lambda ele, azi, dis: ele > self.cfg.overhead_threshold, ), ] self.direction2idx = {d.name: i for i, d in enumerate(self.directions)} with open(os.path.join("load/prompt_library.json"), "r") as f: self.prompt_library = json.load(f) # use provided prompt or find prompt in library self.prompt = self.preprocess_prompt(self.cfg.prompt) # use provided negative prompt self.negative_prompt = self.cfg.negative_prompt threestudio.info( f"Using prompt [{self.prompt}] and negative prompt [{self.negative_prompt}]" ) # view-dependent prompting if self.cfg.use_prompt_debiasing: assert ( self.cfg.prompt_side is None and self.cfg.prompt_back is None and self.cfg.prompt_overhead is None ), "Do not manually assign prompt_side, prompt_back or prompt_overhead when using prompt debiasing" prompts = self.get_debiased_prompt(self.prompt) self.prompts_vd = [ d.prompt(prompt) for d, prompt in zip(self.directions, prompts) ] else: self.prompts_vd = [ self.cfg.get(f"prompt_{d.name}", None) or d.prompt(self.prompt) # type: ignore for d in self.directions ] prompts_vd_display = " ".join( [ f"[{d.name}]:[{prompt}]" for prompt, d in zip(self.prompts_vd, self.directions) ] ) threestudio.info(f"Using view-dependent prompts {prompts_vd_display}") self.negative_prompts_vd = [ d.negative_prompt(self.negative_prompt) for d in self.directions ] self.prepare_text_embeddings() self.load_text_embeddings() @staticmethod def spawn_func(pretrained_model_name_or_path, prompts, cache_dir): raise NotImplementedError @rank_zero_only def prepare_text_embeddings(self): os.makedirs(self._cache_dir, exist_ok=True) all_prompts = ( [self.prompt] + [self.negative_prompt] + self.prompts_vd + self.negative_prompts_vd ) prompts_to_process = [] for prompt in all_prompts: if self.cfg.use_cache: # some text embeddings are already in cache # do not process them cache_path = os.path.join( self._cache_dir, f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt", ) if os.path.exists(cache_path): threestudio.debug( f"Text embeddings for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] are already in cache, skip processing." ) continue prompts_to_process.append(prompt) if len(prompts_to_process) > 0: if self.cfg.spawn: ctx = mp.get_context("spawn") subprocess = ctx.Process( target=self.spawn_func, args=( self.cfg.pretrained_model_name_or_path, prompts_to_process, self._cache_dir, ), ) subprocess.start() subprocess.join() else: self.spawn_func( self.cfg.pretrained_model_name_or_path, prompts_to_process, self._cache_dir, ) cleanup() def load_text_embeddings(self): # synchronize, to ensure the text embeddings have been computed and saved to cache barrier() self.text_embeddings = self.load_from_cache(self.prompt)[None, ...] self.uncond_text_embeddings = self.load_from_cache(self.negative_prompt)[ None, ... ] self.text_embeddings_vd = torch.stack( [self.load_from_cache(prompt) for prompt in self.prompts_vd], dim=0 ) self.uncond_text_embeddings_vd = torch.stack( [self.load_from_cache(prompt) for prompt in self.negative_prompts_vd], dim=0 ) threestudio.debug(f"Loaded text embeddings.") def load_from_cache(self, prompt): cache_path = os.path.join( self._cache_dir, f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt", ) if not os.path.exists(cache_path): raise FileNotFoundError( f"Text embedding file {cache_path} for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] not found." ) return torch.load(cache_path, map_location=self.device) def preprocess_prompt(self, prompt: str) -> str: if prompt.startswith("lib:"): # find matches in the library candidate = None keywords = prompt[4:].lower().split("_") for prompt in self.prompt_library["dreamfusion"]: if all([k in prompt.lower() for k in keywords]): if candidate is not None: raise ValueError( f"Multiple prompts matched with keywords {keywords} in library" ) candidate = prompt if candidate is None: raise ValueError( f"Cannot find prompt with keywords {keywords} in library" ) threestudio.info("Find matched prompt in library: " + candidate) return candidate else: return prompt def get_text_embeddings( self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] ) -> Tuple[Float[Tensor, "B ..."], Float[Tensor, "B ..."]]: raise NotImplementedError def get_debiased_prompt(self, prompt: str) -> List[str]: os.environ["TOKENIZERS_PARALLELISM"] = "false" tokenizer = AutoTokenizer.from_pretrained( self.cfg.pretrained_model_name_or_path_prompt_debiasing ) model = BertForMaskedLM.from_pretrained( self.cfg.pretrained_model_name_or_path_prompt_debiasing ) views = [d.name for d in self.directions] view_ids = tokenizer(" ".join(views), return_tensors="pt").input_ids[0] view_ids = view_ids[1:5] def modulate(prompt): prompt_vd = f"This image is depicting a [MASK] view of {prompt}" tokens = tokenizer( prompt_vd, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) mask_idx = torch.where(tokens.input_ids == tokenizer.mask_token_id)[1] logits = model(**tokens).logits logits = F.softmax(logits[0, mask_idx], dim=-1) logits = logits[0, view_ids] probes = logits / logits.sum() return probes prompts = [prompt.split(" ") for _ in range(4)] full_probe = modulate(prompt) n_words = len(prompt.split(" ")) prompt_debiasing_mask_ids = ( self.cfg.prompt_debiasing_mask_ids if self.cfg.prompt_debiasing_mask_ids is not None else list(range(n_words)) ) words_to_debias = [prompt.split(" ")[idx] for idx in prompt_debiasing_mask_ids] threestudio.info(f"Words that can potentially be removed: {words_to_debias}") for idx in prompt_debiasing_mask_ids: words = prompt.split(" ") prompt_ = " ".join(words[:idx] + words[(idx + 1) :]) part_probe = modulate(prompt_) pmi = full_probe / torch.lerp(part_probe, full_probe, 0.5) for i in range(pmi.shape[0]): if pmi[i].item() < 0.95: prompts[i][idx] = "" debiased_prompts = [" ".join([word for word in p if word]) for p in prompts] for d, debiased_prompt in zip(views, debiased_prompts): threestudio.info(f"Debiased prompt of the {d} view is [{debiased_prompt}]") del tokenizer, model cleanup() return debiased_prompts def __call__(self) -> PromptProcessorOutput: return PromptProcessorOutput( text_embeddings=self.text_embeddings, uncond_text_embeddings=self.uncond_text_embeddings, text_embeddings_vd=self.text_embeddings_vd, uncond_text_embeddings_vd=self.uncond_text_embeddings_vd, directions=self.directions, direction2idx=self.direction2idx, use_perp_neg=self.cfg.use_perp_neg, perp_neg_f_sb=self.cfg.perp_neg_f_sb, perp_neg_f_fsb=self.cfg.perp_neg_f_fsb, perp_neg_f_fs=self.cfg.perp_neg_f_fs, perp_neg_f_sf=self.cfg.perp_neg_f_sf, )