Spaces:
Sleeping
Sleeping
| import io | |
| import re | |
| import string | |
| import docx2txt | |
| import fitz | |
| import gradio as gr | |
| import joblib | |
| import matplotlib.pyplot as plt | |
| import nltk | |
| import seaborn as sns | |
| import shap | |
| import textract | |
| import torch | |
| from lime.lime_text import LimeTextExplainer | |
| from striprtf.striprtf import rtf_to_text | |
| from transformers import BertForSequenceClassification, BertTokenizer, pipeline | |
| from preprocessing import TextCleaner | |
| cleaner = TextCleaner() | |
| pipe = joblib.load('pipe_v1_natasha.joblib') | |
| model_path = "finetunebert" | |
| tokenizer = BertTokenizer.from_pretrained(model_path, | |
| padding='max_length', | |
| truncation=True) | |
| # tokenizer.init_kwargs["model_max_length"] = 512 | |
| model = BertForSequenceClassification.from_pretrained(model_path) | |
| document_classifier = pipeline("text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| return_all_scores=True) | |
| classes = [ | |
| "Договоры поставки", "Договоры оказания услуг", "Договоры подряда", | |
| "Договоры аренды", "Договоры купли-продажи" | |
| ] | |
| def old__pipeline(text): | |
| clean_text = text_preprocessing(text) | |
| tokens = tokenizer.batch_encode_plus([clean_text], | |
| max_length=512, | |
| padding=True, | |
| truncation=True) | |
| item = {k: torch.tensor(v) for k, v in tokens.items()} | |
| preds = model(**item).logits.detach() | |
| preds = torch.softmax(preds, dim=1)[0] | |
| output = [{ | |
| 'label': cls, | |
| 'score': score | |
| } for cls, score in zip(classes, preds)] | |
| return output | |
| def read_doc(file_obj): | |
| """Read file | |
| :param file_obj: file object | |
| :return: string | |
| """ | |
| text = read_file(file_obj) | |
| return text | |
| def read_docv2(file_obj): | |
| """Read file and collect neighbour for visual output | |
| :param file_obj: file object | |
| :return: string | |
| """ | |
| text = read_file(file_obj) | |
| explainer = LimeTextExplainer(class_names=classes) | |
| text = cleaner.execute(text) | |
| exp = explainer.explain_instance(text, | |
| pipe.predict_proba, | |
| num_features=10, | |
| labels=[0, 1, 2, 3, 4]) | |
| scores = exp.as_list() | |
| scores_desc = sorted(scores, key=lambda t: t[1])[::-1] | |
| selected_words = [word[0] for word in scores_desc] | |
| sent = text.split() | |
| indices = [i for i, word in enumerate(sent) if word in selected_words] | |
| neighbors = [] | |
| for ind in indices: | |
| neighbors.append(" ".join(sent[max(0, ind - 3):min(ind + | |
| 3, len(sent))])) | |
| return "\n\n".join(neighbors) | |
| def classifier(file_obj): | |
| """Classify | |
| :param file_obj: file object | |
| :return: Dict[str, int] | |
| """ | |
| text = read_file(file_obj) | |
| clean_text = text_preprocessing(text) | |
| tokens = tokenizer.batch_encode_plus([clean_text], | |
| max_length=512, | |
| padding=True, | |
| truncation=True) | |
| item = {k: torch.tensor(v) for k, v in tokens.items()} | |
| preds = model(**item).logits.detach() | |
| preds = torch.softmax(preds, dim=1)[0] | |
| return {cls: p.item() for cls, p in zip(classes, preds)} | |
| def clean_text(text): | |
| """Make text lowercase, remove text in square brackets,remove links,remove punctuation | |
| and remove words containing numbers.""" | |
| text = text.lower() | |
| text = re.sub('\[.*?\]', '', text) | |
| text = re.sub('https?://\S+|www\.\S+', '', text) | |
| text = re.sub('<.*?>+', '', text) | |
| text = re.sub('[%s]' % re.escape(string.punctuation), '', text) | |
| text = re.sub('\n', '', text) | |
| text = re.sub('\w*\d\w*', '', text) | |
| return text | |
| def text_preprocessing(text): | |
| """Cleaning and parsing the text.""" | |
| tokenizer = nltk.tokenize.RegexpTokenizer(r'\w+') | |
| nopunc = clean_text(text) | |
| tokenized_text = tokenizer.tokenize(nopunc) | |
| #remove_stopwords = [w for w in tokenized_text if w not in stopwords.words('english')] | |
| combined_text = ' '.join(tokenized_text) | |
| return combined_text | |
| def read_file(file_obj): | |
| """Read file and fixing encoding | |
| :param file_obj: file object | |
| :return: string | |
| """ | |
| if isinstance(file_obj, list): | |
| file_obj = file_obj[0] | |
| filename = file_obj.name | |
| if filename.endswith("docx"): | |
| text = docx2txt.process(filename) | |
| elif filename.endswith("pdf"): | |
| doc = fitz.open(filename) | |
| text = [] | |
| for page in doc: | |
| text.append(page.get_text()) | |
| text = " ".join(text) | |
| elif filename.endswith("doc"): | |
| text = reinterpret(textract.process(filename)) | |
| text = remove_convert_info(text) | |
| elif filename.endswith("rtf"): | |
| with open(filename) as f: | |
| content = f.read() | |
| text = rtf_to_text(content) | |
| else: | |
| return "" | |
| return text | |
| def reinterpret(text: str): | |
| return text.decode('utf8') | |
| def remove_convert_info(text: str): | |
| for i, s in enumerate(text): | |
| if s == ":": | |
| break | |
| return text[i + 6:] | |
| def plot_weights(file_obj): | |
| text = read_file(file_obj) | |
| explainer = LimeTextExplainer(class_names=classes) | |
| text = cleaner.execute(text) | |
| exp = explainer.explain_instance(text, | |
| pipe.predict_proba, | |
| num_features=10, | |
| labels=[0, 1, 2, 3, 4]) | |
| scores = exp.as_list() | |
| scores_desc = sorted(scores, key=lambda t: t[1])[::-1] | |
| plt.rcParams.update({'font.size': 35}) | |
| fig = plt.figure(figsize=(20, 20)) | |
| sns.barplot(x=[s[0] for s in scores_desc[:10]], | |
| y=[s[1] for s in scores_desc[:10]]) | |
| plt.title("Top words contributing to positive sentiment") | |
| plt.ylabel("Weight") | |
| plt.xlabel("Word") | |
| plt.title("Interpreting text predictions with LIME") | |
| plt.xticks(rotation=20) | |
| plt.tight_layout() | |
| return fig | |
| def interpretation_function(file_obj): | |
| text = read_file(file_obj) | |
| clean_text = text_preprocessing(text) | |
| clean_text = " ".join(tokenizer.decode(tokenizer(" ".join(clean_text))["input_ids"]).split()[1:500]) | |
| explainer = shap.Explainer(document_classifier) | |
| shap_values = explainer([clean_text[:-20]]) | |
| # Dimensions are (batch size, text size, number of classes) | |
| # Since we care about positive sentiment, use index 1 | |
| scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1])) | |
| # Scores contains (word, score) pairs | |
| # Format expected by gr.components.Interpretation | |
| return {"original": clean_text, "interpretation": scores} | |
| def as_pyplot_figure(file_obj): | |
| text = read_file(file_obj) | |
| explainer = LimeTextExplainer(class_names=classes) | |
| text = cleaner.execute(text) | |
| exp = explainer.explain_instance(text, | |
| pipe.predict_proba, | |
| num_features=10, | |
| labels=[0, 1, 2, 3, 4]) | |
| buf = io.BytesIO() | |
| fig = exp.as_pyplot_figure() | |
| fig.tight_layout() | |
| plt.rcParams.update({'font.size': 10}) | |
| plt.savefig(buf) | |
| buf.seek(0) | |
| return fig | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""**Document classification**""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file = gr.File(label="Input File") | |
| with gr.Row(): | |
| classify = gr.Button("Classify document") | |
| read = gr.Button("Get text") | |
| interpret_lime = gr.Button("Interpret LIME") | |
| interpret_shap = gr.Button("Interpret SHAP") | |
| with gr.Column(): | |
| label = gr.Label(label="Predicted Document Class") | |
| plot = gr.Plot() | |
| with gr.Column(): | |
| text = gr.Text(label="Selected keywords") | |
| with gr.Column(): | |
| interpretation = gr.components.Interpretation(text) | |
| classify.click(classifier, file, label) | |
| read.click(read_docv2, file, [text]) | |
| interpret_shap.click(interpretation_function, file, interpretation) | |
| interpret_lime.click(as_pyplot_figure, file, plot) | |
| if __name__ == "__main__": | |
| demo.launch() | |