Spaces:
Build error
Build error
| """ | |
| Module which builds embeddings for issues and pull requests | |
| The module is designed to be run from the command line and takes the following arguments: | |
| --input_filename: The name of the file containing the issues and pull requests | |
| --model_id: The name of the sentence transformer model to use | |
| --issue_type: The type of issue to embed (either "issue" or "pull") | |
| --n_issues: The number of issues to embed | |
| --update: Whether to update the existing embeddings | |
| The module saves the embeddings to a file called <issue_type>_embeddings.npy and the index to a file called | |
| embedding_index_to_<issue_type>.json | |
| The index provides a mapping from the index of the embedding to the issue or pull request number. | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def load_model(model_id: str): | |
| return SentenceTransformer(model_id) | |
| class EmbeddingWriter: | |
| def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None: | |
| self.output_embedding_filename = output_embedding_filename | |
| self.output_index_filename = output_index_filename | |
| self.embeddings = [] | |
| self.embedding_to_issue_index = embedding_to_issue_index | |
| self.update = update | |
| def __enter__(self): | |
| return self.embeddings | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if len(self.embeddings) == 0: | |
| return | |
| embeddings = np.array(self.embeddings) | |
| if self.update and os.path.exists(self.output_embedding_filename): | |
| embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings]) | |
| logger.info(f"Saving embeddings to {self.output_embedding_filename}") | |
| np.save(self.output_embedding_filename, embeddings) | |
| logger.info(f"Saving embedding index to {self.output_index_filename}") | |
| with open(self.output_index_filename, "w") as f: | |
| json.dump(self.embedding_to_issue_index, f, indent=4) | |
| def embed_issues( | |
| input_filename: str, | |
| model_id: str, | |
| issue_type: str, | |
| n_issues: int = -1, | |
| update: bool = False | |
| ): | |
| model = load_model(model_id) | |
| output_embedding_filename = f"{issue_type}_embeddings.npy" | |
| output_index_filename = f"embedding_index_to_{issue_type}.json" | |
| with open(input_filename, "r") as f: | |
| issues = json.load(f) | |
| if update and os.path.exists(output_index_filename): | |
| with open(output_index_filename, "r") as f: | |
| embedding_to_issue_index = json.load(f) | |
| embedding_index = len(embedding_to_issue_index) | |
| else: | |
| embedding_to_issue_index = {} | |
| embedding_index = 0 | |
| max_issues = n_issues if n_issues > 0 else len(issues) | |
| n_issues = 0 | |
| with EmbeddingWriter( | |
| output_embedding_filename=output_embedding_filename, | |
| output_index_filename=output_index_filename, | |
| update=update, | |
| embedding_to_issue_index=embedding_to_issue_index | |
| ) as embeddings: #, embedding_to_issue_index: | |
| for issue_id, issue in issues.items(): | |
| if n_issues >= max_issues: | |
| break | |
| if issue_id in embedding_to_issue_index.values() and update: | |
| logger.info(f"Skipping issue {issue_id} as it is already embedded") | |
| continue | |
| if "body" not in issue: | |
| logger.info(f"Skipping issue {issue_id} as it has no body") | |
| continue | |
| if issue_type == "pull" and "pull_request" not in issue: | |
| logger.info(f"Skipping issue {issue_id} as it is not a pull request") | |
| continue | |
| elif issue_type == "issue" and "pull_request" in issue: | |
| logger.info(f"Skipping issue {issue_id} as it is a pull request") | |
| continue | |
| title = issue["title"] if issue["title"] is not None else "" | |
| body = issue["body"] if issue["body"] is not None else "" | |
| logger.info(f"Embedding issue {issue_id}") | |
| embedding = model.encode(title + "\n" + body) | |
| embedding_to_issue_index[embedding_index] = issue_id | |
| embeddings.append(embedding) | |
| embedding_index += 1 | |
| n_issues += 1 | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue') | |
| parser.add_argument("--input_filename", type=str, default="issues_dict.json") | |
| parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2") | |
| parser.add_argument("--n_issues", type=int, default=-1) | |
| parser.add_argument("--update", action="store_true") | |
| args = parser.parse_args() | |
| embed_issues(**vars(args)) | |