Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import torch | |
| try: | |
| from torch.utils.model_zoo import download_url_to_file | |
| except ImportError: | |
| from torch.hub import download_url_to_file | |
| import errno | |
| import sys | |
| import os | |
| import warnings | |
| import re | |
| from urllib.parse import urlparse | |
| import time | |
| ENV_TORCH_HOME = "TORCH_HOME" | |
| ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" | |
| DEFAULT_CACHE_DIR = "~/.cache" | |
| # matches bfd8deac from resnet18-bfd8deac.pth | |
| HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") | |
| # a context manager to measure time | |
| class Timer: | |
| def __init__(self, name): | |
| self.name = name | |
| def __enter__(self): | |
| self.start = time.time() | |
| def __exit__(self, *args): | |
| print(f'{self.name} took {time.time() - self.start:.2f}s') | |
| def load_frames_rgb(file,max_frames=-1,cvt_color=True): | |
| cap = cv2.VideoCapture(file) | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if cvt_color: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(frame) | |
| if max_frames > 0 and len(frames) >= max_frames: | |
| break | |
| cap.release() | |
| return frames | |
| def drawLandmark_multiple(img, bbox=None, landmark=None, color=(0, 255, 0)): | |
| """ | |
| Input: | |
| - img: gray or RGB | |
| - bbox: type of BBox | |
| - landmark: reproject landmark of (5L, 2L) | |
| Output: | |
| - img marked with landmark and bbox | |
| """ | |
| img = cv2.UMat(img).get() | |
| if bbox is not None: | |
| x1, y1, x2, y2 = np.array(bbox)[:4].astype(np.int32) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
| if landmark is not None: | |
| for x, y in np.array(landmark).astype(np.int32): | |
| cv2.circle(img, (int(x), int(y)), 2, color, -1) | |
| return img | |
| draw_landmarks=drawLandmark_multiple | |
| pretrained_urls = { | |
| "MobileNet": "https://github.com/elliottzheng/fast-alignment/releases/download/weights_v1/mobilenet_224_model_best_gdconv_external.pth", | |
| "PFLD": "https://github.com/elliottzheng/fast-alignment/releases/download/weights_v1/pfld_model_best.pth", | |
| "PFLD_onnx": "https://github.com/elliottzheng/fast-alignment/releases/download/weights_v1/PFLD.onnx", | |
| } | |
| def load_weights(file, backbone): | |
| if file is None: | |
| assert backbone in pretrained_urls | |
| url = pretrained_urls[backbone] | |
| return torch.utils.model_zoo.load_url(url) | |
| else: | |
| return torch.load(file, map_location="cpu") | |
| def _get_torch_home(): | |
| torch_home = os.path.expanduser( | |
| os.getenv( | |
| ENV_TORCH_HOME, | |
| os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"), | |
| ) | |
| ) | |
| return torch_home | |
| def auto_download_from_url(url, model_dir=None, map_location=None, progress=True): | |
| r"""Loads the Torch serialized object at the given URL. | |
| If the object is already present in `model_dir`, it's deserialized and | |
| returned. The filename part of the URL should follow the naming convention | |
| ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |
| digits of the SHA256 hash of the contents of the file. The hash is used to | |
| ensure unique names and to verify the contents of the file. | |
| The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where | |
| environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. | |
| ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux | |
| filesytem layout, with a default value ``~/.cache`` if not set. | |
| Args: | |
| url (string): URL of the object to download | |
| model_dir (string, optional): directory in which to save the object | |
| map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) | |
| progress (bool, optional): whether or not to display a progress bar to stderr | |
| Example: | |
| >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |
| """ | |
| # Issue warning to move data if old env is set | |
| if os.getenv("TORCH_MODEL_ZOO"): | |
| warnings.warn( | |
| "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead" | |
| ) | |
| if model_dir is None: | |
| torch_home = _get_torch_home() | |
| model_dir = os.path.join(torch_home, "checkpoints") | |
| try: | |
| os.makedirs(model_dir) | |
| except OSError as e: | |
| if e.errno == errno.EEXIST: | |
| # Directory already exists, ignore. | |
| pass | |
| else: | |
| # Unexpected OSError, re-raise. | |
| raise | |
| parts = urlparse(url) | |
| filename = os.path.basename(parts.path) | |
| cached_file = os.path.join(model_dir, filename) | |
| if not os.path.exists(cached_file): | |
| sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |
| download_url_to_file(url, cached_file, None, progress=progress) | |
| return cached_file | |
| def get_default_onnx_file(backbone): | |
| key = backbone + "_onnx" | |
| if key not in pretrained_urls: | |
| raise "default checkpoint for %s is not available" % backbone | |
| return auto_download_from_url(pretrained_urls[key]) | |
| def is_image(x): | |
| if isinstance(x, np.ndarray) and len(x.shape) == 3 and x.shape[-1] == 3: | |
| return True | |
| else: | |
| return False | |
| def is_box(x): | |
| try: | |
| x = np.array(x) | |
| assert len(x) == 4 | |
| assert (x[2:] - x[:2]).min() > 0 | |
| return True | |
| except: | |
| return False | |
| def is_face(x): | |
| try: | |
| assert is_box(x[0]) | |
| return True | |
| except: | |
| return False | |
| def detection_adapter(all_faces, batch=False): | |
| if not batch: | |
| if is_face(all_faces): # 单个检测结果 | |
| return all_faces[0] | |
| else: | |
| return [face[0] for face in all_faces] # 是单层列表 | |
| else: | |
| return [[face[0] for face in faces] for faces in all_faces] # 双层列表 | |
| def to_numpy(tensor): | |
| return ( | |
| tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
| ) | |
| def bbox_from_pts(ldm_new): | |
| (x1, y1), (x2, y2) = ldm_new.min(0), ldm_new.max(0) | |
| box_new = np.array([x1, y1, x2, y2]) | |
| box_new[:2] -= 10 | |
| box_new[2:] += 10 | |
| return box_new | |
| class Aligner: | |
| def __init__(self, standard_points, size) -> None: | |
| self.standard_points = standard_points # ndarray of N,2 | |
| self.size = size | |
| def __call__(self, img, landmarks): | |
| # ndarray image, landmarks N,2 | |
| from skimage import transform | |
| trans = transform.SimilarityTransform() | |
| res = trans.estimate(landmarks, self.standard_points) | |
| M = trans.params | |
| new_img = cv2.warpAffine(img, M[:2, :], dsize=(self.size, self.size)) | |
| return new_img | |