1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_community.vectorstores import Chroma from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from config import ( LLM_MODEL, LLM_TEMPERATURE, EMBEDDING_MODEL, RETRIEVAL_TOP_K, CHROMA_PERSIST_DIR, OPENAI_API_KEY, OPENAI_BASE_URL, )
class RAGEngine: """RAG 知识库问答引擎"""
def __init__(self): self.embeddings = OpenAIEmbeddings( model=EMBEDDING_MODEL, openai_api_key=OPENAI_API_KEY, openai_api_base=OPENAI_BASE_URL, )
self.llm = ChatOpenAI( model=LLM_MODEL, temperature=LLM_TEMPERATURE, openai_api_key=OPENAI_API_KEY, openai_api_base=OPENAI_BASE_URL, )
self.vectorstore = Chroma( persist_directory=CHROMA_PERSIST_DIR, embedding_function=self.embeddings, )
self.prompt = PromptTemplate( template="""你是一个专业的知识库助手。请根据以下参考资料回答用户的问题。
规则: 1. 只根据提供的参考资料回答,不要编造信息 2. 如果参考资料中没有相关信息,请明确说"根据现有资料,我无法回答这个问题" 3. 回答时请引用信息来源(文档名) 4. 回答要简洁、准确、有条理
参考资料: {context}
用户问题:{question}
回答:""", input_variables=["context", "question"], )
self.qa_chain = RetrievalQA.from_chain_type( llm=self.llm, chain_type="stuff", retriever=self.vectorstore.as_retriever( search_kwargs={"k": RETRIEVAL_TOP_K} ), chain_type_kwargs={"prompt": self.prompt}, return_source_documents=True, )
def add_documents(self, chunks): """将文档块添加到向量数据库""" self.vectorstore.add_documents(chunks) self.vectorstore.persist()
def query(self, question: str) -> dict: """查询并返回答案和来源""" result = self.qa_chain.invoke({"query": question})
sources = [] for doc in result["source_documents"]: source = doc.metadata.get("source", "未知") page = doc.metadata.get("page", "") if source not in sources: sources.append(source)
return { "answer": result["result"], "sources": sources, }
def get_document_count(self) -> int: """获取知识库中的文档块数量""" collection = self.vectorstore._collection return collection.count()
def clear(self): """清空知识库""" self.vectorstore._collection.delete( ids=self.vectorstore._collection.get()["ids"] )
|