Spaces:
Running
on
A100
Running
on
A100
| import math | |
| import os | |
| import re | |
| import tempfile | |
| import logging | |
| from dataclasses import dataclass | |
| import torch | |
| from torchaudio.models import wav2vec2_model | |
| logger = logging.getLogger(__name__) | |
| # iso codes with specialized rules in uroman | |
| special_isos_uroman = "ara, bel, bul, deu, ell, eng, fas, grc, ell, eng, heb, kaz, kir, lav, lit, mkd, mkd2, oss, pnt, pus, rus, srp, srp2, tur, uig, ukr, yid".split( | |
| "," | |
| ) | |
| special_isos_uroman = [i.strip() for i in special_isos_uroman] | |
| def normalize_uroman(text): | |
| text = text.lower() | |
| text = re.sub("([^a-z' ])", " ", text) | |
| text = re.sub(" +", " ", text) | |
| return text.strip() | |
| def get_uroman_tokens(norm_transcripts, uroman, iso=None): | |
| tf = tempfile.NamedTemporaryFile() | |
| tf2 = tempfile.NamedTemporaryFile() | |
| with open(tf.name, "w") as f: | |
| for t in norm_transcripts: | |
| f.write(t + "\n") | |
| uroman.romanize_file( | |
| input_filename=tf.name, | |
| output_filename=tf2.name, | |
| lcode=iso if iso in special_isos_uroman else None, | |
| ) | |
| outtexts = [] | |
| with open(tf2.name) as f: | |
| for line in f: | |
| line = " ".join(line.strip()) | |
| line = re.sub(r"\s+", " ", line).strip() | |
| outtexts.append(line) | |
| assert len(outtexts) == len(norm_transcripts) | |
| uromans = [] | |
| for ot in outtexts: | |
| uromans.append(normalize_uroman(ot)) | |
| return uromans | |
| class Segment: | |
| label: str | |
| start: int | |
| end: int | |
| def __repr__(self): | |
| return f"{self.label}: [{self.start:5d}, {self.end:5d})" | |
| def length(self): | |
| return self.end - self.start | |
| def merge_repeats(path, idx_to_token_map): | |
| i1, i2 = 0, 0 | |
| segments = [] | |
| while i1 < len(path): | |
| while i2 < len(path) and path[i1] == path[i2]: | |
| i2 += 1 | |
| segments.append(Segment(idx_to_token_map[path[i1]], i1, i2 - 1)) | |
| i1 = i2 | |
| return segments | |
| def time_to_frame(time): | |
| stride_msec = 20 | |
| frames_per_sec = 1000 / stride_msec | |
| return int(time * frames_per_sec) | |
| def load_model_dict(): | |
| # Use models directory from environment variable | |
| models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models") | |
| model_path_name = os.path.join(models_dir, "ctc_alignment_mling_uroman_model.pt") | |
| logger.info("Loading model from models directory...") | |
| if not os.path.exists(model_path_name): | |
| raise FileNotFoundError(f"Model file not found at {model_path_name}") | |
| logger.info(f"Model found at: {model_path_name}") | |
| state_dict = torch.load(model_path_name, map_location="cpu") | |
| model = wav2vec2_model( | |
| extractor_mode="layer_norm", | |
| extractor_conv_layer_config=[ | |
| (512, 10, 5), | |
| (512, 3, 2), | |
| (512, 3, 2), | |
| (512, 3, 2), | |
| (512, 3, 2), | |
| (512, 2, 2), | |
| (512, 2, 2), | |
| ], | |
| extractor_conv_bias=True, | |
| encoder_embed_dim=1024, | |
| encoder_projection_dropout=0.0, | |
| encoder_pos_conv_kernel=128, | |
| encoder_pos_conv_groups=16, | |
| encoder_num_layers=24, | |
| encoder_num_heads=16, | |
| encoder_attention_dropout=0.0, | |
| encoder_ff_interm_features=4096, | |
| encoder_ff_interm_dropout=0.1, | |
| encoder_dropout=0.0, | |
| encoder_layer_norm_first=True, | |
| encoder_layer_drop=0.1, | |
| aux_num_out=31, | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Use models directory from environment variable | |
| models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models") | |
| dict_path_name = os.path.join( | |
| models_dir, "ctc_alignment_mling_uroman_model_dict.txt" | |
| ) | |
| if not os.path.exists(dict_path_name): | |
| raise FileNotFoundError(f"Dictionary file not found at {dict_path_name}") | |
| logger.info(f"Dictionary found at: {dict_path_name}") | |
| dictionary = {} | |
| with open(dict_path_name) as f: | |
| dictionary = {l.strip(): i for i, l in enumerate(f.readlines())} | |
| return model, dictionary | |
| def get_spans(tokens, segments): | |
| ltr_idx = 0 | |
| tokens_idx = 0 | |
| intervals = [] | |
| start, end = (0, 0) | |
| sil = "<blank>" | |
| for seg_idx, seg in enumerate(segments): | |
| if tokens_idx == len(tokens): | |
| assert seg_idx == len(segments) - 1 | |
| assert seg.label == "<blank>" | |
| continue | |
| cur_token = tokens[tokens_idx].split(" ") | |
| ltr = cur_token[ltr_idx] | |
| if seg.label == "<blank>": | |
| continue | |
| assert seg.label == ltr | |
| if (ltr_idx) == 0: | |
| start = seg_idx | |
| if ltr_idx == len(cur_token) - 1: | |
| ltr_idx = 0 | |
| tokens_idx += 1 | |
| intervals.append((start, seg_idx)) | |
| while tokens_idx < len(tokens) and len(tokens[tokens_idx]) == 0: | |
| intervals.append((seg_idx, seg_idx)) | |
| tokens_idx += 1 | |
| else: | |
| ltr_idx += 1 | |
| spans = [] | |
| for idx, (start, end) in enumerate(intervals): | |
| span = segments[start : end + 1] | |
| if start > 0: | |
| prev_seg = segments[start - 1] | |
| if prev_seg.label == sil: | |
| pad_start = ( | |
| prev_seg.start | |
| if (idx == 0) | |
| else int((prev_seg.start + prev_seg.end) / 2) | |
| ) | |
| span = [Segment(sil, pad_start, span[0].start)] + span | |
| if end + 1 < len(segments): | |
| next_seg = segments[end + 1] | |
| if next_seg.label == sil: | |
| pad_end = ( | |
| next_seg.end | |
| if (idx == len(intervals) - 1) | |
| else math.floor((next_seg.start + next_seg.end) / 2) | |
| ) | |
| span = span + [Segment(sil, span[-1].end, pad_end)] | |
| spans.append(span) | |
| return spans | |