ianshank Claude commited on
Commit
cabd409
·
1 Parent(s): bb930ab

fix: add missing feature_extractor.py module

Browse files

CRITICAL: Feature extractor was missing from Space deployment

- Caused ModuleNotFoundError on import
- App now has all required modules
- Should start successfully

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

src/agents/meta_controller/feature_extractor.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature Extractor for Meta-Controller.
3
+
4
+ Replaces simple heuristic-based feature engineering with semantic embeddings.
5
+ Uses sentence-transformers for local embedding generation or OpenAI if configured.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ from sentence_transformers import SentenceTransformer, util
15
+
16
+ from src.agents.meta_controller.base import MetaControllerFeatures
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class FeatureExtractorConfig:
23
+ """Configuration for FeatureExtractor."""
24
+ model_name: str = "all-MiniLM-L6-v2"
25
+ device: str = "cpu"
26
+
27
+ @classmethod
28
+ def from_env(cls) -> "FeatureExtractorConfig":
29
+ """Load configuration from environment variables."""
30
+ return cls(
31
+ model_name=os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2"),
32
+ device=os.getenv("DEVICE", "cpu"),
33
+ )
34
+
35
+
36
+ class FeatureExtractor:
37
+ """
38
+ Extracts semantic features from queries using embeddings.
39
+
40
+ Uses a pre-trained embedding model to map queries to a vector space,
41
+ then calculates similarity scores against agent prototypes to estimate
42
+ routing confidence.
43
+ """
44
+
45
+ # Agent prototypes - descriptions of what each agent is good at
46
+ AGENT_PROTOTYPES = {
47
+ "hrm": [
48
+ "complex problem decomposition",
49
+ "hierarchical reasoning",
50
+ "breaking down multiple questions",
51
+ "multi-step planning",
52
+ "structured analysis",
53
+ ],
54
+ "trm": [
55
+ "iterative refinement",
56
+ "improving an answer",
57
+ "comparison and contrast",
58
+ "fixing code or text",
59
+ "polishing content",
60
+ ],
61
+ "mcts": [
62
+ "optimization problem",
63
+ "strategic search",
64
+ "finding the best path",
65
+ "exploring alternatives",
66
+ "decision making under uncertainty",
67
+ ],
68
+ }
69
+
70
+ def __init__(self, config: FeatureExtractorConfig | None = None):
71
+ """
72
+ Initialize the feature extractor.
73
+
74
+ Args:
75
+ config: Configuration object
76
+ """
77
+ if config is None:
78
+ config = FeatureExtractorConfig()
79
+
80
+ self.config = config
81
+
82
+ try:
83
+ logger.info(f"Loading embedding model: {config.model_name}")
84
+ self.model = SentenceTransformer(config.model_name, device=config.device)
85
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
86
+
87
+ # Pre-compute prototype embeddings
88
+ self.prototype_embeddings = {}
89
+ for agent, descriptions in self.AGENT_PROTOTYPES.items():
90
+ self.prototype_embeddings[agent] = self.model.encode(descriptions)
91
+
92
+ logger.info("FeatureExtractor initialized successfully")
93
+ except Exception as e:
94
+ logger.error(f"Failed to initialize FeatureExtractor: {e}")
95
+ # Fallback to simpler logic or raise
96
+ self.model = None
97
+
98
+ def extract_features(self, query: str, iteration: int = 0, last_agent: str = "none") -> MetaControllerFeatures:
99
+ """
100
+ Extract features from a query using semantic analysis.
101
+
102
+ Args:
103
+ query: The input query text
104
+ iteration: Current iteration number
105
+ last_agent: Name of the last agent used
106
+
107
+ Returns:
108
+ MetaControllerFeatures object populated with semantic scores
109
+ """
110
+ query_length = len(query)
111
+
112
+ if self.model is None:
113
+ # Fallback to heuristics if model failed to load
114
+ return self._heuristic_fallback(query, iteration, last_agent)
115
+
116
+ try:
117
+ # Generate query embedding
118
+ query_embedding = self.model.encode(query)
119
+
120
+ # Calculate similarity to each agent's prototypes
121
+ scores = {}
122
+ for agent, proto_embeddings in self.prototype_embeddings.items():
123
+ # Calculate cosine similarity between query and all prototypes for this agent
124
+ similarities = util.cos_sim(query_embedding, proto_embeddings)[0]
125
+ # Take the maximum similarity as the score for this agent
126
+ scores[agent] = float(similarities.max())
127
+
128
+ # Normalize scores to sum to 1 (roughly) or just scale them
129
+ # Here we map [-1, 1] similarity to [0, 1] confidence roughly
130
+
131
+ hrm_conf = max(0.0, scores.get("hrm", 0.0))
132
+ trm_conf = max(0.0, scores.get("trm", 0.0))
133
+ mcts_conf = max(0.0, scores.get("mcts", 0.0))
134
+
135
+ # Apply softmax-like normalization for clearer distinction
136
+ confs = np.array([hrm_conf, trm_conf, mcts_conf])
137
+ # Simple normalization
138
+ if confs.sum() > 0:
139
+ confs = confs / confs.sum()
140
+ else:
141
+ confs = np.array([0.33, 0.33, 0.33])
142
+
143
+ hrm_confidence = float(confs[0])
144
+ trm_confidence = float(confs[1])
145
+ mcts_value = float(confs[2])
146
+
147
+ # Calculate consensus
148
+ max_conf = max(hrm_confidence, trm_confidence, mcts_value)
149
+ min_conf = min(hrm_confidence, trm_confidence, mcts_value)
150
+ consensus_score = min_conf / max_conf if max_conf > 0 else 0.0
151
+
152
+ # Additional features
153
+ has_technical = any(w in query.lower() for w in ["code", "function", "api", "error", "bug"])
154
+
155
+ return MetaControllerFeatures(
156
+ hrm_confidence=hrm_confidence,
157
+ trm_confidence=trm_confidence,
158
+ mcts_value=mcts_value,
159
+ consensus_score=consensus_score,
160
+ last_agent=last_agent,
161
+ iteration=iteration,
162
+ query_length=query_length,
163
+ has_rag_context=query_length > 50, # Simple proxy
164
+ rag_relevance_score=0.0, # Placeholder
165
+ is_technical_query=has_technical
166
+ )
167
+
168
+ except Exception as e:
169
+ logger.error(f"Error extracting features: {e}")
170
+ return self._heuristic_fallback(query, iteration, last_agent)
171
+
172
+ def _heuristic_fallback(self, query: str, iteration: int, last_agent: str) -> MetaControllerFeatures:
173
+ """Fallback to simple string heuristics if embedding fails."""
174
+ # Simple heuristics (copied/adapted from original app.py)
175
+ has_multiple_questions = "?" in query and query.count("?") > 1
176
+ has_comparison = any(word in query.lower() for word in ["vs", "versus", "compare", "difference"])
177
+ has_optimization = any(word in query.lower() for word in ["optimize", "best", "improve", "maximize"])
178
+ has_technical = any(word in query.lower() for word in ["algorithm", "code", "implement", "technical"])
179
+
180
+ hrm_confidence = 0.5 + (0.3 if has_multiple_questions else 0) + (0.1 if has_technical else 0)
181
+ trm_confidence = 0.5 + (0.3 if has_comparison else 0) + (0.1 if len(query) > 100 else 0)
182
+ mcts_confidence = 0.5 + (0.3 if has_optimization else 0) + (0.1 if has_technical else 0)
183
+
184
+ total = hrm_confidence + trm_confidence + mcts_confidence
185
+ if total == 0:
186
+ hrm_confidence = trm_confidence = mcts_confidence = 1.0 / 3.0
187
+ else:
188
+ hrm_confidence /= total
189
+ trm_confidence /= total
190
+ mcts_confidence /= total
191
+
192
+ max_conf = max(hrm_confidence, trm_confidence, mcts_confidence)
193
+ consensus_score = min(hrm_confidence, trm_confidence, mcts_confidence) / max_conf if max_conf > 0 else 0.0
194
+
195
+ return MetaControllerFeatures(
196
+ hrm_confidence=hrm_confidence,
197
+ trm_confidence=trm_confidence,
198
+ mcts_value=mcts_confidence,
199
+ consensus_score=consensus_score,
200
+ last_agent=last_agent,
201
+ iteration=iteration,
202
+ query_length=len(query),
203
+ has_rag_context=len(query) > 50,
204
+ rag_relevance_score=0.0,
205
+ is_technical_query=has_technical,
206
+ )