konieshadow commited on
Commit
642af4d
·
1 Parent(s): 924aa01

修复llm调用问题

Browse files
examples/simple_llm.py CHANGED
@@ -15,17 +15,15 @@ if __name__ == "__main__":
15
  try:
16
  # model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
17
  model_name = "google/gemma-3-4b-it"
18
- use_4bit_quantization = False
19
- device = "mps"
20
 
21
  # gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
22
  # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
23
  if model_name.startswith("mlx-community"):
24
  gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
25
  else:
26
- # 如果设备是 mps,则使用 float32 以增加稳定性
27
- dtype_to_use = torch.float32 if device == "mps" else torch.float16
28
- gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device, torch_dtype=dtype_to_use)
29
 
30
  print("\n--- 示例 1: 简单用户查询 ---")
31
  messages_example1 = [
 
15
  try:
16
  # model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
17
  model_name = "google/gemma-3-4b-it"
18
+ device = "cuda"
 
19
 
20
  # gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
21
  # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
22
  if model_name.startswith("mlx-community"):
23
  gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
24
  else:
25
+ # 如果设备是 mps 或 cuda,则使用 float32 以增加稳定性
26
+ gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, device=device)
 
27
 
28
  print("\n--- 示例 1: 简单用户查询 ---")
29
  messages_example1 = [
src/podcast_transcribe/llm/llm_base.py CHANGED
@@ -174,45 +174,16 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
174
  def __init__(
175
  self,
176
  model_name: str,
177
- use_4bit_quantization: bool = False,
178
  device_map: Optional[str] = None,
179
  device: Optional[str] = None,
180
- trust_remote_code: bool = True,
181
- torch_dtype: Optional[torch.dtype] = None
182
  ):
183
  super().__init__(model_name)
184
- self.use_4bit_quantization = use_4bit_quantization
185
  self.device_map = device_map
186
- self.trust_remote_code = trust_remote_code
187
- self.torch_dtype = torch_dtype or torch.float16
188
  self.device = device
189
 
190
  # 加载模型和分词器
191
  self._load_model_and_tokenizer()
192
 
193
- def _get_quantization_config(self):
194
- """获取量化配置"""
195
- if not self.use_4bit_quantization:
196
- return None
197
-
198
- if self.device and self.device.type == "mps":
199
- print("警告: MPS 设备不支持 4bit 量化,将禁用量化")
200
- self.use_4bit_quantization = False
201
- return None
202
-
203
- # 导入量化配置
204
- try:
205
- from transformers import BitsAndBytesConfig
206
- except ImportError:
207
- raise ImportError("请先安装 bitsandbytes 库: pip install bitsandbytes")
208
-
209
- return BitsAndBytesConfig(
210
- load_in_4bit=True,
211
- bnb_4bit_compute_dtype=self.torch_dtype,
212
- bnb_4bit_quant_type="nf4",
213
- bnb_4bit_use_double_quant=True,
214
- )
215
-
216
  def _load_tokenizer(self):
217
  """加载分词器"""
218
  try:
@@ -222,7 +193,7 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
222
 
223
  self.tokenizer = AutoTokenizer.from_pretrained(
224
  self.model_name,
225
- trust_remote_code=self.trust_remote_code
226
  )
227
 
228
  # 设置 pad_token 如果不存在
@@ -237,22 +208,14 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
237
  raise ImportError("请先安装 transformers 库: pip install transformers")
238
 
239
  print(f"正在加载模型: {self.model_name}")
240
- print(f"4bit量化: {'启用' if self.use_4bit_quantization else '禁用'}")
241
  print(f"目标设备: {self.device}")
242
  print(f"设备映射: {self.device_map}")
243
 
244
  # 配置模型加载参数
245
  model_kwargs = {
246
- "trust_remote_code": self.trust_remote_code,
247
- "torch_dtype": self.torch_dtype,
248
  }
249
 
250
- # 处理量化配置
251
- quantization_config = self._get_quantization_config()
252
- if quantization_config:
253
- model_kwargs["quantization_config"] = quantization_config
254
- print(f"使用 4bit 量化配置")
255
-
256
  # 处理设备映射
257
  if self.device_map is not None:
258
  if self.device and self.device.type == "mps":
@@ -267,10 +230,9 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
267
  )
268
 
269
  # MPS 或手动设备管理
270
- if self.device_map is None or (self.device and self.device.type == "mps"):
271
- if not self.use_4bit_quantization:
272
- print(f"手动移动模型到设备: {self.device}")
273
- self.model = self.model.to(self.device)
274
 
275
  print(f"模型 {self.model_name} 加载成功")
276
 
@@ -287,11 +249,8 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
287
  def _print_error_hints(self):
288
  """打印错误提示信息"""
289
  print("请确保模型名称正确且可访问。")
290
- if self.use_4bit_quantization:
291
- print("如果使用量化,请确保已安装 bitsandbytes 库: pip install bitsandbytes")
292
- if self.device and self.device.type == "mps":
293
  print("MPS 设备注意事项:")
294
- print("- 不支持 4bit 量化")
295
  print("- 不支持 device_map")
296
  print("- 确保 PyTorch 版本支持 MPS")
297
 
@@ -352,12 +311,10 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
352
  """获取模型信息"""
353
  model_info = {
354
  "model_name": self.model_name,
355
- "use_4bit_quantization": self.use_4bit_quantization,
356
  "device": str(self.device),
357
  "device_type": self.device.type,
358
  "device_map": self.device_map,
359
  "model_type": "transformers",
360
- "torch_dtype": str(self.torch_dtype),
361
  "mps_available": torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False,
362
  "cuda_available": torch.cuda.is_available(),
363
  }
 
174
  def __init__(
175
  self,
176
  model_name: str,
 
177
  device_map: Optional[str] = None,
178
  device: Optional[str] = None,
 
 
179
  ):
180
  super().__init__(model_name)
 
181
  self.device_map = device_map
 
 
182
  self.device = device
183
 
184
  # 加载模型和分词器
185
  self._load_model_and_tokenizer()
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def _load_tokenizer(self):
188
  """加载分词器"""
189
  try:
 
193
 
194
  self.tokenizer = AutoTokenizer.from_pretrained(
195
  self.model_name,
196
+ trust_remote_code=True
197
  )
198
 
199
  # 设置 pad_token 如果不存在
 
208
  raise ImportError("请先安装 transformers 库: pip install transformers")
209
 
210
  print(f"正在加载模型: {self.model_name}")
 
211
  print(f"目标设备: {self.device}")
212
  print(f"设备映射: {self.device_map}")
213
 
214
  # 配置模型加载参数
215
  model_kwargs = {
216
+ "trust_remote_code": True,
 
217
  }
218
 
 
 
 
 
 
 
219
  # 处理设备映射
220
  if self.device_map is not None:
221
  if self.device and self.device.type == "mps":
 
230
  )
231
 
232
  # MPS 或手动设备管理
233
+ if self.device_map is None:
234
+ print(f"手动移动模型到设备: {self.device}")
235
+ self.model = self.model.to(self.device)
 
236
 
237
  print(f"模型 {self.model_name} 加载成功")
238
 
 
249
  def _print_error_hints(self):
250
  """打印错误提示信息"""
251
  print("请确保模型名称正确且可访问。")
252
+ if self.device and self.device == "mps":
 
 
253
  print("MPS 设备注意事项:")
 
254
  print("- 不支持 device_map")
255
  print("- 确保 PyTorch 版本支持 MPS")
256
 
 
311
  """获取模型信息"""
312
  model_info = {
313
  "model_name": self.model_name,
 
314
  "device": str(self.device),
315
  "device_type": self.device.type,
316
  "device_map": self.device_map,
317
  "model_type": "transformers",
 
318
  "mps_available": torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False,
319
  "cuda_available": torch.cuda.is_available(),
320
  }
src/podcast_transcribe/llm/llm_gemma_transfomers.py CHANGED
@@ -10,20 +10,14 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
10
  def __init__(
11
  self,
12
  model_name: str = "google/gemma-3-4b-it",
13
- use_4bit_quantization: bool = False,
14
  device_map: Optional[str] = None,
15
  device: Optional[str] = None,
16
- trust_remote_code: bool = True,
17
- torch_dtype: Optional[torch.dtype] = None
18
  ):
19
  # Gemma 使用 float16 作为默认数据类型
20
  super().__init__(
21
  model_name=model_name,
22
- use_4bit_quantization=use_4bit_quantization,
23
  device_map=device_map,
24
  device=device,
25
- trust_remote_code=trust_remote_code,
26
- torch_dtype=torch_dtype if torch_dtype is not None else torch.float16
27
  )
28
 
29
  def _print_error_hints(self):
@@ -38,7 +32,6 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
38
  # 为了保持向后兼容性,也可以提供一个简化的工厂函数
39
  def create_gemma_transformers_client(
40
  model_name: str = "google/gemma-3-4b-it",
41
- use_4bit_quantization: bool = False,
42
  device: Optional[str] = None,
43
  **kwargs
44
  ) -> GemmaTransformersChatCompletion:
@@ -47,7 +40,6 @@ def create_gemma_transformers_client(
47
 
48
  Args:
49
  model_name: 模型名称
50
- use_4bit_quantization: 是否使用4bit量化
51
  device: 指定设备 ("cpu", "cuda", "mps", 等)
52
  **kwargs: 其他传递给构造函数的参数
53
 
@@ -56,7 +48,6 @@ def create_gemma_transformers_client(
56
  """
57
  return GemmaTransformersChatCompletion(
58
  model_name=model_name,
59
- use_4bit_quantization=use_4bit_quantization,
60
  device=device,
61
  **kwargs
62
  )
 
10
  def __init__(
11
  self,
12
  model_name: str = "google/gemma-3-4b-it",
 
13
  device_map: Optional[str] = None,
14
  device: Optional[str] = None,
 
 
15
  ):
16
  # Gemma 使用 float16 作为默认数据类型
17
  super().__init__(
18
  model_name=model_name,
 
19
  device_map=device_map,
20
  device=device,
 
 
21
  )
22
 
23
  def _print_error_hints(self):
 
32
  # 为了保持向后兼容性,也可以提供一个简化的工厂函数
33
  def create_gemma_transformers_client(
34
  model_name: str = "google/gemma-3-4b-it",
 
35
  device: Optional[str] = None,
36
  **kwargs
37
  ) -> GemmaTransformersChatCompletion:
 
40
 
41
  Args:
42
  model_name: 模型名称
 
43
  device: 指定设备 ("cpu", "cuda", "mps", 等)
44
  **kwargs: 其他传递给构造函数的参数
45
 
 
48
  """
49
  return GemmaTransformersChatCompletion(
50
  model_name=model_name,
 
51
  device=device,
52
  **kwargs
53
  )
src/podcast_transcribe/llm/llm_router.py CHANGED
@@ -38,8 +38,7 @@ class LLMRouter:
38
  "class_name": "GemmaTransformersChatCompletion",
39
  "default_model": "google/gemma-3-4b-it",
40
  "supported_params": [
41
- "model_name", "use_4bit_quantization", "device_map",
42
- "device", "trust_remote_code", "torch_dtype"
43
  ],
44
  "description": "基于Transformers库的Gemma聊天完成实现"
45
  }
@@ -191,7 +190,7 @@ class LLMRouter:
191
  max_tokens: 最大生成token数
192
  top_p: nucleus采样参数
193
  model: 可选的模型名称,如果提供则覆盖默认model_name
194
- **kwargs: 其他参数,如device、use_4bit_quantization
195
 
196
  返回:
197
  聊天完成响应字典
@@ -207,12 +206,6 @@ class LLMRouter:
207
  if model is not None:
208
  kwargs["model_name"] = model
209
 
210
- # 如果设备是 mps,并且是 transformers provider,则强制使用 float32
211
- current_device = kwargs.get("device")
212
- if current_device == "mps":
213
- if provider == "gemma-transformers":
214
- kwargs["torch_dtype"] = torch.float32
215
-
216
  # 获取或创建LLM实例
217
  llm_instance = self._get_or_create_instance(provider, **kwargs)
218
 
@@ -271,12 +264,6 @@ class LLMRouter:
271
  if model is not None:
272
  kwargs["model_name"] = model
273
 
274
- # 如果设备是 mps,并且是 transformers provider,则强制使用 float32
275
- current_device = kwargs.get("device")
276
- if current_device == "mps":
277
- if provider == "gemma-transformers":
278
- kwargs["torch_dtype"] = torch.float32
279
-
280
  # 获取或创建LLM实例
281
  llm_instance = self._get_or_create_instance(provider, **kwargs)
282
 
@@ -378,9 +365,7 @@ def chat_completion(
378
  top_p: float = 1.0,
379
  model: Optional[str] = None,
380
  device: Optional[str] = None,
381
- use_4bit_quantization: bool = False,
382
  device_map: Optional[str] = None,
383
- trust_remote_code: bool = True,
384
  **kwargs
385
  ) -> Dict[str, Any]:
386
  """
@@ -396,9 +381,7 @@ def chat_completion(
396
  top_p: nucleus采样参数 (0.0-1.0)
397
  model: 模型名称,如果不指定则使用默认模型
398
  device: 推理设备,'cpu'、'cuda'、'mps'(仅transformers provider支持)
399
- use_4bit_quantization: 是否使用4bit量化(仅transformers provider支持)
400
  device_map: 设备映射配置(仅transformers provider支持)
401
- trust_remote_code: 是否信任远程代码(仅transformers provider支持)
402
  **kwargs: 其他参数
403
 
404
  返回:
@@ -417,7 +400,6 @@ def chat_completion(
417
  provider="gemma-transformers",
418
  model="google/gemma-3-4b-it",
419
  device="cuda",
420
- use_4bit_quantization=True
421
  )
422
 
423
  # 自定义参数
@@ -437,12 +419,8 @@ def chat_completion(
437
  params["model_name"] = model
438
  if device is not None:
439
  params["device"] = device
440
- if use_4bit_quantization:
441
- params["use_4bit_quantization"] = use_4bit_quantization
442
  if device_map:
443
  params["device_map"] = device_map
444
- if not trust_remote_code:
445
- params["trust_remote_code"] = trust_remote_code
446
 
447
  return _router.chat_completion(
448
  messages=messages,
@@ -463,9 +441,7 @@ def reasoning_completion(
463
  top_p: float = 0.9,
464
  model: Optional[str] = None,
465
  device: Optional[str] = None,
466
- use_4bit_quantization: bool = False,
467
  device_map: Optional[str] = None,
468
- trust_remote_code: bool = True,
469
  extract_reasoning_steps: bool = True,
470
  **kwargs
471
  ) -> Dict[str, Any]:
@@ -480,9 +456,7 @@ def reasoning_completion(
480
  top_p: nucleus采样参数
481
  model: 模型名称,如果不指定则使用默认模型
482
  device: 推理设备
483
- use_4bit_quantization: 是否使用4bit量化
484
  device_map: 设备映射配置
485
- trust_remote_code: 是否信任远程代码
486
  extract_reasoning_steps: 是否提取推理步骤
487
  **kwargs: 其他参数
488
 
@@ -510,12 +484,8 @@ def reasoning_completion(
510
  params["model_name"] = model
511
  if device is not None:
512
  params["device"] = device
513
- if use_4bit_quantization:
514
- params["use_4bit_quantization"] = use_4bit_quantization
515
  if device_map:
516
  params["device_map"] = device_map
517
- if not trust_remote_code:
518
- params["trust_remote_code"] = trust_remote_code
519
 
520
  return _router.reasoning_completion(
521
  messages=messages,
 
38
  "class_name": "GemmaTransformersChatCompletion",
39
  "default_model": "google/gemma-3-4b-it",
40
  "supported_params": [
41
+ "model_name", "device_map",
 
42
  ],
43
  "description": "基于Transformers库的Gemma聊天完成实现"
44
  }
 
190
  max_tokens: 最大生成token数
191
  top_p: nucleus采样参数
192
  model: 可选的模型名称,如果提供则覆盖默认model_name
193
+ **kwargs: 其他参数,如device等
194
 
195
  返回:
196
  聊天完成响应字典
 
206
  if model is not None:
207
  kwargs["model_name"] = model
208
 
 
 
 
 
 
 
209
  # 获取或创建LLM实例
210
  llm_instance = self._get_or_create_instance(provider, **kwargs)
211
 
 
264
  if model is not None:
265
  kwargs["model_name"] = model
266
 
 
 
 
 
 
 
267
  # 获取或创建LLM实例
268
  llm_instance = self._get_or_create_instance(provider, **kwargs)
269
 
 
365
  top_p: float = 1.0,
366
  model: Optional[str] = None,
367
  device: Optional[str] = None,
 
368
  device_map: Optional[str] = None,
 
369
  **kwargs
370
  ) -> Dict[str, Any]:
371
  """
 
381
  top_p: nucleus采样参数 (0.0-1.0)
382
  model: 模型名称,如果不指定则使用默认模型
383
  device: 推理设备,'cpu'、'cuda'、'mps'(仅transformers provider支持)
 
384
  device_map: 设备映射配置(仅transformers provider支持)
 
385
  **kwargs: 其他参数
386
 
387
  返回:
 
400
  provider="gemma-transformers",
401
  model="google/gemma-3-4b-it",
402
  device="cuda",
 
403
  )
404
 
405
  # 自定义参数
 
419
  params["model_name"] = model
420
  if device is not None:
421
  params["device"] = device
 
 
422
  if device_map:
423
  params["device_map"] = device_map
 
 
424
 
425
  return _router.chat_completion(
426
  messages=messages,
 
441
  top_p: float = 0.9,
442
  model: Optional[str] = None,
443
  device: Optional[str] = None,
 
444
  device_map: Optional[str] = None,
 
445
  extract_reasoning_steps: bool = True,
446
  **kwargs
447
  ) -> Dict[str, Any]:
 
456
  top_p: nucleus采样参数
457
  model: 模型名称,如果不指定则使用默认模型
458
  device: 推理设备
 
459
  device_map: 设备映射配置
 
460
  extract_reasoning_steps: 是否提取推理步骤
461
  **kwargs: 其他参数
462
 
 
484
  params["model_name"] = model
485
  if device is not None:
486
  params["device"] = device
 
 
487
  if device_map:
488
  params["device_map"] = device_map
 
 
489
 
490
  return _router.reasoning_completion(
491
  messages=messages,