Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Image inference module for VAREdit model. | |
| Supports 2B and 8B model variants for image editing with text instructions. | |
| """ | |
| import argparse | |
| import logging | |
| from typing import Tuple, Any, Optional | |
| from torchvision.transforms.functional import to_tensor | |
| import numpy as np | |
| from PIL import Image | |
| import PIL.Image as PImage | |
| from tools.run_infinity import ( | |
| load_tokenizer, load_visual_tokenizer, load_transformer, | |
| gen_one_img, h_div_w_templates, dynamic_resolution_h_w | |
| ) | |
| import time | |
| import torch | |
| def transform(pil_img, target_image_size): | |
| # currently only support square image. | |
| width, height = pil_img.size | |
| max_dim = max(width, height) | |
| padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
| padded_image.paste(pil_img, (0, 0)) | |
| def crop_op(image): | |
| image = image.resize((max_dim, max_dim), resample=PImage.LANCZOS) | |
| crop_image = image.crop((0, 0, width, height)) | |
| return crop_image | |
| padded_image = padded_image.resize((target_image_size, target_image_size), resample=PImage.LANCZOS) | |
| im = to_tensor(np.array(padded_image)) | |
| return im.add(im).add_(-1), crop_op | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Model configurations | |
| MODEL_CONFIGS = { | |
| '2B': { | |
| 'vae_filename': 'infinity_vae_d32reg.pth', | |
| 'vae_type': 32, | |
| 'model_type': 'infinity_2b', | |
| 'apply_spatial_patchify': 0, | |
| }, | |
| '8B': { | |
| 'vae_filename': 'infinity_vae_d56_f8_14_patchify.pth', | |
| 'vae_type': 14, | |
| 'model_type': 'infinity_8b', | |
| 'apply_spatial_patchify': 1, | |
| } | |
| } | |
| # Common model arguments | |
| COMMON_ARGS = { | |
| 'cfg_insertion_layer': 0, | |
| 'add_lvl_embeding_only_first_block': 1, | |
| 'use_bit_label': 1, | |
| 'rope2d_each_sa_layer': 1, | |
| 'rope2d_normalized_by_hw': 2, | |
| 'use_scale_schedule_embedding': 0, | |
| 'sampling_per_bits': 1, | |
| 'text_channels': 2048, | |
| 'h_div_w_template': 1.000, | |
| 'use_flex_attn': 0, | |
| 'cache_dir': '/dev/shm', | |
| 'checkpoint_type': 'torch', | |
| 'bf16': 1, | |
| 'enable_model_cache': 0, | |
| } | |
| def load_model(pretrain_root: str, model_path: str, model_size: str, image_size: int) -> Tuple[Any, ...]: | |
| """ | |
| Load the model and its components. | |
| Args: | |
| pretrain_root: Root directory for pretrained models | |
| model_path: Path to the specific model checkpoint | |
| Returns: | |
| Tuple of (args, model, vae, tokenizer, text_encoder) | |
| Raises: | |
| ValueError: If unsupported model size is specified | |
| """ | |
| if model_size not in MODEL_CONFIGS: | |
| raise ValueError(f"Unsupported model size: {model_size}. Choose '2B' or '8B'.") | |
| config = MODEL_CONFIGS[model_size] | |
| # Build arguments | |
| args_dict = { | |
| **COMMON_ARGS, | |
| **config, | |
| 'model_path': model_path, | |
| 'vae_path': f"{pretrain_root}/{config['vae_filename']}", | |
| 'text_encoder_ckpt': f"{pretrain_root}/flan-t5-xl" | |
| } | |
| args = argparse.Namespace(**args_dict) | |
| if image_size == 512: | |
| args.pn = "0.25M" | |
| elif image_size == 1024: | |
| args.pn = "1M" | |
| else: | |
| raise ValueError(f"Unsupported image size: {image_size}. Choose 512 or 1024.") | |
| logger.info(f"Loading {model_size} model from {model_path}") | |
| # Load components | |
| text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt) | |
| vae = load_visual_tokenizer(args) | |
| model = load_transformer(vae, args) | |
| logger.info("Model loaded successfully") | |
| return args, model, vae, text_tokenizer, text_encoder, image_size | |
| def generate_image( | |
| model_components: Tuple[Any, ...], | |
| src_img_path: str, | |
| instruction: str, | |
| cfg: float = 4.0, | |
| tau: float = 0.5, | |
| seed: Optional[int] = -1, | |
| ) -> None: | |
| """ | |
| Generate edited image based on source image and text instruction. | |
| Args: | |
| model_components: Tuple of (args, model, vae, tokenizer, text_encoder) | |
| src_img_path: Path to source image | |
| instruction: Text instruction for editing | |
| cfg: Classifier-free guidance scale | |
| tau: Temperature parameter | |
| """ | |
| args, model, vae, tokenizer, text_encoder, image_size = model_components | |
| # Set default image size | |
| assert image_size in [512, 1024], f"Invalid image size: {image_size}, expected 512 or 1024" | |
| if image_size == 512: | |
| pn = "0.25M" | |
| elif image_size == 1024: | |
| pn = "1M" | |
| # Load and preprocess source image | |
| try: | |
| with Image.open(src_img_path) as src_img: | |
| src_img = src_img.convert('RGB') | |
| src_img_tensor, crop_op = transform(src_img, image_size) | |
| except Exception as e: | |
| logger.error(f"Failed to load source image: {e}") | |
| raise | |
| # Set up generation parameters | |
| aspect_ratio = 1.0 # h:w ratio | |
| h_div_w_template = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - aspect_ratio))] | |
| scale_schedule = [(1, h, w) for (_, h, w) in dynamic_resolution_h_w[h_div_w_template][pn]['scales']] | |
| logger.info(f"Generating image with instruction: '{instruction}'") | |
| # Generate image | |
| if seed == -1: | |
| seed = np.random.randint(0, 1000000) | |
| torch.cuda.empty_cache() | |
| start_time = time.time() | |
| generated_image = gen_one_img( | |
| model, vae, tokenizer, text_encoder, | |
| instruction, src_img_tensor, | |
| g_seed=seed, | |
| gt_leak=0, | |
| gt_ls_Bl=None, | |
| cfg_list=cfg, | |
| tau_list=tau, | |
| scale_schedule=scale_schedule, | |
| cfg_insertion_layer=[args.cfg_insertion_layer], | |
| vae_type=args.vae_type, | |
| sampling_per_bits=args.sampling_per_bits, | |
| enable_positive_prompt=0, | |
| apply_spatial_patchify=args.apply_spatial_patchify, | |
| ) | |
| end_time = time.time() | |
| logger.info(f"Time taken: {end_time - start_time:.2f} seconds") | |
| max_memory = torch.cuda.max_memory_allocated() / 1024 ** 3 | |
| logger.info(f"Max memory: {max_memory:.2f} GB") | |
| generated_image_np = generated_image.cpu().numpy() | |
| if generated_image_np.shape[2] == 3: | |
| generated_image_np = generated_image_np[..., ::-1] | |
| result_image = Image.fromarray(generated_image_np.astype(np.uint8)) | |
| result_image = crop_op(result_image) | |
| return result_image | |
| def main(): | |
| """Main execution function with example usage.""" | |
| try: | |
| # Load model | |
| model_components = load_model( | |
| "HiDream-ai/VAREdit", | |
| "HiDream-ai/VAREdit/8B-1024.pth", | |
| "8B", | |
| 1024 | |
| ) | |
| # Generate image | |
| generate_image( | |
| model_components, | |
| "assets/test.jpg", | |
| "Add glasses to this girl and change hair color to red", | |
| cfg=3.0, | |
| tau=1.0, | |
| seed=42 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Inference failed: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| main() |