Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| """ | |
| 鸟类知识智能科普系统 | |
| """ | |
| import streamlit as st | |
| from PIL import Image | |
| import tempfile | |
| from transformers import pipeline, AutoConfig | |
| import torch | |
| # ========== 模型配置 ========== | |
| MODEL_CONFIG = { | |
| "image_to_text": { | |
| "model": "chriamue/bird-species-classifier", | |
| "config": {"use_fast": True} # 强制启用快速处理器 | |
| }, | |
| "text_generation": { | |
| "model": "Qwen/Qwen-7B-Chat", | |
| "config": AutoConfig.from_pretrained("Qwen/Qwen-7B-Chat", revision="main") | |
| }, | |
| "text_to_speech": { | |
| "model": "facebook/mms-tts-eng", | |
| "config": {"speaker_id": 6} # 儿童音色 | |
| } | |
| } | |
| # ========== 模型初始化 ========== | |
| def init_pipelines(): | |
| """缓存模型加载结果避免重复初始化""" | |
| try: | |
| img_pipeline = pipeline( | |
| "image-classification", | |
| model=MODEL_CONFIG["image_to_text"]["model"], | |
| **MODEL_CONFIG["image_to_text"]["config"] | |
| ) | |
| text_pipeline = pipeline( | |
| "text-generation", | |
| model=MODEL_CONFIG["text_generation"]["model"], | |
| config=MODEL_CONFIG["text_generation"]["config"], | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| tts_pipeline = pipeline( | |
| "text-to-speech", | |
| model=MODEL_CONFIG["text_to_speech"]["model"], | |
| **MODEL_CONFIG["text_to_speech"]["config"] | |
| ) | |
| return img_pipeline, text_pipeline, tts_pipeline | |
| except Exception as e: | |
| st.error(f"模型加载失败: {str(e)}") | |
| st.stop() | |
| # ========== 核心功能 ========== | |
| def generate_description(_pipe, bird_name): | |
| """生成儿童友好型描述""" | |
| prompt = f"用6-12岁儿童能理解的语言描述{bird_name},使用比喻和趣味知识:" | |
| return _pipe(prompt, max_new_tokens=120)[0]['generated_text'].split(":")[-1] | |
| # ========== 界面设计 ========== | |
| st.set_page_config(page_title="鸟类知识百科", page_icon="🐦") | |
| st.title("🐦 智能鸟类科普系统") | |
| st.markdown("上传鸟类图片,获取趣味知识讲解") | |
| # 主流程 | |
| def main(): | |
| img_pipe, text_pipe, tts_pipe = init_pipelines() | |
| uploaded_file = st.file_uploader("选择图片文件", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file: | |
| with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp_file: | |
| # 保存临时文件 | |
| tmp_file.write(uploaded_file.getvalue()) | |
| with st.spinner("识别中..."): | |
| # 识别鸟类 | |
| result = img_pipe(Image.open(tmp_file.name)) | |
| bird_name = result[0]['label'] | |
| st.success(f"识别结果:{bird_name}") | |
| # 生成描述 | |
| desc = generate_description(text_pipe, bird_name) | |
| st.subheader("趣味知识") | |
| st.write(desc) | |
| # 语音合成 | |
| audio = tts_pipe(desc[:1000]) # 限制文本长度 | |
| st.audio(audio["audio"], sample_rate=audio["sampling_rate"]) | |
| if __name__ == "__main__": | |
| main() | |