| import math |
| import random |
|
|
| import torch |
| from diffusers import DiffusionPipeline, DDPMScheduler |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput |
| from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput |
| from diffusers.image_processor import VaeImageProcessor |
| from huggingface_hub import PyTorchModelHubMixin |
| from PIL import Image |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
|
|
|
|
|
|
| class CombinedStableDiffusion( |
| DiffusionPipeline, |
| PyTorchModelHubMixin |
| ): |
| """ |
| A Stable Diffusion model wrapper that provides functionality for text-to-image synthesis, |
| noise scheduling, latent space manipulation, and image decoding. |
| """ |
| def __init__( |
| self, |
| original_unet: torch.nn.Module, |
| fine_tuned_unet: torch.nn.Module, |
| scheduler: DDPMScheduler, |
| vae: torch.nn.Module, |
| tokenizer: CLIPTextModel, |
| safety_checker: StableDiffusionSafetyChecker, |
| feature_extractor: CLIPImageProcessor, |
| text_encoder: CLIPTokenizer, |
| ) -> None: |
|
|
| super().__init__() |
|
|
| self.register_modules( |
| tokenizer=tokenizer, |
| text_encoder=text_encoder, |
| original_unet=original_unet, |
| fine_tuned_unet=fine_tuned_unet, |
| scheduler=scheduler, |
| vae=vae, |
| safety_checker=safety_checker, |
| feature_extractor=feature_extractor, |
| ) |
|
|
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| self.image_processor = VaeImageProcessor( |
| vae_scale_factor=self.vae_scale_factor |
| ) |
|
|
| def _get_negative_prompts(self, batch_size: int) -> torch.Tensor: |
| return self.tokenizer( |
| [""] * batch_size, |
| max_length=self.tokenizer.model_max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ).input_ids |
|
|
| def _get_encoder_hidden_states( |
| self, tokenized_prompts: torch.Tensor, do_classifier_free_guidance: bool = False |
| ) -> torch.Tensor: |
| if do_classifier_free_guidance: |
| tokenized_prompts = torch.cat( |
| [ |
| self._get_negative_prompts(tokenized_prompts.shape[0]).to( |
| tokenized_prompts.device |
| ), |
| tokenized_prompts, |
| ] |
| ) |
|
|
| return self.text_encoder(tokenized_prompts)[0] |
|
|
| def _get_unet_prediction( |
| self, |
| latent_model_input: torch.Tensor, |
| timestep: int, |
| encoder_hidden_states: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Return unet noise prediction |
| |
| Args: |
| latent_model_input (torch.Tensor): Unet latents input |
| timestep (int): noise scheduler timestep |
| encoder_hidden_states (torch.Tensor): Text encoder hidden states |
| |
| Returns: |
| torch.Tensor: noise prediction |
| """ |
| unet = self.original_unet if self._use_original_unet else self.fine_tuned_unet |
|
|
| return unet( |
| latent_model_input, |
| timestep=timestep, |
| encoder_hidden_states=encoder_hidden_states, |
| ).sample |
|
|
| def get_noise_prediction( |
| self, |
| latents: torch.Tensor, |
| timestep_index: int, |
| encoder_hidden_states: torch.Tensor, |
| do_classifier_free_guidance: bool = False, |
| detach_main_path: bool = False, |
| ): |
| """ |
| Return noise prediction |
| |
| Args: |
| latents (torch.Tensor): Image latents |
| timestep_index (int): noise scheduler timestep index |
| encoder_hidden_states (torch.Tensor): Text encoder hidden states |
| do_classifier_free_guidance (bool) Whether to do classifier free guidance |
| detach_main_path (bool): Detach gradient |
| |
| Returns: |
| torch.Tensor: noise prediction |
| """ |
| timestep = self.scheduler.timesteps[timestep_index] |
|
|
| latent_model_input = self.scheduler.scale_model_input( |
| sample=torch.cat([latents] * 2) if do_classifier_free_guidance else latents, |
| timestep=timestep, |
| ) |
|
|
| noise_pred = self._get_unet_prediction( |
| latent_model_input=latent_model_input, |
| timestep=timestep, |
| encoder_hidden_states=encoder_hidden_states, |
| ) |
|
|
| if do_classifier_free_guidance: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| if detach_main_path: |
| noise_pred_text = noise_pred_text.detach() |
|
|
| noise_pred = noise_pred_uncond + self.guidance_scale * ( |
| noise_pred_text - noise_pred_uncond |
| ) |
| return noise_pred |
|
|
| def sample_next_latents( |
| self, |
| latents: torch.Tensor, |
| timestep_index: int, |
| noise_pred: torch.Tensor, |
| return_pred_original: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Return next latents prediction |
| |
| Args: |
| latents (torch.Tensor): Image latents |
| timestep_index (int): noise scheduler timestep index |
| noise_pred (torch.Tensor): noise prediction |
| return_pred_original (bool) Whether to sample original sample |
| |
| Returns: |
| torch.Tensor: latent prediction |
| """ |
| timestep = self.scheduler.timesteps[timestep_index] |
| sample = self.scheduler.step( |
| model_output=noise_pred, timestep=timestep, sample=latents |
| ) |
| return ( |
| sample.pred_original_sample if return_pred_original else sample.prev_sample |
| ) |
|
|
| def predict_next_latents( |
| self, |
| latents: torch.Tensor, |
| timestep_index: int, |
| encoder_hidden_states: torch.Tensor, |
| return_pred_original: bool = False, |
| do_classifier_free_guidance: bool = False, |
| detach_main_path: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Predicts the next latent states during the diffusion process. |
| |
| Args: |
| latents (torch.Tensor): Current latent states. |
| timestep_index (int): Index of the current timestep. |
| encoder_hidden_states (torch.Tensor): Encoder hidden states from the text encoder. |
| return_pred_original (bool): Whether to return the predicted original sample. |
| do_classifier_free_guidance (bool) Whether to do classifier free guidance |
| detach_main_path (bool): Detach gradient |
| |
| Returns: |
| tuple: Next latents and predicted noise tensor. |
| """ |
|
|
| noise_pred = self.get_noise_prediction( |
| latents=latents, |
| timestep_index=timestep_index, |
| encoder_hidden_states=encoder_hidden_states, |
| do_classifier_free_guidance=do_classifier_free_guidance, |
| detach_main_path=detach_main_path, |
| ) |
|
|
| latents = self.sample_next_latents( |
| latents=latents, |
| noise_pred=noise_pred, |
| timestep_index=timestep_index, |
| return_pred_original=return_pred_original, |
| ) |
|
|
| return latents, noise_pred |
|
|
| def get_latents(self, batch_size: int, device: torch.device) -> torch.Tensor: |
| latent_resolution = int(self.resolution) // self.vae_scale_factor |
| return torch.randn( |
| ( |
| batch_size, |
| self.original_unet.config.in_channels, |
| latent_resolution, |
| latent_resolution, |
| ), |
| device=device, |
| ) |
|
|
| def do_k_diffusion_steps( |
| self, |
| start_timestep_index: int, |
| end_timestep_index: int, |
| latents: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| return_pred_original: bool = False, |
| do_classifier_free_guidance: bool = False, |
| detach_main_path: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Performs multiple diffusion steps between specified timesteps. |
| |
| Args: |
| start_timestep_index (int): Starting timestep index. |
| end_timestep_index (int): Ending timestep index. |
| latents (torch.Tensor): Initial latents. |
| encoder_hidden_states (torch.Tensor): Encoder hidden states. |
| return_pred_original (bool): Whether to return the predicted original sample. |
| do_classifier_free_guidance (bool) Whether to do classifier free guidance |
| detach_main_path (bool): Detach gradient |
| |
| Returns: |
| tuple: Resulting latents and encoder hidden states. |
| """ |
| assert start_timestep_index <= end_timestep_index |
|
|
| for timestep_index in range(start_timestep_index, end_timestep_index - 1): |
| latents, _ = self.predict_next_latents( |
| latents=latents, |
| timestep_index=timestep_index, |
| encoder_hidden_states=encoder_hidden_states, |
| return_pred_original=False, |
| do_classifier_free_guidance=do_classifier_free_guidance, |
| detach_main_path=detach_main_path, |
| ) |
| res, _ = self.predict_next_latents( |
| latents=latents, |
| timestep_index=end_timestep_index - 1, |
| encoder_hidden_states=encoder_hidden_states, |
| return_pred_original=return_pred_original, |
| do_classifier_free_guidance=do_classifier_free_guidance, |
| ) |
| return res, encoder_hidden_states |
|
|
| def get_pil_image(self, raw_images: torch.Tensor) -> list[Image]: |
| do_denormalize = [True] * raw_images.shape[0] |
| images = self.inference_image_processor.postprocess( |
| raw_images, output_type="pil", do_denormalize=do_denormalize |
| ) |
| return images |
|
|
| def get_reward_image(self, raw_images: torch.Tensor) -> torch.Tensor: |
| reward_images = (raw_images / 2 + 0.5).clamp(0, 1) |
|
|
| if self.use_image_shifting: |
| self._shift_tensor_batch( |
| reward_images, |
| dx=random.randint(0, math.ceil(self.resolution / 224)), |
| dy=random.randint(0, math.ceil(self.resolution / 224)), |
| ) |
|
|
| return self.reward_image_processor(reward_images) |
|
|
| def run_safety_checker(self, image, device, dtype): |
| if self.safety_checker is None: |
| has_nsfw_concept = None |
| else: |
| if torch.is_tensor(image): |
| feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") |
| else: |
| feature_extractor_input = self.image_processor.numpy_to_pil(image) |
| safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) |
| image, has_nsfw_concept = self.safety_checker( |
| images=image, clip_input=safety_checker_input.pixel_values.to(dtype) |
| ) |
| return image, has_nsfw_concept |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt: str | list[str], |
| num_inference_steps=40, |
| original_unet_steps=30, |
| resolution=512, |
| guidance_scale=7.5, |
| output_type: str = "pil", |
| return_dict: bool = True, |
| generator=None, |
| ): |
| self.guidance_scale = guidance_scale |
| batch_size = 1 if isinstance(prompt, str) else len(prompt) |
|
|
| tokenized_prompts = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| padding="max_length", |
| max_length=self.tokenizer.model_max_length, |
| truncation=True |
| ).input_ids.to(self.device) |
| original_encoder_hidden_states = self._get_encoder_hidden_states( |
| tokenized_prompts=tokenized_prompts, |
| do_classifier_free_guidance=True |
| ) |
| fine_tuned_encoder_hidden_states = self._get_encoder_hidden_states( |
| tokenized_prompts=tokenized_prompts, |
| do_classifier_free_guidance=False |
| ) |
|
|
| latent_resolution = int(resolution) // self.vae_scale_factor |
| latents = torch.randn( |
| ( |
| batch_size, |
| self.original_unet.config.in_channels, |
| latent_resolution, |
| latent_resolution, |
| ), |
| device=self.device, |
| ) |
|
|
| self.scheduler.set_timesteps( |
| num_inference_steps, |
| device=self.device |
| ) |
|
|
| self._use_original_unet = True |
| latents, _ = self.do_k_diffusion_steps( |
| start_timestep_index=0, |
| end_timestep_index=original_unet_steps, |
| latents=latents, |
| encoder_hidden_states=original_encoder_hidden_states, |
| return_pred_original=False, |
| do_classifier_free_guidance=True, |
| ) |
|
|
| self._use_original_unet = False |
| latents, _ = self.do_k_diffusion_steps( |
| start_timestep_index=original_unet_steps, |
| end_timestep_index=num_inference_steps, |
| latents=latents, |
| encoder_hidden_states=fine_tuned_encoder_hidden_states, |
| return_pred_original=False, |
| do_classifier_free_guidance=False, |
| ) |
|
|
| if not output_type == "latent": |
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ |
| 0 |
| ] |
| image, has_nsfw_concept = self.run_safety_checker( |
| image, self.device, original_encoder_hidden_states.dtype) |
| else: |
| image = latents |
| has_nsfw_concept = None |
|
|
| if has_nsfw_concept is None: |
| do_denormalize = [True] * image.shape[0] |
| else: |
| do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
| image = self.image_processor.postprocess( |
| image, |
| output_type=output_type, |
| do_denormalize=do_denormalize |
| ) |
|
|
| |
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return image, has_nsfw_concept |
|
|
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
|