| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import logging |
| | import math |
| | import os |
| | import time |
| | import warnings |
| | from enum import Enum |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Tuple, Union |
| |
|
| | import kaldi_native_fbank as knf |
| | import numpy as np |
| | import sentencepiece as spm |
| | import soundfile as sf |
| | import yaml |
| | from onnxruntime import (GraphOptimizationLevel, InferenceSession, |
| | SessionOptions, get_available_providers, get_device) |
| | from rknnlite.api.rknn_lite import RKNNLite |
| |
|
| | RKNN_INPUT_LEN = 171 |
| |
|
| | SPEECH_SCALE = 1/2 |
| |
|
| | class VadOrtInferRuntimeSession: |
| | def __init__(self, config, root_dir: Path): |
| | sess_opt = SessionOptions() |
| | sess_opt.log_severity_level = 4 |
| | sess_opt.enable_cpu_mem_arena = False |
| | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
| |
|
| | cuda_ep = "CUDAExecutionProvider" |
| | cpu_ep = "CPUExecutionProvider" |
| | cpu_provider_options = { |
| | "arena_extend_strategy": "kSameAsRequested", |
| | } |
| |
|
| | EP_list = [] |
| | if ( |
| | config["use_cuda"] |
| | and get_device() == "GPU" |
| | and cuda_ep in get_available_providers() |
| | ): |
| | EP_list = [(cuda_ep, config[cuda_ep])] |
| | EP_list.append((cpu_ep, cpu_provider_options)) |
| |
|
| | config["model_path"] = root_dir / str(config["model_path"]) |
| | self._verify_model(config["model_path"]) |
| | logging.info(f"Loading onnx model at {str(config['model_path'])}") |
| | self.session = InferenceSession( |
| | str(config["model_path"]), sess_options=sess_opt, providers=EP_list |
| | ) |
| |
|
| | if config["use_cuda"] and cuda_ep not in self.session.get_providers(): |
| | logging.warning( |
| | f"{cuda_ep} is not available for current env, " |
| | f"the inference part is automatically shifted to be " |
| | f"executed under {cpu_ep}.\n " |
| | "Please ensure the installed onnxruntime-gpu version" |
| | " matches your cuda and cudnn version, " |
| | "you can check their relations from the offical web site: " |
| | "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", |
| | RuntimeWarning, |
| | ) |
| |
|
| | def __call__( |
| | self, input_content: List[Union[np.ndarray, np.ndarray]] |
| | ) -> np.ndarray: |
| | if isinstance(input_content, list): |
| | input_dict = { |
| | "speech": input_content[0], |
| | "in_cache0": input_content[1], |
| | "in_cache1": input_content[2], |
| | "in_cache2": input_content[3], |
| | "in_cache3": input_content[4], |
| | } |
| | else: |
| | input_dict = {"speech": input_content} |
| |
|
| | return self.session.run(None, input_dict) |
| |
|
| | def get_input_names( |
| | self, |
| | ): |
| | return [v.name for v in self.session.get_inputs()] |
| |
|
| | def get_output_names( |
| | self, |
| | ): |
| | return [v.name for v in self.session.get_outputs()] |
| |
|
| | def get_character_list(self, key: str = "character"): |
| | return self.meta_dict[key].splitlines() |
| |
|
| | def have_key(self, key: str = "character") -> bool: |
| | self.meta_dict = self.session.get_modelmeta().custom_metadata_map |
| | if key in self.meta_dict.keys(): |
| | return True |
| | return False |
| |
|
| | @staticmethod |
| | def _verify_model(model_path): |
| | model_path = Path(model_path) |
| | if not model_path.exists(): |
| | raise FileNotFoundError(f"{model_path} does not exists.") |
| | if not model_path.is_file(): |
| | raise FileExistsError(f"{model_path} is not a file.") |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
| | logging.basicConfig(format=formatter, level=logging.INFO) |
| |
|
| |
|
| | class OrtInferRuntimeSession: |
| | def __init__(self, model_file, device_id=-1, intra_op_num_threads=4): |
| | device_id = str(device_id) |
| | sess_opt = SessionOptions() |
| | sess_opt.intra_op_num_threads = intra_op_num_threads |
| | sess_opt.log_severity_level = 4 |
| | sess_opt.enable_cpu_mem_arena = False |
| | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
| |
|
| | cuda_ep = "CUDAExecutionProvider" |
| | cuda_provider_options = { |
| | "device_id": device_id, |
| | "arena_extend_strategy": "kNextPowerOfTwo", |
| | "cudnn_conv_algo_search": "EXHAUSTIVE", |
| | "do_copy_in_default_stream": "true", |
| | } |
| | cpu_ep = "CPUExecutionProvider" |
| | cpu_provider_options = { |
| | "arena_extend_strategy": "kSameAsRequested", |
| | } |
| |
|
| | EP_list = [] |
| | if ( |
| | device_id != "-1" |
| | and get_device() == "GPU" |
| | and cuda_ep in get_available_providers() |
| | ): |
| | EP_list = [(cuda_ep, cuda_provider_options)] |
| | EP_list.append((cpu_ep, cpu_provider_options)) |
| |
|
| | self._verify_model(model_file) |
| |
|
| | self.session = InferenceSession( |
| | model_file, sess_options=sess_opt, providers=EP_list |
| | ) |
| |
|
| | |
| | del model_file |
| |
|
| | if device_id != "-1" and cuda_ep not in self.session.get_providers(): |
| | warnings.warn( |
| | f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n" |
| | "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " |
| | "you can check their relations from the offical web site: " |
| | "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", |
| | RuntimeWarning, |
| | ) |
| |
|
| | def __call__(self, input_content) -> np.ndarray: |
| | input_dict = dict(zip(self.get_input_names(), input_content)) |
| | try: |
| | result = self.session.run(self.get_output_names(), input_dict) |
| | return result |
| | except Exception as e: |
| | print(e) |
| | raise RuntimeError(f"ONNXRuntime inferece failed. ") from e |
| |
|
| | def get_input_names( |
| | self, |
| | ): |
| | return [v.name for v in self.session.get_inputs()] |
| |
|
| | def get_output_names( |
| | self, |
| | ): |
| | return [v.name for v in self.session.get_outputs()] |
| |
|
| | def get_character_list(self, key: str = "character"): |
| | return self.meta_dict[key].splitlines() |
| |
|
| | def have_key(self, key: str = "character") -> bool: |
| | self.meta_dict = self.session.get_modelmeta().custom_metadata_map |
| | if key in self.meta_dict.keys(): |
| | return True |
| | return False |
| |
|
| | @staticmethod |
| | def _verify_model(model_path): |
| | model_path = Path(model_path) |
| | if not model_path.exists(): |
| | raise FileNotFoundError(f"{model_path} does not exists.") |
| | if not model_path.is_file(): |
| | raise FileExistsError(f"{model_path} is not a file.") |
| |
|
| |
|
| | def log_softmax(x: np.ndarray) -> np.ndarray: |
| | |
| | x_max = np.max(x, axis=-1, keepdims=True) |
| | |
| | softmax = np.exp(x - x_max) |
| | softmax_sum = np.sum(softmax, axis=-1, keepdims=True) |
| | softmax = softmax / softmax_sum |
| | |
| | return np.log(softmax) |
| |
|
| |
|
| | class SenseVoiceInferenceSession: |
| | def __init__( |
| | self, |
| | embedding_model_file, |
| | encoder_model_file, |
| | bpe_model_file, |
| | device_id=-1, |
| | intra_op_num_threads=4, |
| | ): |
| | logging.info(f"Loading model from {embedding_model_file}") |
| |
|
| | self.embedding = np.load(embedding_model_file) |
| | logging.info(f"Loading model {encoder_model_file}") |
| | start = time.time() |
| | self.encoder = RKNNLite(verbose=False) |
| | self.encoder.load_rknn(encoder_model_file) |
| | self.encoder.init_runtime() |
| |
|
| | logging.info( |
| | f"Loading {encoder_model_file} takes {time.time() - start:.2f} seconds" |
| | ) |
| | self.blank_id = 0 |
| | self.sp = spm.SentencePieceProcessor() |
| | self.sp.load(bpe_model_file) |
| |
|
| | def __call__(self, speech, language: int, use_itn: bool) -> np.ndarray: |
| | language_query = self.embedding[[[language]]] |
| |
|
| | |
| | text_norm_query = self.embedding[[[14 if use_itn else 15]]] |
| | event_emo_query = self.embedding[[[1, 2]]] |
| |
|
| | |
| | speech = speech * SPEECH_SCALE |
| | |
| | input_content = np.concatenate( |
| | [ |
| | language_query, |
| | event_emo_query, |
| | text_norm_query, |
| | speech, |
| | ], |
| | axis=1, |
| | ).astype(np.float32) |
| | print(input_content.shape) |
| | |
| | input_content = np.pad(input_content, ((0, 0), (0, RKNN_INPUT_LEN - input_content.shape[1]), (0, 0))) |
| | print("padded shape:", input_content.shape) |
| | start_time = time.time() |
| | encoder_out = self.encoder.inference(inputs=[input_content])[0] |
| | end_time = time.time() |
| | print(f"encoder inference time: {end_time - start_time:.2f} seconds") |
| | |
| | def unique_consecutive(arr): |
| | if len(arr) == 0: |
| | return arr |
| | |
| | mask = np.append([True], arr[1:] != arr[:-1]) |
| | out = arr[mask] |
| | out = out[out != self.blank_id] |
| | return out.tolist() |
| | |
| | |
| | |
| | hypos = unique_consecutive(encoder_out[0].argmax(axis=0)) |
| | text = self.sp.DecodeIds(hypos) |
| | return text |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class WavFrontend: |
| | """Conventional frontend structure for ASR.""" |
| |
|
| | def __init__( |
| | self, |
| | cmvn_file: str = None, |
| | fs: int = 16000, |
| | window: str = "hamming", |
| | n_mels: int = 80, |
| | frame_length: int = 25, |
| | frame_shift: int = 10, |
| | lfr_m: int = 7, |
| | lfr_n: int = 6, |
| | dither: float = 0, |
| | **kwargs, |
| | ) -> None: |
| | opts = knf.FbankOptions() |
| | opts.frame_opts.samp_freq = fs |
| | opts.frame_opts.dither = dither |
| | opts.frame_opts.window_type = window |
| | opts.frame_opts.frame_shift_ms = float(frame_shift) |
| | opts.frame_opts.frame_length_ms = float(frame_length) |
| | opts.mel_opts.num_bins = n_mels |
| | opts.energy_floor = 0 |
| | opts.frame_opts.snip_edges = True |
| | opts.mel_opts.debug_mel = False |
| | self.opts = opts |
| |
|
| | self.lfr_m = lfr_m |
| | self.lfr_n = lfr_n |
| | self.cmvn_file = cmvn_file |
| |
|
| | if self.cmvn_file: |
| | self.cmvn = self.load_cmvn() |
| | self.fbank_fn = None |
| | self.fbank_beg_idx = 0 |
| | self.reset_status() |
| |
|
| | def reset_status(self): |
| | self.fbank_fn = knf.OnlineFbank(self.opts) |
| | self.fbank_beg_idx = 0 |
| |
|
| | def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | waveform = waveform * (1 << 15) |
| | self.fbank_fn = knf.OnlineFbank(self.opts) |
| | self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) |
| | frames = self.fbank_fn.num_frames_ready |
| | mat = np.empty([frames, self.opts.mel_opts.num_bins]) |
| | for i in range(frames): |
| | mat[i, :] = self.fbank_fn.get_frame(i) |
| | feat = mat.astype(np.float32) |
| | feat_len = np.array(mat.shape[0]).astype(np.int32) |
| | return feat, feat_len |
| |
|
| | def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | if self.lfr_m != 1 or self.lfr_n != 1: |
| | feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n) |
| |
|
| | if self.cmvn_file: |
| | feat = self.apply_cmvn(feat) |
| |
|
| | feat_len = np.array(feat.shape[0]).astype(np.int32) |
| | return feat, feat_len |
| |
|
| | def load_audio(self, filename: str) -> Tuple[np.ndarray, int]: |
| | data, sample_rate = sf.read( |
| | filename, |
| | always_2d=True, |
| | dtype="float32", |
| | ) |
| | assert ( |
| | sample_rate == 16000 |
| | ), f"Only 16000 Hz is supported, but got {sample_rate}Hz" |
| | self.sample_rate = sample_rate |
| | data = data[:, 0] |
| | samples = np.ascontiguousarray(data) |
| |
|
| | return samples, sample_rate |
| |
|
| | @staticmethod |
| | def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray: |
| | LFR_inputs = [] |
| |
|
| | T = inputs.shape[0] |
| | T_lfr = int(np.ceil(T / lfr_n)) |
| | left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1)) |
| | inputs = np.vstack((left_padding, inputs)) |
| | T = T + (lfr_m - 1) // 2 |
| | for i in range(T_lfr): |
| | if lfr_m <= T - i * lfr_n: |
| | LFR_inputs.append( |
| | (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1) |
| | ) |
| | else: |
| | |
| | num_padding = lfr_m - (T - i * lfr_n) |
| | frame = inputs[i * lfr_n :].reshape(-1) |
| | for _ in range(num_padding): |
| | frame = np.hstack((frame, inputs[-1])) |
| |
|
| | LFR_inputs.append(frame) |
| | LFR_outputs = np.vstack(LFR_inputs).astype(np.float32) |
| | return LFR_outputs |
| |
|
| | def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray: |
| | """ |
| | Apply CMVN with mvn data |
| | """ |
| | frame, dim = inputs.shape |
| | means = np.tile(self.cmvn[0:1, :dim], (frame, 1)) |
| | vars = np.tile(self.cmvn[1:2, :dim], (frame, 1)) |
| | inputs = (inputs + means) * vars |
| | return inputs |
| |
|
| | def get_features(self, inputs: Union[str, np.ndarray]) -> Tuple[np.ndarray, int]: |
| | if isinstance(inputs, str): |
| | inputs, _ = self.load_audio(inputs) |
| |
|
| | fbank, _ = self.fbank(inputs) |
| | feats = self.apply_cmvn(self.apply_lfr(fbank, self.lfr_m, self.lfr_n)) |
| | return feats |
| |
|
| | def load_cmvn( |
| | self, |
| | ) -> np.ndarray: |
| | with open(self.cmvn_file, "r", encoding="utf-8") as f: |
| | lines = f.readlines() |
| |
|
| | means_list = [] |
| | vars_list = [] |
| | for i in range(len(lines)): |
| | line_item = lines[i].split() |
| | if line_item[0] == "<AddShift>": |
| | line_item = lines[i + 1].split() |
| | if line_item[0] == "<LearnRateCoef>": |
| | add_shift_line = line_item[3 : (len(line_item) - 1)] |
| | means_list = list(add_shift_line) |
| | continue |
| | elif line_item[0] == "<Rescale>": |
| | line_item = lines[i + 1].split() |
| | if line_item[0] == "<LearnRateCoef>": |
| | rescale_line = line_item[3 : (len(line_item) - 1)] |
| | vars_list = list(rescale_line) |
| | continue |
| |
|
| | means = np.array(means_list).astype(np.float64) |
| | vars = np.array(vars_list).astype(np.float64) |
| | cmvn = np.array([means, vars]) |
| | return cmvn |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
| | if not Path(yaml_path).exists(): |
| | raise FileExistsError(f"The {yaml_path} does not exist.") |
| |
|
| | with open(str(yaml_path), "rb") as f: |
| | data = yaml.load(f, Loader=yaml.Loader) |
| | return data |
| |
|
| |
|
| | class VadStateMachine(Enum): |
| | kVadInStateStartPointNotDetected = 1 |
| | kVadInStateInSpeechSegment = 2 |
| | kVadInStateEndPointDetected = 3 |
| |
|
| |
|
| | class FrameState(Enum): |
| | kFrameStateInvalid = -1 |
| | kFrameStateSpeech = 1 |
| | kFrameStateSil = 0 |
| |
|
| |
|
| | |
| | class AudioChangeState(Enum): |
| | kChangeStateSpeech2Speech = 0 |
| | kChangeStateSpeech2Sil = 1 |
| | kChangeStateSil2Sil = 2 |
| | kChangeStateSil2Speech = 3 |
| | kChangeStateNoBegin = 4 |
| | kChangeStateInvalid = 5 |
| |
|
| |
|
| | class VadDetectMode(Enum): |
| | kVadSingleUtteranceDetectMode = 0 |
| | kVadMutipleUtteranceDetectMode = 1 |
| |
|
| |
|
| | class VADXOptions: |
| | def __init__( |
| | self, |
| | sample_rate: int = 16000, |
| | detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, |
| | snr_mode: int = 0, |
| | max_end_silence_time: int = 800, |
| | max_start_silence_time: int = 3000, |
| | do_start_point_detection: bool = True, |
| | do_end_point_detection: bool = True, |
| | window_size_ms: int = 200, |
| | sil_to_speech_time_thres: int = 150, |
| | speech_to_sil_time_thres: int = 150, |
| | speech_2_noise_ratio: float = 1.0, |
| | do_extend: int = 1, |
| | lookback_time_start_point: int = 200, |
| | lookahead_time_end_point: int = 100, |
| | max_single_segment_time: int = 60000, |
| | nn_eval_block_size: int = 8, |
| | dcd_block_size: int = 4, |
| | snr_thres: int = -100.0, |
| | noise_frame_num_used_for_snr: int = 100, |
| | decibel_thres: int = -100.0, |
| | speech_noise_thres: float = 0.6, |
| | fe_prior_thres: float = 1e-4, |
| | silence_pdf_num: int = 1, |
| | sil_pdf_ids: List[int] = [0], |
| | speech_noise_thresh_low: float = -0.1, |
| | speech_noise_thresh_high: float = 0.3, |
| | output_frame_probs: bool = False, |
| | frame_in_ms: int = 10, |
| | frame_length_ms: int = 25, |
| | ): |
| | self.sample_rate = sample_rate |
| | self.detect_mode = detect_mode |
| | self.snr_mode = snr_mode |
| | self.max_end_silence_time = max_end_silence_time |
| | self.max_start_silence_time = max_start_silence_time |
| | self.do_start_point_detection = do_start_point_detection |
| | self.do_end_point_detection = do_end_point_detection |
| | self.window_size_ms = window_size_ms |
| | self.sil_to_speech_time_thres = sil_to_speech_time_thres |
| | self.speech_to_sil_time_thres = speech_to_sil_time_thres |
| | self.speech_2_noise_ratio = speech_2_noise_ratio |
| | self.do_extend = do_extend |
| | self.lookback_time_start_point = lookback_time_start_point |
| | self.lookahead_time_end_point = lookahead_time_end_point |
| | self.max_single_segment_time = max_single_segment_time |
| | self.nn_eval_block_size = nn_eval_block_size |
| | self.dcd_block_size = dcd_block_size |
| | self.snr_thres = snr_thres |
| | self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr |
| | self.decibel_thres = decibel_thres |
| | self.speech_noise_thres = speech_noise_thres |
| | self.fe_prior_thres = fe_prior_thres |
| | self.silence_pdf_num = silence_pdf_num |
| | self.sil_pdf_ids = sil_pdf_ids |
| | self.speech_noise_thresh_low = speech_noise_thresh_low |
| | self.speech_noise_thresh_high = speech_noise_thresh_high |
| | self.output_frame_probs = output_frame_probs |
| | self.frame_in_ms = frame_in_ms |
| | self.frame_length_ms = frame_length_ms |
| |
|
| |
|
| | class E2EVadSpeechBufWithDoa(object): |
| | def __init__(self): |
| | self.start_ms = 0 |
| | self.end_ms = 0 |
| | self.buffer = [] |
| | self.contain_seg_start_point = False |
| | self.contain_seg_end_point = False |
| | self.doa = 0 |
| |
|
| | def reset(self): |
| | self.start_ms = 0 |
| | self.end_ms = 0 |
| | self.buffer = [] |
| | self.contain_seg_start_point = False |
| | self.contain_seg_end_point = False |
| | self.doa = 0 |
| |
|
| |
|
| | class E2EVadFrameProb(object): |
| | def __init__(self): |
| | self.noise_prob = 0.0 |
| | self.speech_prob = 0.0 |
| | self.score = 0.0 |
| | self.frame_id = 0 |
| | self.frm_state = 0 |
| |
|
| |
|
| | class WindowDetector(object): |
| | def __init__( |
| | self, |
| | window_size_ms: int, |
| | sil_to_speech_time: int, |
| | speech_to_sil_time: int, |
| | frame_size_ms: int, |
| | ): |
| | self.window_size_ms = window_size_ms |
| | self.sil_to_speech_time = sil_to_speech_time |
| | self.speech_to_sil_time = speech_to_sil_time |
| | self.frame_size_ms = frame_size_ms |
| |
|
| | self.win_size_frame = int(window_size_ms / frame_size_ms) |
| | self.win_sum = 0 |
| | self.win_state = [0] * self.win_size_frame |
| |
|
| | self.cur_win_pos = 0 |
| | self.pre_frame_state = FrameState.kFrameStateSil |
| | self.cur_frame_state = FrameState.kFrameStateSil |
| | self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) |
| | self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) |
| |
|
| | self.voice_last_frame_count = 0 |
| | self.noise_last_frame_count = 0 |
| | self.hydre_frame_count = 0 |
| |
|
| | def reset(self) -> None: |
| | self.cur_win_pos = 0 |
| | self.win_sum = 0 |
| | self.win_state = [0] * self.win_size_frame |
| | self.pre_frame_state = FrameState.kFrameStateSil |
| | self.cur_frame_state = FrameState.kFrameStateSil |
| | self.voice_last_frame_count = 0 |
| | self.noise_last_frame_count = 0 |
| | self.hydre_frame_count = 0 |
| |
|
| | def get_win_size(self) -> int: |
| | return int(self.win_size_frame) |
| |
|
| | def detect_one_frame( |
| | self, frameState: FrameState, frame_count: int |
| | ) -> AudioChangeState: |
| | cur_frame_state = FrameState.kFrameStateSil |
| | if frameState == FrameState.kFrameStateSpeech: |
| | cur_frame_state = 1 |
| | elif frameState == FrameState.kFrameStateSil: |
| | cur_frame_state = 0 |
| | else: |
| | return AudioChangeState.kChangeStateInvalid |
| | self.win_sum -= self.win_state[self.cur_win_pos] |
| | self.win_sum += cur_frame_state |
| | self.win_state[self.cur_win_pos] = cur_frame_state |
| | self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame |
| |
|
| | if ( |
| | self.pre_frame_state == FrameState.kFrameStateSil |
| | and self.win_sum >= self.sil_to_speech_frmcnt_thres |
| | ): |
| | self.pre_frame_state = FrameState.kFrameStateSpeech |
| | return AudioChangeState.kChangeStateSil2Speech |
| |
|
| | if ( |
| | self.pre_frame_state == FrameState.kFrameStateSpeech |
| | and self.win_sum <= self.speech_to_sil_frmcnt_thres |
| | ): |
| | self.pre_frame_state = FrameState.kFrameStateSil |
| | return AudioChangeState.kChangeStateSpeech2Sil |
| |
|
| | if self.pre_frame_state == FrameState.kFrameStateSil: |
| | return AudioChangeState.kChangeStateSil2Sil |
| | if self.pre_frame_state == FrameState.kFrameStateSpeech: |
| | return AudioChangeState.kChangeStateSpeech2Speech |
| | return AudioChangeState.kChangeStateInvalid |
| |
|
| | def frame_size_ms(self) -> int: |
| | return int(self.frame_size_ms) |
| |
|
| |
|
| | class E2EVadModel: |
| | def __init__(self, config, vad_post_args: Dict[str, Any], root_dir: Path): |
| | super(E2EVadModel, self).__init__() |
| | self.vad_opts = VADXOptions(**vad_post_args) |
| | self.windows_detector = WindowDetector( |
| | self.vad_opts.window_size_ms, |
| | self.vad_opts.sil_to_speech_time_thres, |
| | self.vad_opts.speech_to_sil_time_thres, |
| | self.vad_opts.frame_in_ms, |
| | ) |
| | self.model = VadOrtInferRuntimeSession(config, root_dir) |
| | self.all_reset_detection() |
| |
|
| | def all_reset_detection(self): |
| | |
| | self.is_final = False |
| | self.data_buf_start_frame = 0 |
| | self.frm_cnt = 0 |
| | self.latest_confirmed_speech_frame = 0 |
| | self.lastest_confirmed_silence_frame = -1 |
| | self.continous_silence_frame_count = 0 |
| | self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| | self.confirmed_start_frame = -1 |
| | self.confirmed_end_frame = -1 |
| | self.number_end_time_detected = 0 |
| | self.sil_frame = 0 |
| | self.sil_pdf_ids = self.vad_opts.sil_pdf_ids |
| | self.noise_average_decibel = -100.0 |
| | self.pre_end_silence_detected = False |
| | self.next_seg = True |
| |
|
| | self.output_data_buf = [] |
| | self.output_data_buf_offset = 0 |
| | self.frame_probs = [] |
| | self.max_end_sil_frame_cnt_thresh = ( |
| | self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres |
| | ) |
| | self.speech_noise_thres = self.vad_opts.speech_noise_thres |
| | self.scores = None |
| | self.scores_offset = 0 |
| | self.max_time_out = False |
| | self.decibel = [] |
| | self.decibel_offset = 0 |
| | self.data_buf_size = 0 |
| | self.data_buf_all_size = 0 |
| | self.waveform = None |
| | self.reset_detection() |
| |
|
| | def reset_detection(self): |
| | self.continous_silence_frame_count = 0 |
| | self.latest_confirmed_speech_frame = 0 |
| | self.lastest_confirmed_silence_frame = -1 |
| | self.confirmed_start_frame = -1 |
| | self.confirmed_end_frame = -1 |
| | self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| | self.windows_detector.reset() |
| | self.sil_frame = 0 |
| | self.frame_probs = [] |
| |
|
| | def compute_decibel(self) -> None: |
| | frame_sample_length = int( |
| | self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 |
| | ) |
| | frame_shift_length = int( |
| | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000 |
| | ) |
| | if self.data_buf_all_size == 0: |
| | self.data_buf_all_size = len(self.waveform[0]) |
| | self.data_buf_size = self.data_buf_all_size |
| | else: |
| | self.data_buf_all_size += len(self.waveform[0]) |
| |
|
| | for offset in range( |
| | 0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length |
| | ): |
| | self.decibel.append( |
| | 10 |
| | * np.log10( |
| | np.square( |
| | self.waveform[0][offset : offset + frame_sample_length] |
| | ).sum() |
| | + 1e-6 |
| | ) |
| | ) |
| |
|
| | def compute_scores(self, feats: np.ndarray) -> None: |
| | scores = self.model(feats) |
| | self.vad_opts.nn_eval_block_size = scores[0].shape[1] |
| | self.frm_cnt += scores[0].shape[1] |
| | if isinstance(feats, list): |
| | |
| | feats = feats[0] |
| |
|
| | assert ( |
| | scores[0].shape[1] == feats.shape[1] |
| | ), "The shape between feats and scores does not match" |
| |
|
| | self.scores = scores[0] |
| | self.scores_offset += self.scores.shape[1] |
| |
|
| | return scores[1:] |
| |
|
| | def pop_data_buf_till_frame(self, frame_idx: int) -> None: |
| | while self.data_buf_start_frame < frame_idx: |
| | if self.data_buf_size >= int( |
| | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000 |
| | ): |
| | self.data_buf_start_frame += 1 |
| | self.data_buf_size = ( |
| | self.data_buf_all_size |
| | - self.data_buf_start_frame |
| | * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | ) |
| |
|
| | def pop_data_to_output_buf( |
| | self, |
| | start_frm: int, |
| | frm_cnt: int, |
| | first_frm_is_start_point: bool, |
| | last_frm_is_end_point: bool, |
| | end_point_is_sent_end: bool, |
| | ) -> None: |
| | self.pop_data_buf_till_frame(start_frm) |
| | expected_sample_number = int( |
| | frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000 |
| | ) |
| | if last_frm_is_end_point: |
| | extra_sample = max( |
| | 0, |
| | int( |
| | self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 |
| | - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000 |
| | ), |
| | ) |
| | expected_sample_number += int(extra_sample) |
| | if end_point_is_sent_end: |
| | expected_sample_number = max(expected_sample_number, self.data_buf_size) |
| | if self.data_buf_size < expected_sample_number: |
| | logging.error("error in calling pop data_buf\n") |
| |
|
| | if len(self.output_data_buf) == 0 or first_frm_is_start_point: |
| | self.output_data_buf.append(E2EVadSpeechBufWithDoa()) |
| | self.output_data_buf[-1].reset() |
| | self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms |
| | self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms |
| | self.output_data_buf[-1].doa = 0 |
| | cur_seg = self.output_data_buf[-1] |
| | if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| | logging.error("warning\n") |
| | out_pos = len(cur_seg.buffer) |
| | data_to_pop = 0 |
| | if end_point_is_sent_end: |
| | data_to_pop = expected_sample_number |
| | else: |
| | data_to_pop = int( |
| | frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000 |
| | ) |
| | if data_to_pop > self.data_buf_size: |
| | logging.error("VAD data_to_pop is bigger than self.data_buf.size()!!!\n") |
| | data_to_pop = self.data_buf_size |
| | expected_sample_number = self.data_buf_size |
| |
|
| | cur_seg.doa = 0 |
| | for sample_cpy_out in range(0, data_to_pop): |
| | |
| | out_pos += 1 |
| | for sample_cpy_out in range(data_to_pop, expected_sample_number): |
| | |
| | out_pos += 1 |
| | if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| | logging.error("Something wrong with the VAD algorithm\n") |
| | self.data_buf_start_frame += frm_cnt |
| | cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms |
| | if first_frm_is_start_point: |
| | cur_seg.contain_seg_start_point = True |
| | if last_frm_is_end_point: |
| | cur_seg.contain_seg_end_point = True |
| |
|
| | def on_silence_detected(self, valid_frame: int): |
| | self.lastest_confirmed_silence_frame = valid_frame |
| | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | self.pop_data_buf_till_frame(valid_frame) |
| | |
| | |
| |
|
| | def on_voice_detected(self, valid_frame: int) -> None: |
| | self.latest_confirmed_speech_frame = valid_frame |
| | self.pop_data_to_output_buf(valid_frame, 1, False, False, False) |
| |
|
| | def on_voice_start(self, start_frame: int, fake_result: bool = False) -> None: |
| | if self.vad_opts.do_start_point_detection: |
| | pass |
| | if self.confirmed_start_frame != -1: |
| | logging.error("not reset vad properly\n") |
| | else: |
| | self.confirmed_start_frame = start_frame |
| |
|
| | if ( |
| | not fake_result |
| | and self.vad_state_machine |
| | == VadStateMachine.kVadInStateStartPointNotDetected |
| | ): |
| | self.pop_data_to_output_buf( |
| | self.confirmed_start_frame, 1, True, False, False |
| | ) |
| |
|
| | def on_voice_end( |
| | self, end_frame: int, fake_result: bool, is_last_frame: bool |
| | ) -> None: |
| | for t in range(self.latest_confirmed_speech_frame + 1, end_frame): |
| | self.on_voice_detected(t) |
| | if self.vad_opts.do_end_point_detection: |
| | pass |
| | if self.confirmed_end_frame != -1: |
| | logging.error("not reset vad properly\n") |
| | else: |
| | self.confirmed_end_frame = end_frame |
| | if not fake_result: |
| | self.sil_frame = 0 |
| | self.pop_data_to_output_buf( |
| | self.confirmed_end_frame, 1, False, True, is_last_frame |
| | ) |
| | self.number_end_time_detected += 1 |
| |
|
| | def maybe_on_voice_end_last_frame( |
| | self, is_final_frame: bool, cur_frm_idx: int |
| | ) -> None: |
| | if is_final_frame: |
| | self.on_voice_end(cur_frm_idx, False, True) |
| | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| |
|
| | def get_latency(self) -> int: |
| | return int(self.latency_frm_num_at_start_point() * self.vad_opts.frame_in_ms) |
| |
|
| | def latency_frm_num_at_start_point(self) -> int: |
| | vad_latency = self.windows_detector.get_win_size() |
| | if self.vad_opts.do_extend: |
| | vad_latency += int( |
| | self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms |
| | ) |
| | return vad_latency |
| |
|
| | def get_frame_state(self, t: int) -> FrameState: |
| | frame_state = FrameState.kFrameStateInvalid |
| | cur_decibel = self.decibel[t - self.decibel_offset] |
| | cur_snr = cur_decibel - self.noise_average_decibel |
| | |
| | if cur_decibel < self.vad_opts.decibel_thres: |
| | frame_state = FrameState.kFrameStateSil |
| | self.detect_one_frame(frame_state, t, False) |
| | return frame_state |
| |
|
| | sum_score = 0.0 |
| | noise_prob = 0.0 |
| | assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num |
| | if len(self.sil_pdf_ids) > 0: |
| | assert len(self.scores) == 1 |
| | sil_pdf_scores = [ |
| | self.scores[0][t - self.scores_offset][sil_pdf_id] |
| | for sil_pdf_id in self.sil_pdf_ids |
| | ] |
| | sum_score = sum(sil_pdf_scores) |
| | noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio |
| | total_score = 1.0 |
| | sum_score = total_score - sum_score |
| | speech_prob = math.log(sum_score) |
| | if self.vad_opts.output_frame_probs: |
| | frame_prob = E2EVadFrameProb() |
| | frame_prob.noise_prob = noise_prob |
| | frame_prob.speech_prob = speech_prob |
| | frame_prob.score = sum_score |
| | frame_prob.frame_id = t |
| | self.frame_probs.append(frame_prob) |
| | if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres: |
| | if ( |
| | cur_snr >= self.vad_opts.snr_thres |
| | and cur_decibel >= self.vad_opts.decibel_thres |
| | ): |
| | frame_state = FrameState.kFrameStateSpeech |
| | else: |
| | frame_state = FrameState.kFrameStateSil |
| | else: |
| | frame_state = FrameState.kFrameStateSil |
| | if self.noise_average_decibel < -99.9: |
| | self.noise_average_decibel = cur_decibel |
| | else: |
| | self.noise_average_decibel = ( |
| | cur_decibel |
| | + self.noise_average_decibel |
| | * (self.vad_opts.noise_frame_num_used_for_snr - 1) |
| | ) / self.vad_opts.noise_frame_num_used_for_snr |
| |
|
| | return frame_state |
| |
|
| | def infer_offline( |
| | self, |
| | feats: np.ndarray, |
| | waveform: np.ndarray, |
| | in_cache: Dict[str, np.ndarray] = dict(), |
| | is_final: bool = False, |
| | ) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]: |
| | self.waveform = waveform |
| | self.compute_decibel() |
| |
|
| | self.compute_scores(feats) |
| | if not is_final: |
| | self.detect_common_frames() |
| | else: |
| | self.detect_last_frames() |
| | segments = [] |
| | for batch_num in range(0, feats.shape[0]): |
| | segment_batch = [] |
| | if len(self.output_data_buf) > 0: |
| | for i in range(self.output_data_buf_offset, len(self.output_data_buf)): |
| | if ( |
| | not self.output_data_buf[i].contain_seg_start_point |
| | or not self.output_data_buf[i].contain_seg_end_point |
| | ): |
| | continue |
| | segment = [ |
| | self.output_data_buf[i].start_ms, |
| | self.output_data_buf[i].end_ms, |
| | ] |
| | segment_batch.append(segment) |
| | self.output_data_buf_offset += 1 |
| | if segment_batch: |
| | segments.append(segment_batch) |
| |
|
| | if is_final: |
| | |
| | self.all_reset_detection() |
| | return segments, in_cache |
| |
|
| | def infer_online( |
| | self, |
| | feats: np.ndarray, |
| | waveform: np.ndarray, |
| | in_cache: list = None, |
| | is_final: bool = False, |
| | max_end_sil: int = 800, |
| | ) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]: |
| | feats = [feats] |
| | if in_cache is None: |
| | in_cache = [] |
| |
|
| | self.max_end_sil_frame_cnt_thresh = ( |
| | max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| | ) |
| | self.waveform = waveform |
| | feats.extend(in_cache) |
| | in_cache = self.compute_scores(feats) |
| | self.compute_decibel() |
| |
|
| | if is_final: |
| | self.detect_last_frames() |
| | else: |
| | self.detect_common_frames() |
| |
|
| | segments = [] |
| | |
| | for batch_num in range(0, feats[0].shape[0]): |
| | if len(self.output_data_buf) > 0: |
| | for i in range(self.output_data_buf_offset, len(self.output_data_buf)): |
| | if not self.output_data_buf[i].contain_seg_start_point: |
| | continue |
| | if ( |
| | not self.next_seg |
| | and not self.output_data_buf[i].contain_seg_end_point |
| | ): |
| | continue |
| | start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1 |
| | if self.output_data_buf[i].contain_seg_end_point: |
| | end_ms = self.output_data_buf[i].end_ms |
| | self.next_seg = True |
| | self.output_data_buf_offset += 1 |
| | else: |
| | end_ms = -1 |
| | self.next_seg = False |
| | segments.append([start_ms, end_ms]) |
| |
|
| | return segments, in_cache |
| |
|
| | def get_frames_state( |
| | self, |
| | feats: np.ndarray, |
| | waveform: np.ndarray, |
| | in_cache: list = None, |
| | is_final: bool = False, |
| | max_end_sil: int = 800, |
| | ): |
| | feats = [feats] |
| | states = [] |
| | if in_cache is None: |
| | in_cache = [] |
| |
|
| | self.max_end_sil_frame_cnt_thresh = ( |
| | max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| | ) |
| | self.waveform = waveform |
| | feats.extend(in_cache) |
| | in_cache = self.compute_scores(feats) |
| | self.compute_decibel() |
| |
|
| | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | return states |
| |
|
| | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | frame_state = FrameState.kFrameStateInvalid |
| | frame_state = self.get_frame_state(self.frm_cnt - 1 - i) |
| | states.append(frame_state) |
| | if i == 0 and is_final: |
| | logging.info("last frame detected") |
| | self.detect_one_frame(frame_state, self.frm_cnt - 1, True) |
| | else: |
| | self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False) |
| |
|
| | return states |
| |
|
| | def detect_common_frames(self) -> int: |
| | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | return 0 |
| | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | frame_state = FrameState.kFrameStateInvalid |
| | frame_state = self.get_frame_state(self.frm_cnt - 1 - i) |
| | |
| | self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False) |
| |
|
| | self.decibel = self.decibel[self.vad_opts.nn_eval_block_size - 1 :] |
| | self.decibel_offset = self.frm_cnt - 1 - i |
| | return 0 |
| |
|
| | def detect_last_frames(self) -> int: |
| | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | return 0 |
| | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | frame_state = FrameState.kFrameStateInvalid |
| | frame_state = self.get_frame_state(self.frm_cnt - 1 - i) |
| | if i != 0: |
| | self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False) |
| | else: |
| | self.detect_one_frame(frame_state, self.frm_cnt - 1, True) |
| |
|
| | return 0 |
| |
|
| | def detect_one_frame( |
| | self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool |
| | ) -> None: |
| | tmp_cur_frm_state = FrameState.kFrameStateInvalid |
| | if cur_frm_state == FrameState.kFrameStateSpeech: |
| | if math.fabs(1.0) > float(self.vad_opts.fe_prior_thres): |
| | tmp_cur_frm_state = FrameState.kFrameStateSpeech |
| | else: |
| | tmp_cur_frm_state = FrameState.kFrameStateSil |
| | elif cur_frm_state == FrameState.kFrameStateSil: |
| | tmp_cur_frm_state = FrameState.kFrameStateSil |
| | state_change = self.windows_detector.detect_one_frame( |
| | tmp_cur_frm_state, cur_frm_idx |
| | ) |
| | frm_shift_in_ms = self.vad_opts.frame_in_ms |
| | if AudioChangeState.kChangeStateSil2Speech == state_change: |
| | self.continous_silence_frame_count = 0 |
| | self.pre_end_silence_detected = False |
| |
|
| | if ( |
| | self.vad_state_machine |
| | == VadStateMachine.kVadInStateStartPointNotDetected |
| | ): |
| | start_frame = max( |
| | self.data_buf_start_frame, |
| | cur_frm_idx - self.latency_frm_num_at_start_point(), |
| | ) |
| | self.on_voice_start(start_frame) |
| | self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment |
| | for t in range(start_frame + 1, cur_frm_idx + 1): |
| | self.on_voice_detected(t) |
| | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx): |
| | self.on_voice_detected(t) |
| | if ( |
| | cur_frm_idx - self.confirmed_start_frame + 1 |
| | > self.vad_opts.max_single_segment_time / frm_shift_in_ms |
| | ): |
| | self.on_voice_end(cur_frm_idx, False, False) |
| | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | elif not is_final_frame: |
| | self.on_voice_detected(cur_frm_idx) |
| | else: |
| | self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx) |
| | else: |
| | pass |
| | elif AudioChangeState.kChangeStateSpeech2Sil == state_change: |
| | self.continous_silence_frame_count = 0 |
| | if ( |
| | self.vad_state_machine |
| | == VadStateMachine.kVadInStateStartPointNotDetected |
| | ): |
| | pass |
| | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | if ( |
| | cur_frm_idx - self.confirmed_start_frame + 1 |
| | > self.vad_opts.max_single_segment_time / frm_shift_in_ms |
| | ): |
| | self.on_voice_end(cur_frm_idx, False, False) |
| | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | elif not is_final_frame: |
| | self.on_voice_detected(cur_frm_idx) |
| | else: |
| | self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx) |
| | else: |
| | pass |
| | elif AudioChangeState.kChangeStateSpeech2Speech == state_change: |
| | self.continous_silence_frame_count = 0 |
| | if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | if ( |
| | cur_frm_idx - self.confirmed_start_frame + 1 |
| | > self.vad_opts.max_single_segment_time / frm_shift_in_ms |
| | ): |
| | self.max_time_out = True |
| | self.on_voice_end(cur_frm_idx, False, False) |
| | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | elif not is_final_frame: |
| | self.on_voice_detected(cur_frm_idx) |
| | else: |
| | self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx) |
| | else: |
| | pass |
| | elif AudioChangeState.kChangeStateSil2Sil == state_change: |
| | self.continous_silence_frame_count += 1 |
| | if ( |
| | self.vad_state_machine |
| | == VadStateMachine.kVadInStateStartPointNotDetected |
| | ): |
| | |
| | if ( |
| | ( |
| | self.vad_opts.detect_mode |
| | == VadDetectMode.kVadSingleUtteranceDetectMode.value |
| | ) |
| | and ( |
| | self.continous_silence_frame_count * frm_shift_in_ms |
| | > self.vad_opts.max_start_silence_time |
| | ) |
| | ) or (is_final_frame and self.number_end_time_detected == 0): |
| | for t in range( |
| | self.lastest_confirmed_silence_frame + 1, cur_frm_idx |
| | ): |
| | self.on_silence_detected(t) |
| | self.on_voice_start(0, True) |
| | self.on_voice_end(0, True, False) |
| | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | else: |
| | if cur_frm_idx >= self.latency_frm_num_at_start_point(): |
| | self.on_silence_detected( |
| | cur_frm_idx - self.latency_frm_num_at_start_point() |
| | ) |
| | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | if ( |
| | self.continous_silence_frame_count * frm_shift_in_ms |
| | >= self.max_end_sil_frame_cnt_thresh |
| | ): |
| | lookback_frame = int( |
| | self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms |
| | ) |
| | if self.vad_opts.do_extend: |
| | lookback_frame -= int( |
| | self.vad_opts.lookahead_time_end_point / frm_shift_in_ms |
| | ) |
| | lookback_frame -= 1 |
| | lookback_frame = max(0, lookback_frame) |
| | self.on_voice_end(cur_frm_idx - lookback_frame, False, False) |
| | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | elif ( |
| | cur_frm_idx - self.confirmed_start_frame + 1 |
| | > self.vad_opts.max_single_segment_time / frm_shift_in_ms |
| | ): |
| | self.on_voice_end(cur_frm_idx, False, False) |
| | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | elif self.vad_opts.do_extend and not is_final_frame: |
| | if self.continous_silence_frame_count <= int( |
| | self.vad_opts.lookahead_time_end_point / frm_shift_in_ms |
| | ): |
| | self.on_voice_detected(cur_frm_idx) |
| | else: |
| | self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx) |
| | else: |
| | pass |
| |
|
| | if ( |
| | self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected |
| | and self.vad_opts.detect_mode |
| | == VadDetectMode.kVadMutipleUtteranceDetectMode.value |
| | ): |
| | self.reset_detection() |
| |
|
| |
|
| | class FSMNVad(object): |
| | def __init__(self, config_dir: str): |
| | config_dir = Path(config_dir) |
| | self.config = read_yaml(config_dir / "fsmn-config.yaml") |
| | self.frontend = WavFrontend( |
| | cmvn_file=config_dir / "fsmn-am.mvn", |
| | **self.config["WavFrontend"]["frontend_conf"], |
| | ) |
| | self.config["FSMN"]["model_path"] = config_dir / "fsmnvad-offline.onnx" |
| |
|
| | self.vad = E2EVadModel( |
| | self.config["FSMN"], self.config["vadPostArgs"], config_dir |
| | ) |
| |
|
| | def set_parameters(self, mode): |
| | pass |
| |
|
| | def extract_feature(self, waveform): |
| | fbank, _ = self.frontend.fbank(waveform) |
| | feats, feats_len = self.frontend.lfr_cmvn(fbank) |
| | return feats.astype(np.float32), feats_len |
| |
|
| | def is_speech(self, buf, sample_rate=16000): |
| | assert sample_rate == 16000, "only support 16k sample rate" |
| |
|
| | def segments_offline(self, waveform_path: Union[str, Path, np.ndarray]): |
| | """get sements of audio""" |
| |
|
| | if isinstance(waveform_path, np.ndarray): |
| | waveform = waveform_path |
| | else: |
| | if not os.path.exists(waveform_path): |
| | raise FileExistsError(f"{waveform_path} is not exist.") |
| | if os.path.isfile(waveform_path): |
| | logging.info(f"load audio {waveform_path}") |
| | waveform, _sample_rate = sf.read( |
| | waveform_path, |
| | dtype="float32", |
| | ) |
| | else: |
| | raise FileNotFoundError(str(Path)) |
| | assert ( |
| | _sample_rate == 16000 |
| | ), f"only support 16k sample rate, current sample rate is {_sample_rate}" |
| |
|
| | feats, feats_len = self.extract_feature(waveform) |
| | waveform = waveform[None, ...] |
| | segments_part, in_cache = self.vad.infer_offline( |
| | feats[None, ...], waveform, is_final=True |
| | ) |
| | return segments_part[0] |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} |
| | formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
| | logging.basicConfig(format=formatter, level=logging.INFO) |
| |
|
| | def main(): |
| | arg_parser = argparse.ArgumentParser(description="Sense Voice") |
| | arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model") |
| | download_model_path = os.path.dirname(__file__) |
| | arg_parser.add_argument( |
| | "-dp", |
| | "--download_path", |
| | default=download_model_path, |
| | type=str, |
| | help="dir path of resource downloaded", |
| | ) |
| | arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device") |
| | arg_parser.add_argument( |
| | "-n", "--num_threads", default=4, type=int, help="Num threads" |
| | ) |
| | arg_parser.add_argument( |
| | "-l", |
| | "--language", |
| | choices=languages.keys(), |
| | default="auto", |
| | type=str, |
| | help="Language", |
| | ) |
| | arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN") |
| | args = arg_parser.parse_args() |
| |
|
| | front = WavFrontend(os.path.join(download_model_path, "am.mvn")) |
| |
|
| | model = SenseVoiceInferenceSession( |
| | os.path.join(download_model_path, "embedding.npy"), |
| | os.path.join( |
| | download_model_path, |
| | "sense-voice-encoder.rknn", |
| | ), |
| | os.path.join(download_model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"), |
| | args.device, |
| | args.num_threads, |
| | ) |
| | waveform, _sample_rate = sf.read( |
| | args.audio_file, |
| | dtype="float32", |
| | always_2d=True |
| | ) |
| |
|
| | logging.info(f"Audio {args.audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel") |
| | |
| | start = time.time() |
| | vad = FSMNVad(download_model_path) |
| | for channel_id, channel_data in enumerate(waveform.T): |
| | segments = vad.segments_offline(channel_data) |
| | results = "" |
| | for part in segments: |
| | audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16]) |
| | asr_result = model( |
| | audio_feats[None, ...], |
| | language=languages[args.language], |
| | use_itn=args.use_itn, |
| | ) |
| | logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}") |
| | vad.vad.all_reset_detection() |
| | decoding_time = time.time() - start |
| | logging.info(f"Decoder audio takes {decoding_time} seconds") |
| | logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|