DunasAnastasiia commited on
Commit
7c2e31a
·
0 Parent(s):

Initial commit (Xet)

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.npy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheelhouse/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ *.py,cover
49
+ .hypothesis/
50
+ .pytest_cache/
51
+
52
+ # Type checkers
53
+ .mypy_cache/
54
+ .pyre/
55
+ .pytype/
56
+
57
+ # Linting / formatting
58
+ .ruff_cache/
59
+
60
+ # Jupyter Notebook
61
+ .ipynb_checkpoints
62
+
63
+ # Django / Flask / FastAPI (common)
64
+ *.log
65
+ local_settings.py
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Sphinx documentation
70
+ docs/_build/
71
+
72
+ # Virtual environments
73
+ .venv/
74
+ venv/
75
+ ENV/
76
+ env/
77
+ env.bak/
78
+ venv.bak/
79
+
80
+ # Environment variables / secrets
81
+ .env
82
+ .env.*
83
+ *.env
84
+
85
+ # Editors / IDEs
86
+ .vscode/
87
+ .idea/
88
+ *.sublime-project
89
+ *.sublime-workspace
90
+
91
+ # OS files
92
+ .DS_Store
93
+ Thumbs.db
94
+
95
+ # PyCharm
96
+ *.iml
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RAG QA (BM25 + Dense + Reranker)
3
+ emoji: 🔎
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: "6.1.0"
8
+ python_version: "3.10"
9
+ app_file: app.py
10
+ pinned: false
11
+ short_description: RAG Question Answering demo with BM25 + Dense retrieval and optional reranker.
12
+ ---
13
+
14
+ # RAG QA (BM25 + Dense + OpenAI-compatible providers)
15
+
16
+ Це навчальний проєкт **Retrieval-Augmented Generation (RAG)** для задачі **Question Answering**: система знаходить релевантні уривки в корпусі документів і відповідає на питання, спираючись лише на знайдений контекст. :contentReference[oaicite:2]{index=2}
17
+
18
+ ## Demo
19
+ - Hugging Face Space: *(посилання після деплою)*
20
+ - GitHub: https://github.com/DunasAnastasiia/llm-chat-project
21
+
22
+ ---
23
+
24
+ ## Як працює система (архітектура)
25
+
26
+ **Джерело даних (Dataset)**
27
+ - Hugging Face dataset: `rag-datasets/rag-mini-wikipedia` :contentReference[oaicite:3]{index=3}
28
+
29
+ **Chunking**
30
+ - Простий character-based chunking з overlap: `chunk_chars=900`, `overlap_chars=150`
31
+
32
+ **Retriever**
33
+ - **BM25** (лексичний пошук по ключових словах)
34
+ - **Dense retrieval** (Sentence-Transformers embeddings)
35
+ - Можна **вмикати/вимикати** BM25 та Dense окремо (для порівняння якості).
36
+
37
+ **Reranker (optional)**
38
+ - Cross-encoder `cross-encoder/ms-marco-MiniLM-L-6-v2` для реранкінгу кандидатів (опціонально).
39
+
40
+ **LLM (Generation)**
41
+ - Використовується **OpenAI-compatible Chat Completions API** через бібліотеку `openai` із параметром `base_url`.
42
+ - Працює з провайдерами типу **Groq** / **OpenRouter** (ключ вводиться в UI).
43
+
44
+ **Citations**
45
+ - В промпті LLM проситься цитувати чанки як `[1] [2] ...`, а в UI показується список retrieved chunks з `source_id` / `chunk_id`.
46
+
47
+ **UI**
48
+ - Gradio (web app), тумблери `Use BM25 / Use Dense / Use Reranker`, поле для API key. :contentReference[oaicite:9]{index=9}
49
+
50
+ ## Приклади запитів: де BM25 кращий, а де Dense
51
+
52
+ ### Запити, де краще справляється BM25
53
+
54
+ **Q:** Was Abraham Lincoln the first President of the United States?
55
+ **Очікувана відповідь:** No
56
+ **Чому BM25 кращий:** запит містить дуже конкретні ключові слова (Abraham Lincoln, first President, United States). BM25 добре працює, коли відповідь лежить у чанку з тими самими словами/формулюванням.
57
+
58
+ **Q:** Who was the general in charge at the Battle of Antietam?
59
+ **Очікувана відповідь:** General McClellan
60
+ **Чому BM25 кращий:** тут є точні “якорі” (Battle of Antietam, general in charge). BM25 зазвичай підтягує уривок, де ці терміни зустрічаються буквально.
61
+
62
+ ---
63
+
64
+ ### Запити, де кращі результати повертає Dense retriever
65
+
66
+ **Q:** Who assassinated Lincoln?
67
+ **Очікувана відповідь:** John Wilkes Booth
68
+ **Чому dense кращий:** типовий приклад “семантичного” запиту — відповідь може бути в уривку, який не повторює точну форму запиту (наприклад, “Lincoln was assassinated by …”). Dense краще ловить перефразування й зв’язки “подія ↔ учасник”.
69
+
70
+ **Q:** What caused Calvin Jr.'s death?
71
+ **Очікувана відповідь:** heart attack
72
+ **Чому dense кращий:** питання може збігатися з текстом не дослівно (наприклад, у корпусі “died of a heart attack”). Dense часто краще дістає такі уривки навіть без точного збігу слів.
73
+
74
+
75
+ ---
76
+
77
+ ## Швидкий старт локально
78
+
79
+ ### 1) Встановлення
80
+ ```bash
81
+ python -m venv .venv
82
+ source .venv/bin/activate # Windows: .venv\Scripts\activate
83
+ pip install -r requirements.txt
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
+
6
+ import gradio as gr
7
+
8
+ from rag.retrieve import Retriever
9
+ from rag.llm import answer_with_provider
10
+
11
+
12
+ def ensure_retriever(state):
13
+ if state is None:
14
+ state = Retriever()
15
+ return state
16
+
17
+
18
+ def defaults_for_provider(provider_name: str) -> tuple[str, str]:
19
+ """
20
+ Returns (base_url, default_model) for a given provider.
21
+ """
22
+ if provider_name.startswith("Groq"):
23
+ return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant"
24
+ if provider_name.startswith("OpenRouter"):
25
+ return "https://openrouter.ai/api/v1", "meta-llama/llama-3.1-8b-instruct:free"
26
+ # fallback
27
+ return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant"
28
+
29
+
30
+ def on_provider_change(provider_name: str):
31
+ base_url, model = defaults_for_provider(provider_name)
32
+ return base_url, model
33
+
34
+
35
+ def run_qa(
36
+ provider: str,
37
+ base_url: str,
38
+ api_key: str,
39
+ model: str,
40
+ question: str,
41
+ use_bm25: bool,
42
+ use_dense: bool,
43
+ use_rerank: bool,
44
+ state,
45
+ ):
46
+ state = ensure_retriever(state)
47
+
48
+ if not question or not question.strip():
49
+ return "Write a question 🙂", "", state
50
+
51
+ # Retrieval toggles
52
+ chunks = state.retrieve(
53
+ question,
54
+ use_bm25=use_bm25,
55
+ use_dense=use_dense,
56
+ use_rerank=use_rerank,
57
+ )
58
+
59
+ # Show retrieved context
60
+ ctx = []
61
+ for i, c in enumerate(chunks, start=1):
62
+ ctx.append(
63
+ f"[{i}] ({c.why}, score={c.score:.4f}) source_id={c.source_id}, chunk_id={c.chunk_id}\n{c.text}"
64
+ )
65
+ ctx_text = "\n\n---\n\n".join(ctx) if ctx else "(nothing retrieved)"
66
+
67
+ # If both retrievers off => "no retrieval" mode
68
+ if not use_bm25 and not use_dense:
69
+ ctx_text = "(retrieval is OFF: the model will answer without any context)"
70
+ chunks_for_llm = []
71
+ else:
72
+ chunks_for_llm = [{"chunk_id": c.chunk_id, "source_id": c.source_id, "text": c.text} for c in chunks]
73
+
74
+ if not api_key or not api_key.strip():
75
+ return f"Paste your {provider} API key first.", ctx_text, state
76
+
77
+ # Provider call (OpenAI-compatible Chat Completions)
78
+ try:
79
+ ans = answer_with_provider(
80
+ api_key=api_key.strip(),
81
+ base_url=(base_url or "").strip(),
82
+ model=(model or "").strip(),
83
+ question=question,
84
+ chunks=chunks_for_llm,
85
+ )
86
+ except Exception as e:
87
+ return f"LLM error: {type(e).__name__}: {e}", ctx_text, state
88
+
89
+ return ans, ctx_text, state
90
+
91
+
92
+ with gr.Blocks(title="RAG QA (BM25 + Dense + OpenAI-compatible providers)") as demo:
93
+ gr.Markdown(
94
+ "# RAG QA (HF dataset + BM25 + Dense)\n"
95
+ "Use a **free-tier OpenAI-compatible provider** (Groq / OpenRouter).\n"
96
+ "1) Build index: `python -m rag.index`\n"
97
+ "2) Run UI: `python app.py`\n"
98
+ )
99
+
100
+ state = gr.State(None)
101
+
102
+ provider = gr.Dropdown(
103
+ ["Groq (free tier)", "OpenRouter (free models)"],
104
+ value="Groq (free tier)",
105
+ label="Provider",
106
+ )
107
+
108
+ base_url = gr.Textbox(
109
+ label="Base URL",
110
+ value="https://api.groq.com/openai/v1",
111
+ placeholder="https://api.groq.com/openai/v1",
112
+ )
113
+
114
+ api_key = gr.Textbox(
115
+ label="API key",
116
+ type="password",
117
+ placeholder="paste provider key here",
118
+ )
119
+
120
+ model = gr.Textbox(
121
+ label="Model",
122
+ value="llama-3.1-8b-instant",
123
+ )
124
+
125
+ provider.change(
126
+ fn=on_provider_change,
127
+ inputs=[provider],
128
+ outputs=[base_url, model],
129
+ )
130
+
131
+ question = gr.Textbox(label="Question", placeholder="Ask something...", lines=2)
132
+
133
+ with gr.Row():
134
+ use_bm25 = gr.Checkbox(value=True, label="Use BM25")
135
+ use_dense = gr.Checkbox(value=True, label="Use Dense")
136
+ use_rerank = gr.Checkbox(value=False, label="Use Reranker (optional)")
137
+
138
+ btn = gr.Button("Answer")
139
+
140
+ answer = gr.Textbox(label="Answer", lines=8)
141
+ context = gr.Textbox(label="Retrieved chunks", lines=12)
142
+
143
+ btn.click(
144
+ fn=run_qa,
145
+ inputs=[provider, base_url, api_key, model, question, use_bm25, use_dense, use_rerank, state],
146
+ outputs=[answer, context, state],
147
+ )
148
+
149
+
150
+ if __name__ == "__main__":
151
+ import os
152
+ port = int(os.getenv("PORT", "7860"))
153
+ demo.launch(server_name="0.0.0.0", server_port=port, share=False)
artifacts/chunks.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
artifacts/embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07ccc5bacea16e81d6a77b1971bbc03110547b7b74e05cf1ff74388dc6b9d940
3
+ size 5689472
rag/__init__.py ADDED
File without changes
rag/chunking.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class Chunk:
8
+ chunk_id: int
9
+ source_id: str
10
+ text: str
11
+
12
+
13
+ def chunk_text(text: str, chunk_chars: int, overlap_chars: int) -> list[str]:
14
+ """
15
+ Simple character-based chunking with overlap.
16
+ Works for any text without requiring tokenizers.
17
+ """
18
+ text = (text or "").strip()
19
+ if not text:
20
+ return []
21
+ if chunk_chars <= 0:
22
+ return [text]
23
+
24
+ out: list[str] = []
25
+ i = 0
26
+ n = len(text)
27
+ step = max(1, chunk_chars - max(0, overlap_chars))
28
+ while i < n:
29
+ chunk = text[i : i + chunk_chars].strip()
30
+ if chunk:
31
+ out.append(chunk)
32
+ i += step
33
+ return out
rag/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class Settings:
8
+ # Hugging Face dataset
9
+ dataset_name: str = "rag-datasets/rag-mini-wikipedia"
10
+ corpus_config: str = "text-corpus"
11
+ qa_config: str = "question-answer"
12
+
13
+ # Chunking
14
+ chunk_chars: int = 900
15
+ overlap_chars: int = 150
16
+
17
+ # Retrieval
18
+ top_k_bm25: int = 8
19
+ top_k_dense: int = 8
20
+ top_k_final: int = 6
21
+
22
+ # Dense model
23
+ embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"
24
+
25
+ # Optional reranker
26
+ rerank_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
27
+ rerank_top_n: int = 20 # candidates to rerank
28
+
29
+ # OpenAI
30
+ default_openai_model: str = "gpt-4o-mini"
31
+
32
+ # Artifacts
33
+ artifacts_dir: str = "artifacts"
34
+ chunks_jsonl: str = "chunks.jsonl"
35
+ embeddings_npy: str = "embeddings.npy"
36
+
37
+
38
+ SETTINGS = Settings()
rag/data.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable, Any
4
+
5
+
6
+ def pick_text_column(columns: list[str]) -> str:
7
+ """
8
+ Try to robustly choose the text field from various corpus schemas.
9
+ """
10
+ candidates = ["text", "content", "document", "passage", "passages", "contents", "wiki_text"]
11
+ for c in candidates:
12
+ if c in columns:
13
+ return c
14
+ return columns[0]
15
+
16
+
17
+ def pick_id_column(columns: list[str]) -> str | None:
18
+ for c in ["id", "doc_id", "document_id", "passage_id", "pid"]:
19
+ if c in columns:
20
+ return c
21
+ return None
22
+
23
+
24
+ def iter_corpus_rows(ds) -> Iterable[dict[str, Any]]:
25
+ for row in ds:
26
+ yield row
rag/index.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from datasets import load_dataset
9
+ from sentence_transformers import SentenceTransformer
10
+
11
+ from rag.config import SETTINGS
12
+ from rag.chunking import chunk_text
13
+ from rag.data import pick_text_column, pick_id_column
14
+
15
+
16
+ def _load_corpus_split():
17
+ ds = load_dataset(SETTINGS.dataset_name, SETTINGS.corpus_config)
18
+
19
+ if hasattr(ds, "keys"):
20
+ if "train" in ds:
21
+ return ds["train"]
22
+ if "passages" in ds:
23
+ return ds["passages"]
24
+ if "validation" in ds:
25
+ return ds["validation"]
26
+ if "test" in ds:
27
+ return ds["test"]
28
+ return ds[list(ds.keys())[0]]
29
+
30
+ return ds
31
+
32
+
33
+ def build_index(limit_docs: int | None = None) -> None:
34
+ art = Path(SETTINGS.artifacts_dir)
35
+ art.mkdir(parents=True, exist_ok=True)
36
+ chunks_path = art / SETTINGS.chunks_jsonl
37
+ emb_path = art / SETTINGS.embeddings_npy
38
+
39
+ # Load corpus (no trust_remote_code; modern datasets removed/limited it)
40
+ corpus = _load_corpus_split()
41
+
42
+ cols = list(corpus.column_names)
43
+ text_col = pick_text_column(cols)
44
+ id_col = pick_id_column(cols)
45
+
46
+ chunks: list[dict] = []
47
+ chunk_id = 0
48
+
49
+ n_docs = len(corpus) if limit_docs is None else min(limit_docs, len(corpus))
50
+
51
+ for idx in tqdm(range(n_docs), desc="Chunking corpus"):
52
+ row = corpus[int(idx)]
53
+ raw = row.get(text_col, "")
54
+ source_id = str(row.get(id_col, idx)) if id_col else str(idx)
55
+
56
+ for part in chunk_text(raw, SETTINGS.chunk_chars, SETTINGS.overlap_chars):
57
+ chunks.append({"chunk_id": chunk_id, "source_id": source_id, "text": part})
58
+ chunk_id += 1
59
+
60
+ # Save chunks
61
+ with chunks_path.open("w", encoding="utf-8") as f:
62
+ for ch in chunks:
63
+ f.write(json.dumps(ch, ensure_ascii=False) + "\n")
64
+
65
+ # Compute embeddings
66
+ model = SentenceTransformer(SETTINGS.embed_model)
67
+ texts = [c["text"] for c in chunks]
68
+ emb = model.encode(
69
+ texts,
70
+ batch_size=64,
71
+ show_progress_bar=True,
72
+ normalize_embeddings=True,
73
+ )
74
+ emb = np.asarray(emb, dtype=np.float32)
75
+ np.save(emb_path, emb)
76
+
77
+ print(f"Saved {len(chunks)} chunks -> {chunks_path}")
78
+ print(f"Saved embeddings shape={emb.shape} -> {emb_path}")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ build_index()
rag/llm.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from openai import OpenAI
4
+
5
+
6
+ def build_prompt(question: str, chunks: list[dict]) -> tuple[str, str]:
7
+ instructions = (
8
+ "You are a QA assistant. Answer ONLY using the provided context.\n"
9
+ "If the answer is not in the context, say you don't know based on the context.\n"
10
+ "When you use information from a chunk, cite it like [1], [2] matching the chunk numbers.\n"
11
+ "Be concise."
12
+ )
13
+
14
+ ctx_lines = []
15
+ for i, ch in enumerate(chunks, start=1):
16
+ ctx_lines.append(
17
+ f"[{i}] source_id={ch['source_id']} chunk_id={ch['chunk_id']}\n{ch['text']}\n"
18
+ )
19
+
20
+ input_text = (
21
+ "CONTEXT:\n"
22
+ + "\n".join(ctx_lines)
23
+ + "\nQUESTION:\n"
24
+ + question.strip()
25
+ + "\n\nANSWER:"
26
+ )
27
+ return instructions, input_text
28
+
29
+
30
+ def answer_with_provider(
31
+ api_key: str,
32
+ base_url: str,
33
+ model: str,
34
+ question: str,
35
+ chunks: list[dict],
36
+ ) -> str:
37
+ """
38
+ Works with OpenAI-compatible providers (Groq, OpenRouter, Together, etc.)
39
+ via Chat Completions API.
40
+ """
41
+ client = OpenAI(api_key=api_key, base_url=base_url)
42
+ instructions, input_text = build_prompt(question, chunks)
43
+
44
+ resp = client.chat.completions.create(
45
+ model=model,
46
+ messages=[
47
+ {"role": "system", "content": instructions},
48
+ {"role": "user", "content": input_text},
49
+ ],
50
+ temperature=0.2,
51
+ )
52
+
53
+ msg = resp.choices[0].message.content
54
+ return msg or ""
rag/retrieve.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ from rank_bm25 import BM25Okapi
10
+ from sentence_transformers import SentenceTransformer, CrossEncoder
11
+
12
+ from rag.config import SETTINGS
13
+
14
+ _WORD = re.compile(r"[A-Za-z0-9']+")
15
+
16
+
17
+ def tokenize(text: str) -> list[str]:
18
+ return _WORD.findall((text or "").lower())
19
+
20
+
21
+ @dataclass
22
+ class ChunkRec:
23
+ chunk_id: int
24
+ source_id: str
25
+ text: str
26
+ score: float
27
+ why: str # "bm25", "dense", "rerank"
28
+
29
+
30
+ class Retriever:
31
+ def __init__(self) -> None:
32
+ art = Path(SETTINGS.artifacts_dir)
33
+ self.chunks = self._load_chunks(art / SETTINGS.chunks_jsonl)
34
+ self.emb = np.load(art / SETTINGS.embeddings_npy)
35
+
36
+ # BM25
37
+ tokenized = [tokenize(c["text"]) for c in self.chunks]
38
+ self.bm25 = BM25Okapi(tokenized)
39
+
40
+ # Dense encoder
41
+ self.embedder = SentenceTransformer(SETTINGS.embed_model)
42
+
43
+ # Reranker (lazy)
44
+ self._reranker: CrossEncoder | None = None
45
+
46
+ @staticmethod
47
+ def _load_chunks(path: Path) -> list[dict]:
48
+ out = []
49
+ with path.open("r", encoding="utf-8") as f:
50
+ for line in f:
51
+ out.append(json.loads(line))
52
+ return out
53
+
54
+ def _bm25_search(self, query: str, k: int) -> list[ChunkRec]:
55
+ scores = self.bm25.get_scores(tokenize(query))
56
+ idx = np.argsort(scores)[::-1][:k]
57
+ out: list[ChunkRec] = []
58
+ for i in idx:
59
+ c = self.chunks[int(i)]
60
+ out.append(
61
+ ChunkRec(
62
+ c["chunk_id"],
63
+ c["source_id"],
64
+ c["text"],
65
+ float(scores[int(i)]),
66
+ "bm25",
67
+ )
68
+ )
69
+ return out
70
+
71
+ def _dense_search(self, query: str, k: int) -> list[ChunkRec]:
72
+ q = self.embedder.encode([query], normalize_embeddings=True)
73
+ q = np.asarray(q, dtype=np.float32)[0]
74
+ # cosine similarity because embeddings normalized
75
+ scores = self.emb @ q
76
+ idx = np.argsort(scores)[::-1][:k]
77
+ out: list[ChunkRec] = []
78
+ for i in idx:
79
+ c = self.chunks[int(i)]
80
+ out.append(
81
+ ChunkRec(
82
+ c["chunk_id"],
83
+ c["source_id"],
84
+ c["text"],
85
+ float(scores[int(i)]),
86
+ "dense",
87
+ )
88
+ )
89
+ return out
90
+
91
+ def _get_reranker(self) -> CrossEncoder:
92
+ if self._reranker is None:
93
+ self._reranker = CrossEncoder(SETTINGS.rerank_model)
94
+ return self._reranker
95
+
96
+ def retrieve(
97
+ self,
98
+ query: str,
99
+ use_bm25: bool = True,
100
+ use_dense: bool = True,
101
+ use_rerank: bool = False,
102
+ ) -> list[ChunkRec]:
103
+ cands: list[ChunkRec] = []
104
+ if use_bm25:
105
+ cands.extend(self._bm25_search(query, SETTINGS.top_k_bm25))
106
+ if use_dense:
107
+ cands.extend(self._dense_search(query, SETTINGS.top_k_dense))
108
+
109
+ # de-dup by chunk_id keeping best score per chunk
110
+ best: dict[int, ChunkRec] = {}
111
+ for r in cands:
112
+ prev = best.get(r.chunk_id)
113
+ if prev is None or r.score > prev.score:
114
+ best[r.chunk_id] = r
115
+ merged = list(best.values())
116
+ merged.sort(key=lambda x: x.score, reverse=True)
117
+
118
+ if use_rerank and merged:
119
+ reranker = self._get_reranker()
120
+ top = merged[: SETTINGS.rerank_top_n]
121
+ pairs = [(query, r.text) for r in top]
122
+ rr_scores = reranker.predict(pairs)
123
+ for r, s in zip(top, rr_scores):
124
+ r.score = float(s)
125
+ r.why = "rerank"
126
+ top.sort(key=lambda x: x.score, reverse=True)
127
+ return top[: SETTINGS.top_k_final]
128
+
129
+ return merged[: SETTINGS.top_k_final]
requirements.txt ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.2
4
+ aiosignal==1.4.0
5
+ annotated-doc==0.0.4
6
+ annotated-types==0.7.0
7
+ anyio==4.12.0
8
+ attrs==25.4.0
9
+ audioop-lts==0.2.2
10
+ brotli==1.2.0
11
+ certifi==2025.11.12
12
+ charset-normalizer==3.4.4
13
+ click==8.3.1
14
+ datasets==4.4.1
15
+ dill==0.4.0
16
+ distro==1.9.0
17
+ fastapi==0.124.4
18
+ ffmpy==1.0.0
19
+ filelock==3.20.0
20
+ frozenlist==1.8.0
21
+ fsspec==2025.10.0
22
+ gradio==6.1.0
23
+ gradio_client==2.0.1
24
+ groovy==0.1.2
25
+ h11==0.16.0
26
+ hf-xet==1.2.0
27
+ httpcore==1.0.9
28
+ httpx==0.28.1
29
+ huggingface-hub==0.36.0
30
+ idna==3.11
31
+ Jinja2==3.1.6
32
+ jiter==0.12.0
33
+ joblib==1.5.2
34
+ markdown-it-py==4.0.0
35
+ MarkupSafe==3.0.3
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ multidict==6.7.0
39
+ multiprocess==0.70.18
40
+ networkx==3.6.1
41
+ numpy==2.3.5
42
+ openai==2.11.0
43
+ orjson==3.11.5
44
+ packaging==25.0
45
+ pandas==2.3.3
46
+ pillow==12.0.0
47
+ propcache==0.4.1
48
+ pyarrow==22.0.0
49
+ pydantic==2.12.4
50
+ pydantic_core==2.41.5
51
+ pydub==0.25.1
52
+ Pygments==2.19.2
53
+ python-dateutil==2.9.0.post0
54
+ python-multipart==0.0.20
55
+ pytz==2025.2
56
+ PyYAML==6.0.3
57
+ rank-bm25==0.2.2
58
+ regex==2025.11.3
59
+ requests==2.32.5
60
+ rich==14.2.0
61
+ safehttpx==0.1.7
62
+ safetensors==0.7.0
63
+ scikit-learn==1.8.0
64
+ scipy==1.16.3
65
+ semantic-version==2.10.0
66
+ sentence-transformers==5.2.0
67
+ setuptools==80.9.0
68
+ shellingham==1.5.4
69
+ six==1.17.0
70
+ sniffio==1.3.1
71
+ starlette==0.50.0
72
+ sympy==1.14.0
73
+ threadpoolctl==3.6.0
74
+ tokenizers==0.22.1
75
+ tomlkit==0.13.3
76
+ torch==2.9.1
77
+ tqdm==4.67.1
78
+ transformers==4.57.3
79
+ typer==0.20.0
80
+ typer-slim==0.20.0
81
+ typing-inspection==0.4.2
82
+ typing_extensions==4.15.0
83
+ tzdata==2025.3
84
+ urllib3==2.6.2
85
+ uvicorn==0.38.0
86
+ xxhash==3.6.0
87
+ yarl==1.22.0