康心伴Logo
康心伴WellAlly
Backend Development

使用RAG构建营养聊天机器人(FastAPI+Pinecone)

5 分钟阅读

使用RAG构建营养聊天机器人 (FastAPI + Pinecone)

在健康科技领域,用户经常有各种营养相关的问题。本文将教你如何使用检索增强生成(RAG)技术构建一个智能营养聊天机器人,能够基于权威营养资料回答用户问题。

什么是RAG?

RAG (Retrieval-Augmented Generation) 是一种结合信息检索和文本生成的AI技术:

组件作用
文档加载将营养知识转换为可搜索格式
向量嵌入将文本转换为数学向量表示
向量存储高效存储和检索相似内容
检索找到与问题最相关的知识
生成基于检索到的知识生成答案

优势:

  • 基于真实数据,减少幻觉
  • 可追溯答案来源
  • 易于更新知识库
  • 符合医疗领域准确性要求

项目架构

code
nutrition-chatbot/
├── app/
│   ├── main.py                      # FastAPI应用
│   ├── config.py                    # 配置管理
│   ├── api/
│   │   └── chat.py                  # 聊天API路由
│   ├── services/
│   │   ├── rag_service.py           # RAG服务
│   │   ├── vector_service.py        # 向量数据库服务
│   │   └── llm_service.py           # LLM服务
│   ├── models/
│   │   └── schemas.py               # Pydantic模型
│   └── utils/
│       ├── document_loader.py       # 文档加载
│       └── text_splitter.py         # 文本分割
├── knowledge/
│   ├── nutrition_guidelines/        # 营养指南
│   ├── food_database/               # 食物数据库
│   └── research_papers/             # 研究论文
├── tests/
├── requirements.txt
└── .env.example
Code collapsed

1. 项目依赖

requirements.txt

code
fastapi==0.104.1
uvicorn[standard]==0.24.0
pydantic==2.5.0
python-dotenv==1.0.0

# LangChain
langchain==0.1.0
langchain-openai==0.0.2
langchain-community==0.0.10

# 向量数据库
pinecone-client==3.0.0

# 文档处理
pypdf==3.17.4
docx2txt==0.8
beautifulsoup4==4.12.2
markdown==3.5.1

# 其他
httpx==0.25.1
numpy==1.26.2
tiktoken==0.5.2
Code collapsed

.env.example

code
# OpenAI API
OPENAI_API_KEY=your_openai_api_key
OPENAI_MODEL=gpt-4-turbo-preview
EMBEDDING_MODEL=text-embedding-3-small

# Pinecone
PINECONE_API_KEY=your_pinecone_api_key
PINECONE_ENVIRONMENT=your_environment
PINECONE_INDEX_NAME=nutrition-bot

# 应用配置
APP_NAME=营养聊天机器人
MAX_CONTEXT_LENGTH=4000
TOP_K_RESULTS=5
CHUNK_SIZE=1000
CHUNK_OVERLAP=200
Code collapsed

2. 数据模型

app/models/schemas.py

code
from pydantic import BaseModel, Field
from typing import List, Optional
from datetime import datetime

class ChatMessage(BaseModel):
    """聊天消息"""
    role: str = Field(..., description: "角色: user/assistant/system")
    content: str = Field(..., description: "消息内容")
    timestamp: datetime = Field(default_factory=datetime.now)

class ChatRequest(BaseModel):
    """聊天请求"""
    question: str = Field(..., min_length=1, description: "用户问题")
    session_id: Optional[str] = Field(None, description: "会话ID")
    context: Optional[List[ChatMessage]] = Field(None, description: "历史上下文")

class SourceDocument(BaseModel):
    """来源文档"""
    content: str
    source: str
    page: Optional[int] = None
    score: float
    metadata: dict = Field(default_factory=dict)

class ChatResponse(BaseModel):
    """聊天响应"""
    answer: str
    sources: List[SourceDocument]
    session_id: str
    timestamp: datetime = Field(default_factory=datetime.now)

class NutritionQuery(BaseModel):
    """营养查询"""
    food_item: Optional[str] = None
    nutrient: Optional[str] = None
    health_condition: Optional[str] = None
    dietary_preference: Optional[str] = None

class DocumentIngest(BaseModel):
    """文档导入请求"""
    file_path: str
    document_type: str
    metadata: dict = Field(default_factory=dict)
Code collapsed

3. 向量数据库服务

app/services/vector_service.py

code
import os
from typing import List, Dict, Any, Optional
from pinecone import Pinecone, ServerlessSpec
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Pinecone as PineconeLangchain
from app.config import settings

class VectorService:
    """向量数据库服务"""

    def __init__(self):
        self.pc = Pinecone(api_key=settings.PINECONE_API_KEY)
        self.embeddings = OpenAIEmbeddings(
            model=settings.EMBEDDING_MODEL,
            openai_api_key=settings.OPENAI_API_KEY
        )
        self.index_name = settings.PINECONE_INDEX_NAME
        self._initialize_index()

    def _initialize_index(self):
        """初始化Pinecone索引"""
        existing_indexes = [index.name for index in self.pc.list_indexes()]

        if self.index_name not in existing_indexes:
            self.pc.create_index(
                name=self.index_name,
                dimension=1536,  # text-embedding-3-small维度
                metric: "cosine",
                spec=ServerlessSpec(
                    cloud: "aws",
                    region: "us-east-1"
                )
            )
            print(f"创建新索引: {self.index_name}")

        self.index = self.pc.Index(self.index_name)

    async def add_documents(
        self,
        texts: List[str],
        metadatas: List[Dict[str, Any]]
    ) -> List[str]:
        """添加文档到向量库"""
        # 生成嵌入向量
        embeddings = self.embeddings.embed_documents(texts)

        # 准备向量记录
        vectors = []
        for i, (text, embedding, metadata) in enumerate(zip(texts, embeddings, metadatas)):
            vector_id = f"{metadata.get('source', 'doc')}_{i}_{hash(text) % 10000}"
            vectors.append({
                "id": vector_id,
                "values": embedding,
                "metadata": {
                    **metadata,
                    "text": text[:1000]  # 保存部分文本用于预览
                }
            })

        # 批量上传
        self.index.upsert(vectors=vectors)

        return [v["id"] for v in vectors]

    async def similarity_search(
        self,
        query: str,
        top_k: int = 5,
        filter: Optional[Dict[str, Any]] = None
    ) -> List[Dict[str, Any]]:
        """相似度搜索"""
        # 生成查询向量
        query_embedding = self.embeddings.embed_query(query)

        # 搜索
        results = self.index.query(
            vector=query_embedding,
            top_k=top_k,
            include_metadata=True,
            filter=filter
        )

        # 格式化结果
        documents = []
        for match in results.matches:
            documents.append({
                "content": match.metadata.get("text", ""),
                "source": match.metadata.get("source", ""),
                "page": match.metadata.get("page"),
                "score": match.score,
                "metadata": match.metadata
            })

        return documents

    async def delete_documents(self, filter: Dict[str, Any]):
        """删除文档"""
        self.index.delete(filter=filter)

    async def get_stats(self) -> Dict[str, Any]:
        """获取索引统计信息"""
        return self.index.describe_index_stats()
Code collapsed

4. RAG服务

app/services/rag_service.py

code
from typing import List, Dict, Any, Optional
from langchain_openai import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

from app.services.vector_service import VectorService
from app.services.llm_service import LLMService
from app.models.schemas import ChatRequest, ChatResponse, SourceDocument
from app.config import settings

class RAGService:
    """RAG服务 - 检索增强生成"""

    def __init__(self):
        self.vector_service = VectorService()
        self.llm_service = LLMService()
        self.chat_history = {}  # 存储会话历史

    async def chat(self, request: ChatRequest) -> ChatResponse:
        """处理聊天请求"""
        # 1. 检索相关文档
        sources = await self.vector_service.similarity_search(
            query=request.question,
            top_k=settings.TOP_K_RESULTS
        )

        if not sources:
            return ChatResponse(
                answer: "抱歉,我没有找到与您的问题相关的营养信息。请尝试更具体的问题或联系营养师。",
                sources=[],
                session_id=request.session_id or "new"
            )

        # 2. 构建上下文
        context = self._build_context(sources)

        # 3. 加载会话历史
        memory = self._load_memory(request.session_id)

        # 4. 生成回答
        answer = await self.llm_service.generate_answer(
            question=request.question,
            context=context,
            chat_history=memory
        )

        # 5. 保存会话历史
        if request.session_id:
            self._save_history(request.session_id, request.question, answer)

        # 6. 构建响应
        return ChatResponse(
            answer=answer,
            sources=[SourceDocument(**source) for source in sources],
            session_id=request.session_id or "new"
        )

    def _build_context(self, sources: List[Dict[str, Any]]) -> str:
        """构建检索上下文"""
        context_parts = []
        for i, source in enumerate(sources, 1):
            context_parts.append(
                f"[来源 {i}] {source['content']}\n"
                f"来源: {source['source']}"
                f"{f' (第{source['page']}页)' if source.get('page') else ''}"
            )
        return "\n\n".join(context_parts)

    def _load_memory(self, session_id: Optional[str]) -> List[Dict[str, str]]:
        """加载会话记忆"""
        if not session_id or session_id not in self.chat_history:
            return []
        return self.chat_history[session_id][-5:]  # 保留最近5轮对话

    def _save_history(self, session_id: str, question: str, answer: str):
        """保存对话历史"""
        if session_id not in self.chat_history:
            self.chat_history[session_id] = []
        self.chat_history[session_id].extend([
            {"role": "user", "content": question},
            {"role": "assistant", "content": answer}
        ])

    async def nutrition_query(self, query: NutritionQuery) -> Dict[str, Any]:
        """营养数据查询"""
        # 构建查询语句
        query_parts = []
        if query.food_item:
            query_parts.append(f"{query.food_item}的营养成分")
        if query.nutrient:
            query_parts.append(f"{query.nutrient}的作用和需求量")
        if query.health_condition:
            query_parts.append(f"{query.health_condition}的饮食建议")

        search_query = " ".join(query_parts) if query_parts else "一般营养建议"

        # 检索相关信息
        sources = await self.vector_service.similarity_search(
            query=search_query,
            top_k=5
        )

        # 生成结构化回答
        context = self._build_context(sources)
        answer = await self.llm_service.generate_nutrition_answer(
            query=search_query,
            context=context
        )

        return {
            "query": query.dict(),
            "answer": answer,
            "sources": sources
        }
Code collapsed

5. LLM服务

app/services/llm_service.py

code
from typing import List, Dict, Any, Optional
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from app.config import settings

class LLMService:
    """LLM服务"""

    def __init__(self):
        self.llm = ChatOpenAI(
            model=settings.OPENAI_MODEL,
            temperature=0.7,
            openai_api_key=settings.OPENAI_API_KEY
        )
        self._setup_prompts()

    def _setup_prompts(self):
        """设置提示模板"""
        # 系统提示
        self.system_prompt = """你是一个专业营养师助手,基于权威营养指南和科学研究回答用户问题。

遵循以下原则:
1. 仅使用提供的检索内容回答问题,不编造信息
2. 如果检索内容不足以回答问题,明确告知用户
3. 引用信息来源时标注来源编号
4. 对于涉及医疗健康的建议,提醒用户咨询专业医生
5. 使用清晰、易懂的语言解释专业概念
6. 提供实用的饮食建议和营养知识"""

        # 聊天提示模板
        self.chat_template = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(self.system_prompt),
            ("human", """上下文信息:
{context}

历史对话:
{chat_history}

用户问题: {question}

请基于以上信息回答用户问题。""")
        ])

    async def generate_answer(
        self,
        question: str,
        context: str,
        chat_history: Optional[List[Dict[str, str]]] = None
    ) -> str:
        """生成回答"""
        # 格式化对话历史
        history_text = ""
        if chat_history:
            for msg in chat_history[-3:]:  # 最近3轮
                role = "用户" if msg["role"] == "user" else "助手"
                history_text += f"{role}: {msg['content']}\n"

        # 生成提示
        prompt = self.chat_template.format_messages(
            context=context,
            chat_history=history_text or "无",
            question=question
        )

        # 调用LLM
        response = await self.llm.ainvoke(prompt)
        return response.content

    async def generate_nutrition_answer(
        self,
        query: str,
        context: str
    ) -> Dict[str, Any]:
        """生成营养数据回答"""
        prompt = f"""基于以下营养资料,回答用户的查询:

查询: {query}

资料:
{context}

请提供:
1. 直接回答用户的问题
2. 相关的营养建议
3. 注意事项或禁忌(如有)
4. 建议咨询专业人士的提示"""

        response = await self.llm.ainvoke(prompt)

        return {
            "answer": response.content,
            "query_type": "nutrition_query"
        }
Code collapsed

6. 文档处理工具

app/utils/document_loader.py

code
from typing import List, Dict, Any, Optional
from pathlib import Path
import pypdf
import docx2txt
from bs4 import BeautifulSoup
import markdown

class DocumentLoader:
    """文档加载器"""

    def load_pdf(self, file_path: str) -> List[Dict[str, Any]]:
        """加载PDF文档"""
        documents = []
        with open(file_path, 'rb') as file:
            reader = pypdf.PdfReader(file)
            for page_num, page in enumerate(reader.pages, 1):
                text = page.extract_text()
                if text.strip():
                    documents.append({
                        "content": text,
                        "source": Path(file_path).name,
                        "page": page_num,
                        "type": "pdf"
                    })
        return documents

    def load_docx(self, file_path: str) -> List[Dict[str, Any]]:
        """加载Word文档"""
        text = docx2txt.process(file_path)
        return [{
            "content": text,
            "source": Path(file_path).name,
            "type": "docx"
        }]

    def load_html(self, file_path: str) -> List[Dict[str, Any]]:
        """加载HTML文档"""
        with open(file_path, 'r', encoding='utf-8') as file:
            soup = BeautifulSoup(file.read(), 'html.parser')
            text = soup.get_text(separator='\n', strip=True)
        return [{
            "content": text,
            "source": Path(file_path).name,
            "type": "html"
        }]

    def load_markdown(self, file_path: str) -> List[Dict[str, Any]]:
        """加载Markdown文档"""
        with open(file_path, 'r', encoding='utf-8') as file:
            md_content = file.read()
            html = markdown.markdown(md_content)
            soup = BeautifulSoup(html, 'html.parser')
            text = soup.get_text(separator='\n', strip=True)
        return [{
            "content": text,
            "source": Path(file_path).name,
            "type": "markdown"
        }]

    def load_text(self, file_path: str) -> List[Dict[str, Any]]:
        """加载纯文本文档"""
        with open(file_path, 'r', encoding='utf-8') as file:
            text = file.read()
        return [{
            "content": text,
            "source": Path(file_path).name,
            "type": "text"
        }]

    def load_directory(self, directory: str) -> List[Dict[str, Any]]:
        """加载目录下所有文档"""
        documents = []
        path = Path(directory)

        for file_path in path.rglob('*'):
            if file_path.is_file():
                ext = file_path.suffix.lower()
                try:
                    if ext == '.pdf':
                        documents.extend(self.load_pdf(str(file_path)))
                    elif ext in ['.docx', '.doc']:
                        documents.extend(self.load_docx(str(file_path)))
                    elif ext in ['.html', '.htm']:
                        documents.extend(self.load_html(str(file_path)))
                    elif ext == '.md':
                        documents.extend(self.load_markdown(str(file_path)))
                    elif ext == '.txt':
                        documents.extend(self.load_text(str(file_path)))
                except Exception as e:
                    print(f"加载文件失败 {file_path}: {e}")

        return documents
Code collapsed

app/utils/text_splitter.py

code
from typing import List
import re

class TextSplitter:
    """文本分割器"""

    def __init__(
        self,
        chunk_size: int = 1000,
        chunk_overlap: int = 200
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

    def split_text(self, text: str) -> List[str]:
        """分割文本为块"""
        # 按段落分割
        paragraphs = re.split(r'\n\s*\n', text)

        chunks = []
        current_chunk = ""

        for paragraph in paragraphs:
            paragraph = paragraph.strip()
            if not paragraph:
                continue

            # 如果当前块加上新段落超出限制,保存当前块
            if len(current_chunk) + len(paragraph) > self.chunk_size:
                if current_chunk:
                    chunks.append(current_chunk.strip())

                # 开始新块,保留重叠部分
                if self.chunk_overlap > 0 and chunks:
                    last_chunk = chunks[-1]
                    overlap_start = max(0, len(last_chunk) - self.chunk_overlap)
                    current_chunk = last_chunk[overlap_start:] + "\n\n" + paragraph
                else:
                    current_chunk = paragraph
            else:
                if current_chunk:
                    current_chunk += "\n\n" + paragraph
                else:
                    current_chunk = paragraph

        # 添加最后一个块
        if current_chunk:
            chunks.append(current_chunk.strip())

        return chunks

    def split_documents(self, documents: List[dict]) -> List[dict]:
        """分割文档列表"""
        split_docs = []

        for doc in documents:
            chunks = self.split_text(doc['content'])
            for i, chunk in enumerate(chunks):
                split_docs.append({
                    'content': chunk,
                    'source': doc['source'],
                    'page': doc.get('page'),
                    'chunk_index': i,
                    'type': doc.get('type', 'text')
                })

        return split_docs
Code collapsed

7. API路由

app/api/chat.py

code
from fastapi import APIRouter, HTTPException
from typing import List

from app.models.schemas import ChatRequest, ChatResponse, NutritionQuery
from app.services.rag_service import RAGService

router = APIRouter(prefix: "/api/chat", tags=["聊天"])
rag_service = RAGService()

@router.post("/ask", response_model=ChatResponse)
async def ask_question(request: ChatRequest):
    """
    提出营养相关问题

    - **question**: 用户问题
    - **session_id**: 会话ID(可选,用于多轮对话)
    """
    try:
        response = await rag_service.chat(request)
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/nutrition", response_model=dict)
async def query_nutrition(query: NutritionQuery):
    """
    营养数据查询

    - **food_item**: 食物名称
    - **nutrient**: 营养素名称
    - **health_condition**: 健康状况
    - **dietary_preference**: 饮食偏好
    """
    try:
        result = await rag_service.nutrition_query(query)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/session/end")
async def end_session(session_id: str):
    """结束会话"""
    # 清理会话历史
    if session_id in rag_service.chat_history:
        del rag_service.chat_history[session_id]
    return {"message": "会话已结束"}
Code collapsed

8. 主应用

app/main.py

code
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.chat import router as chat_router
from app.config import settings

app = FastAPI(
    title: "营养聊天机器人API",
    description: "基于RAG的智能营养咨询系统",
    version: "1.0.0"
)

# CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 注册路由
app.include_router(chat_router)

@app.get("/health")
async def health_check():
    """健康检查"""
    return {
        "status": "healthy",
        "service": "nutrition-chatbot",
        "model": settings.OPENAI_MODEL
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host: "0.0.0.0", port=8000)
Code collapsed

API使用示例

提出问题

code
POST /api/chat/ask
Content-Type: application/json

{
  "question": "糖尿病患者应该如何控制饮食?",
  "session_id": "user_123"
}
Code collapsed

营养查询

code
POST /api/chat/nutrition
Content-Type: application/json

{
  "food_item": "三文鱼",
  "nutrient": "Omega-3",
  "health_condition": "心血管疾病"
}
Code collapsed

通过本教程,你已掌握使用FastAPI、LangChain和Pinecone构建RAG营养聊天机器人的核心技术。这个系统可扩展到其他健康咨询场景,提供准确、可追溯的健康信息服务。

#

文章标签

FastAPI
RAG
Pinecone
LangChain
营养咨询
AI聊天机器人

觉得这篇文章有帮助?

立即体验康心伴,开始您的健康管理之旅