| |
|
| | import argparse
|
| | import binascii
|
| | import os
|
| | import os.path as osp
|
| |
|
| | import imageio
|
| | import torch
|
| | import torchvision
|
| | from sys import argv
|
| |
|
| | __all__ = ['cache_video', 'cache_image', 'str2bool']
|
| |
|
| |
|
| | def get_arguments(args=argv[1:]):
|
| | parser = get_argument_parser()
|
| | args = parser.parse_args(args)
|
| |
|
| |
|
| | if getattr(args, "local_rank", -1) == -1:
|
| | env_lr = os.environ.get("LOCAL_RANK") or os.environ.get("SLURM_LOCALID")
|
| | try:
|
| | if env_lr is not None:
|
| | args.local_rank = int(env_lr)
|
| | except ValueError:
|
| | pass
|
| |
|
| |
|
| | args.no_cuda = False
|
| |
|
| |
|
| | if torch.cuda.is_available() and getattr(args, "local_rank", -1) >= 0:
|
| | try:
|
| | torch.cuda.set_device(args.local_rank % torch.cuda.device_count())
|
| | except Exception:
|
| | pass
|
| |
|
| | return args
|
| |
|
| |
|
| | def get_argument_parser():
|
| | parser = argparse.ArgumentParser()
|
| | parser.add_argument("--config-file",
|
| | type=str,
|
| | default="ovi/configs/inference/inference_fusion.yaml")
|
| | parser.add_argument("--local_rank",
|
| | type=int,
|
| | default=-1,
|
| | help="local_rank for distributed training on gpus")
|
| |
|
| | return parser
|
| |
|
| |
|
| | def rand_name(length=8, suffix=''):
|
| | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
| | if suffix:
|
| | if not suffix.startswith('.'):
|
| | suffix = '.' + suffix
|
| | name += suffix
|
| | return name
|
| |
|
| |
|
| | def cache_video(tensor,
|
| | save_file=None,
|
| | fps=30,
|
| | suffix='.mp4',
|
| | nrow=8,
|
| | normalize=True,
|
| | value_range=(-1, 1),
|
| | retry=5):
|
| |
|
| | cache_file = osp.join('/tmp', rand_name(
|
| | suffix=suffix)) if save_file is None else save_file
|
| |
|
| |
|
| | error = None
|
| | for _ in range(retry):
|
| | try:
|
| |
|
| | tensor = tensor.clamp(min(value_range), max(value_range))
|
| | tensor = torch.stack([
|
| | torchvision.utils.make_grid(
|
| | u, nrow=nrow, normalize=normalize, value_range=value_range)
|
| | for u in tensor.unbind(2)
|
| | ],
|
| | dim=1).permute(1, 2, 3, 0)
|
| | tensor = (tensor * 255).type(torch.uint8).cpu()
|
| |
|
| |
|
| | writer = imageio.get_writer(
|
| | cache_file, fps=fps, codec='libx264', quality=8)
|
| | for frame in tensor.numpy():
|
| | writer.append_data(frame)
|
| | writer.close()
|
| | return cache_file
|
| | except Exception as e:
|
| | error = e
|
| | continue
|
| | else:
|
| | print(f'cache_video failed, error: {error}', flush=True)
|
| | return None
|
| |
|
| |
|
| | def cache_image(tensor,
|
| | save_file,
|
| | nrow=8,
|
| | normalize=True,
|
| | value_range=(-1, 1),
|
| | retry=5):
|
| |
|
| | suffix = osp.splitext(save_file)[1]
|
| | if suffix.lower() not in [
|
| | '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
|
| | ]:
|
| | suffix = '.png'
|
| |
|
| |
|
| | error = None
|
| | for _ in range(retry):
|
| | try:
|
| | tensor = tensor.clamp(min(value_range), max(value_range))
|
| | torchvision.utils.save_image(
|
| | tensor,
|
| | save_file,
|
| | nrow=nrow,
|
| | normalize=normalize,
|
| | value_range=value_range)
|
| | return save_file
|
| | except Exception as e:
|
| | error = e
|
| | continue
|
| |
|
| |
|
| | def str2bool(v):
|
| | """
|
| | Convert a string to a boolean.
|
| |
|
| | Supported true values: 'yes', 'true', 't', 'y', '1'
|
| | Supported false values: 'no', 'false', 'f', 'n', '0'
|
| |
|
| | Args:
|
| | v (str): String to convert.
|
| |
|
| | Returns:
|
| | bool: Converted boolean value.
|
| |
|
| | Raises:
|
| | argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
| | """
|
| | if isinstance(v, bool):
|
| | return v
|
| | v_lower = v.lower()
|
| | if v_lower in ('yes', 'true', 't', 'y', '1'):
|
| | return True
|
| | elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
| | return False
|
| | else:
|
| | raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
| |
|