| import random |
| import numpy as np |
| import requests |
| from io import BytesIO |
| from PIL import Image |
| from statistics import mean |
| import copy |
| import json |
| from typing import Any, Mapping |
| import open_clip |
| import torch |
|
|
| from sentence_transformers.util import (semantic_search, |
| dot_score, |
| normalize_embeddings) |
|
|
|
|
| def nn_project(curr_embeds, embedding_layer, print_hits=False): |
| with torch.no_grad(): |
| bsz,seq_len,emb_dim = curr_embeds.shape |
| |
| curr_embeds = curr_embeds.reshape((-1,emb_dim)) |
| curr_embeds = normalize_embeddings(curr_embeds) |
|
|
| embedding_matrix = embedding_layer.weight |
| embedding_matrix = normalize_embeddings(embedding_matrix) |
| |
| hits = semantic_search(curr_embeds, embedding_matrix, |
| query_chunk_size=curr_embeds.shape[0], |
| top_k=1, |
| score_function=dot_score) |
|
|
| if print_hits: |
| all_hits = [] |
| for hit in hits: |
| all_hits.append(hit[0]["score"]) |
| print(f"mean hits:{mean(all_hits)}") |
| |
| nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits], device=curr_embeds.device) |
| nn_indices = nn_indices.reshape((bsz,seq_len)) |
|
|
| projected_embeds = embedding_layer(nn_indices) |
|
|
| return projected_embeds, nn_indices |
|
|
| def decode_ids(input_ids, tokenizer, by_token=False): |
| input_ids = input_ids.detach().cpu().numpy() |
|
|
| texts = [] |
|
|
| if by_token: |
| for input_ids_i in input_ids: |
| curr_text = [] |
| for tmp in input_ids_i: |
| curr_text.append(tokenizer.decode([tmp])) |
|
|
| texts.append('|'.join(curr_text)) |
| else: |
| for input_ids_i in input_ids: |
| texts.append(tokenizer.decode(input_ids_i)) |
|
|
| return texts |
|
|
| def get_target_feature(model, preprocess, tokenizer_funct, device, target_images=None, target_prompts=None): |
| if target_images is not None: |
| with torch.no_grad(): |
| curr_images = [preprocess(i).unsqueeze(0) for i in target_images] |
| curr_images = torch.concatenate(curr_images).to(device) |
| all_target_features = model.encode_image(curr_images) |
| else: |
| texts = tokenizer_funct(target_prompts).to(device) |
| all_target_features = model.encode_text(texts) |
|
|
| return all_target_features |
|
|
| def encode_text_embedding(model, text_embedding, ids, avg_text=False): |
| cast_dtype = model.transformer.get_cast_dtype() |
|
|
| x = text_embedding + model.positional_embedding.to(cast_dtype) |
| x = x.permute(1, 0, 2) |
| x = model.transformer(x, attn_mask=model.attn_mask) |
| x = x.permute(1, 0, 2) |
| x = model.ln_final(x) |
|
|
| |
| |
| if avg_text: |
| x = x[torch.arange(x.shape[0]), :ids.argmax(dim=-1)] |
| x[:, 1:-1] |
| x = x.mean(dim=1) @ model.text_projection |
| else: |
| x = x[torch.arange(x.shape[0]), ids.argmax(dim=-1)] @ model.text_projection |
|
|
| return x |
| |
| def forward_text_embedding(model, embeddings, ids, image_features, avg_text=False, return_feature=False): |
| text_features = encode_text_embedding(model, embeddings, ids, avg_text=avg_text) |
|
|
| if return_feature: |
| return text_features |
|
|
| image_features = image_features / image_features.norm(dim=1, keepdim=True) |
| text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
|
| logits_per_image = image_features @ text_features.t() |
| logits_per_text = logits_per_image.t() |
|
|
| return logits_per_image, logits_per_text |
| |
| def initialize_prompt(tokenizer, token_embedding, args, device, original_prompt): |
| prompt_len = args["prompt_len"] |
|
|
| |
| tokens = tokenizer.encode(original_prompt) |
| if len(tokens) > prompt_len: |
| tokens = tokens[:prompt_len] |
| if len(tokens) < prompt_len: |
| tokens += [0] * (prompt_len - len(tokens)) |
| |
| prompt_ids = torch.tensor([tokens] * args["prompt_bs"]).to(device) |
| |
| prompt_embeds = token_embedding(prompt_ids).detach() |
| prompt_embeds.requires_grad = True |
|
|
| |
| template_text = "{}" |
| padded_template_text = template_text.format(" ".join(["<start_of_text>"] * prompt_len)) |
| dummy_ids = tokenizer.encode(padded_template_text) |
|
|
| |
| dummy_ids = [i if i != 49406 else -1 for i in dummy_ids] |
| dummy_ids = [49406] + dummy_ids + [49407] |
| dummy_ids += [0] * (77 - len(dummy_ids)) |
| dummy_ids = torch.tensor([dummy_ids] * args["prompt_bs"]).to(device) |
|
|
| |
| tmp_dummy_ids = copy.deepcopy(dummy_ids) |
| tmp_dummy_ids[tmp_dummy_ids == -1] = 0 |
| dummy_embeds = token_embedding(tmp_dummy_ids).detach() |
| dummy_embeds.requires_grad = False |
| |
| return prompt_embeds, dummy_embeds, dummy_ids |
|
|
| def optimize_prompt_loop(model, tokenizer, token_embedding, all_target_features, args, device, original_prompt): |
| opt_iters = args["iter"] |
| lr = args["lr"] |
| weight_decay = args["weight_decay"] |
| print_step = args["print_step"] |
| batch_size = args["batch_size"] |
| print_new_best = True |
| |
| |
| prompt_embeds, dummy_embeds, dummy_ids = initialize_prompt(tokenizer, token_embedding, args, device, original_prompt) |
| p_bs, p_len, p_dim = prompt_embeds.shape |
|
|
| |
| input_optimizer = torch.optim.AdamW([prompt_embeds], lr=lr, weight_decay=weight_decay) |
|
|
| best_sim = -1000 * args["loss_weight"] |
| best_text = "" |
|
|
| for step in range(opt_iters): |
| |
| if batch_size is None: |
| target_features = all_target_features |
| else: |
| curr_indx = torch.randperm(len(all_target_features)) |
| target_features = all_target_features[curr_indx][0:batch_size] |
| |
| universal_target_features = all_target_features |
|
|
| |
| projected_embeds, nn_indices = nn_project(prompt_embeds, token_embedding, print_hits=False) |
|
|
| |
| with torch.no_grad(): |
| |
| padded_embeds = dummy_embeds.detach().clone() |
| padded_embeds[dummy_ids == -1] = projected_embeds.reshape(-1, p_dim) |
| logits_per_image, _ = forward_text_embedding(model, padded_embeds, dummy_ids, universal_target_features) |
| scores_per_prompt = logits_per_image.mean(dim=0) |
| universal_cosim_score = scores_per_prompt.max().item() |
| best_indx = scores_per_prompt.argmax().item() |
| |
| |
| tmp_embeds = prompt_embeds.detach().clone() |
| tmp_embeds.data = projected_embeds.data |
| tmp_embeds.requires_grad = True |
| |
| |
| |
| padded_embeds = dummy_embeds.detach().clone() |
| padded_embeds[dummy_ids == -1] = tmp_embeds.reshape(-1, p_dim) |
| |
| logits_per_image, _ = forward_text_embedding(model, padded_embeds, dummy_ids, target_features) |
| cosim_scores = logits_per_image |
| loss = 1 - cosim_scores.mean() |
| loss = loss * args["loss_weight"] |
| |
| prompt_embeds.grad, = torch.autograd.grad(loss, [tmp_embeds]) |
| |
| input_optimizer.step() |
| input_optimizer.zero_grad() |
|
|
| curr_lr = input_optimizer.param_groups[0]["lr"] |
| cosim_scores = cosim_scores.mean().item() |
|
|
| decoded_text = decode_ids(nn_indices, tokenizer)[best_indx] |
| if print_step is not None and (step % print_step == 0 or step == opt_iters-1): |
| per_step_message = f"step: {step}, lr: {curr_lr}" |
| |
| |
| |
|
|
| if best_sim * args["loss_weight"] < universal_cosim_score * args["loss_weight"]: |
| best_sim = universal_cosim_score |
| best_text = decoded_text |
| if print_new_best: |
| print(f"step: {step}, new best cosine sim: {best_sim}, new best prompt: {best_text}") |
|
|
| if print_step is not None: |
| print(f"best cosine sim: {best_sim}, best prompt: {best_text}") |
|
|
| return best_text |
|
|
|
|
| def optimize_prompt(model, preprocess, args, device, target_images=None, target_prompts=None): |
| token_embedding = model.token_embedding |
| tokenizer = open_clip.tokenizer._tokenizer |
| tokenizer_funct = open_clip.get_tokenizer(args["clip_model"]) |
|
|
| all_target_features = get_target_feature(model, preprocess, tokenizer_funct, device, target_images=target_images) |
| learned_prompt = optimize_prompt_loop(model, tokenizer, token_embedding, all_target_features, args, device, target_prompts) |
|
|
| return learned_prompt |
| |
|
|