Spaces:
Running
Running
| # https://huggingface.co/deepkyu/ml-talking-face | |
| import os | |
| import subprocess | |
| import pkg_resources | |
| library = 'httpx' | |
| installed_version = pkg_resources.get_distribution(library).version | |
| desired_version = '0.25.0' | |
| if installed_version != desired_version: | |
| subprocess.run(f'pip install --force-reinstall {library}=={desired_version}', shell=True) | |
| print(f"Package is installed with newer version: {library}=={desired_version}") | |
| REST_IP = os.environ['REST_IP'] | |
| SERVICE_PORT = int(os.environ['SERVICE_PORT']) | |
| TRANSLATION_APIKEY_URL = os.environ['TRANSLATION_APIKEY_URL'] | |
| GOOGLE_APPLICATION_CREDENTIALS = os.environ['GOOGLE_APPLICATION_CREDENTIALS'] | |
| subprocess.call(f"wget --no-check-certificate -O {GOOGLE_APPLICATION_CREDENTIALS} {TRANSLATION_APIKEY_URL}", shell=True) | |
| TOXICITY_THRESHOLD = float(os.getenv('TOXICITY_THRESHOLD', 0.7)) | |
| import gradio as gr | |
| from toxicity_estimator import PerspectiveAPI | |
| from translator import Translator | |
| from client_rest import RestAPIApplication | |
| from pathlib import Path | |
| import argparse | |
| import threading | |
| from utils import get_snippet_from_url | |
| class GradioApplication: | |
| def __init__(self, rest_ip, rest_port, max_seed, server_port=7860, share=False): | |
| self.lang_list = { | |
| 'ko': 'ko_KR', | |
| 'en': 'en_US', | |
| 'ja': 'ja_JP', | |
| 'zh': 'zh_CN', | |
| 'zh-CN': 'zh_CN' | |
| } | |
| self.background_list = [None, | |
| "background_image/cvpr.png", | |
| "background_image/black.png", | |
| "background_image/river.mp4", | |
| "background_image/sky.mp4"] | |
| self.perspective_api = PerspectiveAPI() | |
| self.translator = Translator() | |
| self.rest_application = RestAPIApplication(rest_ip, rest_port) | |
| self.output_dir = Path("output_file") | |
| self.max_seed = max_seed | |
| self._file_seed = 0 | |
| self.lock = threading.Lock() | |
| with gr.Blocks( | |
| theme="deepkyu/compact-theme", | |
| css=get_snippet_from_url("https://huggingface.co/spaces/deepkyu/compact-theme/raw/main/main.css") | |
| ) as demo: | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=8): | |
| gr.Markdown(Path("docs/title.md").read_text(), sanitize_html=False) | |
| with gr.Column(scale=1): | |
| toggle_dark = gr.Button(value="Dark", variant='stop') | |
| toggle_dark.click( | |
| None, | |
| js=""" | |
| () => { | |
| document.body.classList.toggle('dark'); | |
| } | |
| """, | |
| ) | |
| gr.Markdown( Path("docs/description.md").read_text(), sanitize_html=False) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| text_input, lang_input, duration_rate_input, action_input, background_input = prepare_input() | |
| submit_button = gr.Button(value="Run", variant="primary") | |
| with gr.Column(scale=1): | |
| toxicity_output, translation_result_otuput, video_output = prepare_output() | |
| submit_button.click( | |
| fn=self.infer, | |
| inputs=[text_input, lang_input, duration_rate_input, action_input, background_input], | |
| outputs=[toxicity_output, translation_result_otuput, video_output], | |
| ) | |
| gr.Markdown(Path("docs/article.md").read_text(), sanitize_html=False) | |
| demo.queue().launch(share=share, server_port=server_port) | |
| def _get_file_seed(self): | |
| return f"{self._file_seed % self.max_seed:02d}" | |
| def _reset_file_seed(self): | |
| self._file_seed = 0 | |
| def _counter_file_seed(self): | |
| with self.lock: | |
| self._file_seed += 1 | |
| def get_lang_code(self, lang): | |
| return self.lang_list[lang] | |
| def get_background_data(self, background_index): | |
| # get background filename and its extension | |
| data_path = self.background_list[background_index] | |
| if data_path is not None: | |
| with open(data_path, 'rb') as rf: | |
| background_data = rf.read() | |
| is_video_background = str(data_path).endswith(".mp4") | |
| else: | |
| background_data = None | |
| is_video_background = False | |
| return background_data, is_video_background | |
| def return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=""): | |
| return {'Toxicity': toxicity_prob}, f"Language: {lang_dest}\nText: {target_text}\n-\nDetails: {detail}", str(video_filename) | |
| def infer(self, text, lang, duration_rate, action, background_index): | |
| self._counter_file_seed() | |
| print(f"File Seed: {self._file_seed}") | |
| toxicity_prob = 0.0 | |
| target_text = "" | |
| lang_dest = "" | |
| video_filename = "vacant.mp4" | |
| # Toxicity estimation | |
| try: | |
| toxicity_prob = self.perspective_api.get_score(text) | |
| except Exception as e: # when Perspective API doesn't work | |
| pass | |
| if toxicity_prob > TOXICITY_THRESHOLD: | |
| detail = "Sorry, it seems that the input text is too toxic." | |
| return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {detail}") | |
| # Google Translate API | |
| try: | |
| target_text, lang_dest = self.translator.get_translation(text, lang) | |
| except Exception as e: | |
| raise e | |
| target_text = "" | |
| lang_dest = "" | |
| detail = f"Error from language translation: ({e})" | |
| return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {detail}") | |
| try: | |
| self.translator.length_check(lang_dest, target_text) # assertion check | |
| except AssertionError as e: | |
| return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {str(e)}") | |
| lang_rpc_code = self.get_lang_code(lang_dest) | |
| # Video Inference | |
| background_data, is_video_background = self.get_background_data(background_index) | |
| video_data = self.rest_application.get_video(target_text, lang_rpc_code, duration_rate, action.lower(), | |
| background_data, is_video_background) | |
| print(f"Video data size: {len(video_data)}") | |
| video_filename = self.output_dir / f"{self._file_seed:02d}.mkv" | |
| with open(video_filename, "wb") as video_file: | |
| video_file.write(video_data) | |
| return self.return_format(toxicity_prob, target_text, lang_dest, video_filename) | |
| def prepare_input(): | |
| text_input = gr.Textbox(lines=2, | |
| placeholder="Type your text with English, Chinese, Korean, and Japanese.", | |
| value="Hello, this is demonstration for talking face generation " | |
| "with multilingual text-to-speech.", | |
| label="Text") | |
| lang_input = gr.Radio(['Korean', 'English', 'Japanese', 'Chinese'], | |
| type='value', | |
| value='Korean', | |
| label="Language") | |
| duration_rate_input = gr.Slider(minimum=0.8, | |
| maximum=1.2, | |
| step=0.01, | |
| value=1.0, | |
| label="Duration (The bigger the value, the slower the speech)") | |
| action_input = gr.Radio(['Default', 'Hand', 'BothHand', 'HandDown', 'Sorry'], | |
| type='value', | |
| value='Default', | |
| label="Select an action ...") | |
| background_input = gr.Radio(['None', 'CVPR', 'Black', 'River', 'Sky'], | |
| type='index', | |
| value='None', | |
| label="Select a background image/video ...") | |
| return text_input, lang_input, duration_rate_input, action_input, background_input | |
| def prepare_output(): | |
| toxicity_output = gr.Label(num_top_classes=1, label="Toxicity (from Perspective API)") | |
| translation_result_otuput = gr.Textbox(type="text", label="Translation Result") | |
| video_output = gr.Video(format='mp4') | |
| return toxicity_output, translation_result_otuput, video_output | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description='GRADIO DEMO for talking face generation submitted to CVPR2022') | |
| parser.add_argument('-p', '--port', dest='gradio_port', type=int, default=7860, help="Port for gradio") | |
| parser.add_argument('--rest_ip', type=str, default=REST_IP, help="IP for REST API") | |
| parser.add_argument('--rest_port', type=int, default=SERVICE_PORT, help="Port for REST API") | |
| parser.add_argument('--max_seed', type=int, default=20, help="Max seed for saving video") | |
| parser.add_argument('--share', action='store_true', help='get publicly sharable link') | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| gradio_application = GradioApplication(args.rest_ip, args.rest_port, args.max_seed, | |
| server_port=args.gradio_port, share=args.share) | |