refactored project to use poetry
This commit is contained in:
0
rag_system/__init__.py
Normal file
0
rag_system/__init__.py
Normal file
0
rag_system/app/__init__.py
Normal file
0
rag_system/app/__init__.py
Normal file
44
rag_system/app/rag_chain.py
Normal file
44
rag_system/app/rag_chain.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from llm.ollama import load_llm
|
||||
from vectordb.azure_search import retrieve
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
# Define the prompt template for the LLM
|
||||
prompt = PromptTemplate(
|
||||
template="""You are an assistant for question-answering tasks.
|
||||
Use the following context to answer the question.
|
||||
If you don't know the answer, just say that you don't know.
|
||||
Use three sentences maximum and keep the answer concise:
|
||||
Question: {question}
|
||||
Context: {context}
|
||||
Answer:
|
||||
""",
|
||||
input_variables=["question", "documents"],
|
||||
)
|
||||
|
||||
|
||||
def get_rag_response(query):
|
||||
print("⌄⌄⌄⌄ Retrieving ⌄⌄⌄⌄")
|
||||
retrieved_docs = retrieve(query, 10)
|
||||
print("Query Found %d documents." % len(retrieved_docs))
|
||||
|
||||
print("⌃⌃⌃⌃ Retrieving ⌃⌃⌃⌃ ")
|
||||
|
||||
print("⌄⌄⌄⌄ Augmented Prompt ⌄⌄⌄⌄")
|
||||
llm = load_llm()
|
||||
# Create a chain combining the prompt template and LLM
|
||||
rag_chain = prompt | llm | StrOutputParser()
|
||||
context = (
|
||||
(" ".join(doc.page_content) for doc in retrieved_docs)
|
||||
if retrieved_docs
|
||||
else "No relevant documents found."
|
||||
)
|
||||
|
||||
print("⌃⌃⌃⌃ Augmented Prompt ⌃⌃⌃⌃")
|
||||
|
||||
print("⌄⌄⌄⌄ Generation ⌄⌄⌄⌄")
|
||||
response = rag_chain.invoke({"question": query, "context": context})
|
||||
print(response)
|
||||
print("⌃⌃⌃⌃ Generation ⌃⌃⌃⌃")
|
||||
|
||||
return response
|
||||
9
rag_system/app/streamlit_app.py
Normal file
9
rag_system/app/streamlit_app.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import streamlit as st
|
||||
from rag_system.app.rag_chain import get_rag_response
|
||||
|
||||
st.title("RAG System")
|
||||
query = st.text_input("Ask a question:")
|
||||
if query:
|
||||
response = get_rag_response(query)
|
||||
st.write("### Response:")
|
||||
st.write(response)
|
||||
10
rag_system/clear_index.py
Normal file
10
rag_system/clear_index.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from rag_system.vectordb.azure_search import delete_all_documents
|
||||
|
||||
|
||||
def main():
|
||||
print("Deleting documents...")
|
||||
delete_all_documents()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
20
rag_system/crawler.py
Normal file
20
rag_system/crawler.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from rag_system.loaders.pdf_loader import load_pdf
|
||||
from rag_system.loaders.web_loader import load_web_crawl
|
||||
from rag_system.vectordb.azure_search import add_documents
|
||||
|
||||
|
||||
def main():
|
||||
print("[1/2] Splitting and processing documents...")
|
||||
pdf_documents = load_pdf("data/verint-responsible-ethical-ai.pdf")
|
||||
# web_documents = load_web_crawl(["https://excalibur.mgmresorts.com/en.html"])
|
||||
# web_documents = load_web_crawl(["https://www.verint.com"])
|
||||
# web_documents = load_web_crawl("https://firecrawl.dev")
|
||||
print("[2/2] Generating and storing embeddings...")
|
||||
add_documents(pdf_documents)
|
||||
# add_documents(web_documents)
|
||||
print("Embeddings stored. You can now run the Streamlit app with:\n")
|
||||
print(" streamlit run rag_system/app/streamlit_app.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
rag_system/llm/__init__.py
Normal file
0
rag_system/llm/__init__.py
Normal file
5
rag_system/llm/ollama.py
Normal file
5
rag_system/llm/ollama.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from langchain_ollama import OllamaLLM
|
||||
|
||||
|
||||
def load_llm():
|
||||
return OllamaLLM(model="llama3.2", base_url="http://localhost:11434", temperature=0)
|
||||
0
rag_system/loaders/__init__.py
Normal file
0
rag_system/loaders/__init__.py
Normal file
98
rag_system/loaders/firecrawl.py
Normal file
98
rag_system/loaders/firecrawl.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import warnings
|
||||
from typing import Iterator, Literal, Optional
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
|
||||
class FireCrawlLoader(BaseLoader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
mode: Literal["crawl", "scrape", "map", "extract"] = "crawl",
|
||||
params: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with API key and url.
|
||||
|
||||
Args:
|
||||
url: The url to be crawled.
|
||||
api_key: The Firecrawl API key. If not specified will be read from env var
|
||||
FIRECRAWL_API_KEY. Get an API key
|
||||
api_url: The Firecrawl API URL. If not specified will be read from env var
|
||||
FIRECRAWL_API_URL or defaults to https://api.firecrawl.dev.
|
||||
mode: The mode to run the loader in. Default is "crawl".
|
||||
Options include "scrape" (single url),
|
||||
"crawl" (all accessible sub pages),
|
||||
"map" (returns list of links that are semantically related).
|
||||
"extract" (extracts structured data from a page).
|
||||
params: The parameters to pass to the Firecrawl API.
|
||||
Examples include crawlerOptions.
|
||||
For more details, visit: https://github.com/mendableai/firecrawl-py
|
||||
"""
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`firecrawl` package not found, please run `pip install firecrawl-py`"
|
||||
)
|
||||
if mode not in ("crawl", "scrape", "search", "map", "extract"):
|
||||
raise ValueError(
|
||||
f"""Invalid mode '{mode}'.
|
||||
Allowed: 'crawl', 'scrape', 'search', 'map', 'extract'."""
|
||||
)
|
||||
|
||||
if not url:
|
||||
raise ValueError("Url must be provided")
|
||||
|
||||
api_key = api_key or get_from_env("api_key", "FIRECRAWL_API_KEY")
|
||||
self.firecrawl = FirecrawlApp(api_key=api_key, api_url=api_url)
|
||||
self.url = url
|
||||
self.mode = mode
|
||||
self.params = params or {}
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
if self.mode == "scrape":
|
||||
firecrawl_docs = [self.firecrawl.scrape_url(self.url, **self.params)]
|
||||
elif self.mode == "crawl":
|
||||
if not self.url:
|
||||
raise ValueError("URL is required for crawl mode")
|
||||
crawl_response = self.firecrawl.crawl_url(self.url, **self.params)
|
||||
firecrawl_docs = crawl_response.data or []
|
||||
elif self.mode == "map":
|
||||
if not self.url:
|
||||
raise ValueError("URL is required for map mode")
|
||||
firecrawl_docs = self.firecrawl.map_url(self.url, params=self.params)
|
||||
elif self.mode == "extract":
|
||||
if not self.url:
|
||||
raise ValueError("URL is required for extract mode")
|
||||
firecrawl_docs = [
|
||||
str(self.firecrawl.extract([self.url], params=self.params))
|
||||
]
|
||||
elif self.mode == "search":
|
||||
raise ValueError(
|
||||
"Search mode is not supported in this version, please downgrade."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""Invalid mode '{self.mode}'.
|
||||
Allowed: 'crawl', 'scrape', 'map', 'extract'."""
|
||||
)
|
||||
for doc in firecrawl_docs:
|
||||
if self.mode == "map" or self.mode == "extract":
|
||||
page_content = doc
|
||||
metadata = {}
|
||||
else:
|
||||
page_content = doc.markdown or doc.html or doc.rawHtml or ""
|
||||
metadata = doc.metadata or {}
|
||||
if not page_content:
|
||||
continue
|
||||
yield Document(
|
||||
page_content=page_content,
|
||||
metadata=metadata,
|
||||
)
|
||||
11
rag_system/loaders/pdf_loader.py
Normal file
11
rag_system/loaders/pdf_loader.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
|
||||
def load_pdf(file_path):
|
||||
loader = PyPDFLoader(file_path)
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
||||
documents = loader.load_and_split(splitter)
|
||||
print(f"Loaded and Split into {len(documents)} documents from {file_path}")
|
||||
|
||||
return documents
|
||||
44
rag_system/loaders/web_loader.py
Normal file
44
rag_system/loaders/web_loader.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from rag_system.loaders.firecrawl import FireCrawlLoader
|
||||
|
||||
|
||||
def load_web_crawl(url):
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key="changeme",
|
||||
api_url="http://localhost:3002",
|
||||
mode="crawl",
|
||||
params={
|
||||
"limit": 100,
|
||||
"include_paths": ["/.*"],
|
||||
"ignore_sitemap": True,
|
||||
"poll_interval": 5,
|
||||
},
|
||||
)
|
||||
docs = []
|
||||
docs_lazy = loader.load()
|
||||
for doc in docs_lazy:
|
||||
print(".", end="")
|
||||
docs.append(doc)
|
||||
print()
|
||||
|
||||
# Load documents from the URLs
|
||||
# docs = [WebBaseLoader(url).load() for url in urls]
|
||||
# docs_list = [item for sublist in docs for item in sublist]
|
||||
# Initialize a text splitter with specified chunk size and overlap
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=250, chunk_overlap=0
|
||||
)
|
||||
# Split the documents into chunks
|
||||
splits = text_splitter.split_documents(docs)
|
||||
|
||||
for split in splits:
|
||||
documents.append(split.page_content)
|
||||
metadatas.append(split.metadata)
|
||||
|
||||
return (documents, metadatas)
|
||||
0
rag_system/vectordb/__init__.py
Normal file
0
rag_system/vectordb/__init__.py
Normal file
139
rag_system/vectordb/azure_search.py
Normal file
139
rag_system/vectordb/azure_search.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import os
|
||||
|
||||
from typing import Tuple
|
||||
from langchain_community.vectorstores.azuresearch import AzureSearch
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from dotenv import load_dotenv
|
||||
from uuid import uuid4
|
||||
|
||||
load_dotenv() # take environment variables
|
||||
required_env_vars = [
|
||||
"AZURE_DEPLOYMENT",
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
"AZURE_ENDPOINT",
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"VECTOR_STORE_ADDRESS",
|
||||
"VECTOR_STORE_PASSWORD",
|
||||
"INDEX_NAME",
|
||||
"RETRY_TOTAL",
|
||||
]
|
||||
|
||||
missing_vars = [var for var in required_env_vars if not os.environ.get(var)]
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
# Use AzureOpenAIEmbeddings with an Azure account
|
||||
embeddings: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings(
|
||||
azure_deployment=os.getenv("AZURE_DEPLOYMENT"),
|
||||
openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
)
|
||||
|
||||
# Specify additional properties for the Azure client such as the following https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations
|
||||
vector_store: AzureSearch = AzureSearch(
|
||||
azure_search_endpoint=os.getenv("VECTOR_STORE_ADDRESS"),
|
||||
azure_search_key=os.getenv("VECTOR_STORE_PASSWORD"),
|
||||
index_name=os.getenv("INDEX_NAME"),
|
||||
embedding_function=embeddings.embed_query,
|
||||
# Configure max retries for the Azure client
|
||||
additional_search_client_options={"retry_total": os.getenv("RETRY_TOTAL")},
|
||||
)
|
||||
|
||||
|
||||
def get_document_id(document):
|
||||
"""
|
||||
Get the document ID from the document object.
|
||||
"""
|
||||
if hasattr(document, "metadata") and "id" in document.metadata:
|
||||
return document.metadata["id"]
|
||||
elif hasattr(document, "id"):
|
||||
return document.id
|
||||
else:
|
||||
raise ValueError("Document does not have a valid ID.")
|
||||
|
||||
|
||||
def delete_all_documents():
|
||||
"""
|
||||
Delete all documents from the AzureSearch vector store.
|
||||
"""
|
||||
try:
|
||||
|
||||
docs_to_delete = []
|
||||
while True:
|
||||
# Delete all documents in the index
|
||||
docs_to_delete = retrieve("", 10)
|
||||
|
||||
vector_store.delete(list(map(get_document_id, docs_to_delete)))
|
||||
if len(docs_to_delete) > 0:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
print("All documents deleted successfully.")
|
||||
except Exception as e:
|
||||
print(f"Error deleting documents: {str(e)}")
|
||||
|
||||
|
||||
def add_documents(documents):
|
||||
# uuids = [str(uuid4()) for _ in range(len(documents))]
|
||||
|
||||
try:
|
||||
vector_store.add_documents(documents)
|
||||
except Exception as e:
|
||||
print(f"Error adding document to vector store: {str(e)}")
|
||||
|
||||
|
||||
def retrieve(query_text, n_results=1):
|
||||
# Perform a similarity search
|
||||
docs = vector_store.similarity_search(
|
||||
query=query_text,
|
||||
k=n_results,
|
||||
search_type="similarity",
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
# def add_document_to_vector_store(document):
|
||||
# """
|
||||
# Add a document to the AzureSearch vector store.
|
||||
|
||||
# Args:
|
||||
# vector_store: The initialized AzureSearch vector store instance.
|
||||
# document: A dictionary or object representing the document to be added.
|
||||
# Example format:
|
||||
# {
|
||||
# "id": "unique_document_id",
|
||||
# "content": "The text content of the document",
|
||||
# "metadata": {
|
||||
# "source": "source_url",
|
||||
# "created": "2025-03-04T14:14:40.421666",
|
||||
# "modified": "2025-03-04T14:14:40.421666"
|
||||
# }
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
|
||||
# # Add the document to the vector store
|
||||
# vector_store.add_documents([document])
|
||||
# print(f"Document with ID {document['id']} added successfully.")
|
||||
# except Exception as e:
|
||||
# print(f"Error adding document to vector store: {str(e)}")
|
||||
|
||||
# add_document_to_vector_store("https://api.python.langchain.com/en/latest/langchain_api_reference.html",None)
|
||||
# Example document to add
|
||||
|
||||
# doc = Document(
|
||||
# page_content="This is the content of the document.For testing IVA demo integration ",
|
||||
# metadata= {
|
||||
# "source": "https://example.com/source",
|
||||
# "created": "2025-03-04T14:14:40.421666",
|
||||
# "modified": "2025-03-04T14:14:40.421666"
|
||||
# }
|
||||
# )
|
||||
# Add the document to the vector store
|
||||
# add_document_to_vector_store( doc)
|
||||
|
||||
# result = retrieve("iva",1)
|
||||
55
rag_system/vectordb/chromadb.py
Normal file
55
rag_system/vectordb/chromadb.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Tuple
|
||||
import chromadb
|
||||
from langchain_chroma import Chroma
|
||||
from uuid import uuid4
|
||||
|
||||
# from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
# OllamaEmbeddingFunction,
|
||||
# )
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
from chromadb.api.types import Metadata, Document, OneOrMany
|
||||
|
||||
|
||||
# Define a custom embedding function for ChromaDB using Ollama
|
||||
class ChromaDBEmbeddingFunction:
|
||||
"""
|
||||
Custom embedding function for ChromaDB using embeddings from Ollama.
|
||||
"""
|
||||
|
||||
def __init__(self, langchain_embeddings):
|
||||
self.langchain_embeddings = langchain_embeddings
|
||||
|
||||
def __call__(self, input):
|
||||
# Ensure the input is in a list format for processing
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
return self.langchain_embeddings.embed_documents(input)
|
||||
|
||||
|
||||
# Initialize the embedding function with Ollama embeddings
|
||||
embedding = ChromaDBEmbeddingFunction(
|
||||
OllamaEmbeddings(
|
||||
model="nomic-embed-text",
|
||||
base_url="http://localhost:11434", # Adjust the base URL as per your Ollama server configuration
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
persistent_client = chromadb.PersistentClient()
|
||||
collection = persistent_client.get_or_create_collection(
|
||||
name="collection_name",
|
||||
metadata={"description": "A collection for RAG with Ollama - Demo1"},
|
||||
embedding_function=embedding, # Use the custom embedding function)
|
||||
)
|
||||
|
||||
|
||||
def add_documents(documents: Tuple[OneOrMany[Document], OneOrMany[Metadata]]):
|
||||
docs, metas = documents
|
||||
uuids = [str(uuid4()) for _ in range(len(docs))]
|
||||
collection.add(documents=docs, ids=uuids, metadatas=metas)
|
||||
|
||||
|
||||
def retrieve(query_text, n_results=1):
|
||||
# return vector_store.similarity_search(query, k=3)
|
||||
results = collection.query(query_texts=[query_text], n_results=n_results)
|
||||
return results["documents"], results["metadatas"]
|
||||
Reference in New Issue
Block a user