Adding AzureSearch AI as vector store
This commit is contained in:
9
.vscode/launch.json
vendored
9
.vscode/launch.json
vendored
@@ -4,11 +4,20 @@
|
|||||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: Current File",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"justMyCode": false,
|
||||||
|
"console": "integratedTerminal"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Python:Streamlit",
|
"name": "Python:Streamlit",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"module": "streamlit",
|
"module": "streamlit",
|
||||||
|
"justMyCode": false,
|
||||||
"args": [
|
"args": [
|
||||||
"run",
|
"run",
|
||||||
"app/streamlit_app.py",
|
"app/streamlit_app.py",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from llm.ollama import load_llm
|
from llm.ollama import load_llm
|
||||||
from vectordb.vector_store import retrieve
|
from vectordb.azure_search import retrieve
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
|
||||||
@@ -16,23 +16,28 @@ prompt = PromptTemplate(
|
|||||||
input_variables=["question", "documents"],
|
input_variables=["question", "documents"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_rag_response(query):
|
def get_rag_response(query):
|
||||||
print("⌄⌄⌄⌄ Retrieving ⌄⌄⌄⌄")
|
print("⌄⌄⌄⌄ Retrieving ⌄⌄⌄⌄")
|
||||||
retrieved_docs, metadata = retrieve(query, 10)
|
retrieved_docs = retrieve(query, 10)
|
||||||
print("Query Found %d documents." % len(retrieved_docs[0]))
|
print("Query Found %d documents." % len(retrieved_docs))
|
||||||
for meta in metadata[0]:
|
|
||||||
print("Metadata: ", meta)
|
|
||||||
print("⌃⌃⌃⌃ Retrieving ⌃⌃⌃⌃ ")
|
print("⌃⌃⌃⌃ Retrieving ⌃⌃⌃⌃ ")
|
||||||
|
|
||||||
print("⌄⌄⌄⌄ Augmented Prompt ⌄⌄⌄⌄")
|
print("⌄⌄⌄⌄ Augmented Prompt ⌄⌄⌄⌄")
|
||||||
llm = load_llm()
|
llm = load_llm()
|
||||||
# Create a chain combining the prompt template and LLM
|
# Create a chain combining the prompt template and LLM
|
||||||
rag_chain = prompt | llm | StrOutputParser()
|
rag_chain = prompt | llm | StrOutputParser()
|
||||||
context = " ".join(retrieved_docs[0]) if retrieved_docs else "No relevant documents found."
|
context = (
|
||||||
|
(" ".join(doc.page_content) for doc in retrieved_docs)
|
||||||
|
if retrieved_docs
|
||||||
|
else "No relevant documents found."
|
||||||
|
)
|
||||||
|
|
||||||
print("⌃⌃⌃⌃ Augmented Prompt ⌃⌃⌃⌃")
|
print("⌃⌃⌃⌃ Augmented Prompt ⌃⌃⌃⌃")
|
||||||
|
|
||||||
print("⌄⌄⌄⌄ Generation ⌄⌄⌄⌄")
|
print("⌄⌄⌄⌄ Generation ⌄⌄⌄⌄")
|
||||||
response = rag_chain.invoke({"question": query, "context": context});
|
response = rag_chain.invoke({"question": query, "context": context})
|
||||||
print(response)
|
print(response)
|
||||||
print("⌃⌃⌃⌃ Generation ⌃⌃⌃⌃")
|
print("⌃⌃⌃⌃ Generation ⌃⌃⌃⌃")
|
||||||
|
|
||||||
|
|||||||
10
clearIndex.py
Normal file
10
clearIndex.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from vectordb.azure_search import delete_all_documents
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Deleting documents...")
|
||||||
|
delete_all_documents()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
from langchain_ollama import OllamaLLM
|
from langchain_ollama import OllamaLLM
|
||||||
|
|
||||||
|
|
||||||
def load_llm():
|
def load_llm():
|
||||||
return OllamaLLM(
|
return OllamaLLM(model="llama3.2", base_url="http://localhost:11434", temperature=0)
|
||||||
model="llama3.2",
|
|
||||||
base_url="http://localhost:11434",
|
|
||||||
temperature=0)
|
|
||||||
|
|||||||
@@ -58,17 +58,11 @@ class FireCrawlLoader(BaseLoader):
|
|||||||
|
|
||||||
def lazy_load(self) -> Iterator[Document]:
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
if self.mode == "scrape":
|
if self.mode == "scrape":
|
||||||
firecrawl_docs = [
|
firecrawl_docs = [self.firecrawl.scrape_url(self.url, **self.params)]
|
||||||
self.firecrawl.scrape_url(
|
|
||||||
self.url, **self.params
|
|
||||||
)
|
|
||||||
]
|
|
||||||
elif self.mode == "crawl":
|
elif self.mode == "crawl":
|
||||||
if not self.url:
|
if not self.url:
|
||||||
raise ValueError("URL is required for crawl mode")
|
raise ValueError("URL is required for crawl mode")
|
||||||
crawl_response = self.firecrawl.crawl_url(
|
crawl_response = self.firecrawl.crawl_url(self.url, **self.params)
|
||||||
self.url, **self.params
|
|
||||||
)
|
|
||||||
firecrawl_docs = crawl_response.data or []
|
firecrawl_docs = crawl_response.data or []
|
||||||
elif self.mode == "map":
|
elif self.mode == "map":
|
||||||
if not self.url:
|
if not self.url:
|
||||||
@@ -94,9 +88,7 @@ class FireCrawlLoader(BaseLoader):
|
|||||||
page_content = doc
|
page_content = doc
|
||||||
metadata = {}
|
metadata = {}
|
||||||
else:
|
else:
|
||||||
page_content = (
|
page_content = doc.markdown or doc.html or doc.rawHtml or ""
|
||||||
doc.markdown or doc.html or doc.rawHtml or ""
|
|
||||||
)
|
|
||||||
metadata = doc.metadata or {}
|
metadata = doc.metadata or {}
|
||||||
if not page_content:
|
if not page_content:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_community.document_loaders import PyPDFLoader
|
from langchain_community.document_loaders import PyPDFLoader
|
||||||
|
|
||||||
|
|
||||||
def load_pdf(file_path):
|
def load_pdf(file_path):
|
||||||
loader = PyPDFLoader(file_path)
|
loader = PyPDFLoader(file_path)
|
||||||
pages = loader.load()
|
|
||||||
print(f"Loaded {len(pages)} documents from {file_path}")
|
|
||||||
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
||||||
splits = splitter.split_documents(pages)
|
documents = loader.load_and_split(splitter)
|
||||||
documents = []
|
print(f"Loaded and Split into {len(documents)} documents from {file_path}")
|
||||||
metadatas = []
|
|
||||||
|
|
||||||
for split in splits:
|
return documents
|
||||||
documents.append(split.page_content)
|
|
||||||
metadatas.append(split.metadata)
|
|
||||||
|
|
||||||
return (documents, metadatas)
|
|
||||||
|
|||||||
@@ -3,23 +3,30 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
from loaders.firecrawl import FireCrawlLoader
|
from loaders.firecrawl import FireCrawlLoader
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_web_crawl(url):
|
def load_web_crawl(url):
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
metadatas = []
|
metadatas = []
|
||||||
|
|
||||||
loader = FireCrawlLoader(
|
loader = FireCrawlLoader(
|
||||||
url=url, api_key="changeme", api_url="http://localhost:3002", mode="crawl", params={ "limit": 100, "include_paths": ["/.*"], "ignore_sitemap": True, "poll_interval": 5 }
|
url=url,
|
||||||
|
api_key="changeme",
|
||||||
|
api_url="http://localhost:3002",
|
||||||
|
mode="crawl",
|
||||||
|
params={
|
||||||
|
"limit": 100,
|
||||||
|
"include_paths": ["/.*"],
|
||||||
|
"ignore_sitemap": True,
|
||||||
|
"poll_interval": 5,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
docs_lazy = loader.load()
|
docs_lazy = loader.load()
|
||||||
for doc in docs_lazy:
|
for doc in docs_lazy:
|
||||||
print('.', end="")
|
print(".", end="")
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
# Load documents from the URLs
|
# Load documents from the URLs
|
||||||
# docs = [WebBaseLoader(url).load() for url in urls]
|
# docs = [WebBaseLoader(url).load() for url in urls]
|
||||||
# docs_list = [item for sublist in docs for item in sublist]
|
# docs_list = [item for sublist in docs for item in sublist]
|
||||||
|
|||||||
16
main.py
16
main.py
@@ -1,16 +1,20 @@
|
|||||||
from loaders.pdf_loader import load_pdf
|
from loaders.pdf_loader import load_pdf
|
||||||
from loaders.web_loader import load_web_crawl
|
from loaders.web_loader import load_web_crawl
|
||||||
from vectordb.vector_store import add_documents
|
from vectordb.azure_search import add_documents
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("[1/2] Splitting and processing documents...")
|
print("[1/2] Splitting and processing documents...")
|
||||||
# pdf_documents = load_pdf("data/verint-responsible-ethical-ai.pdf")
|
pdf_documents = load_pdf("data/verint-responsible-ethical-ai.pdf")
|
||||||
# web_documents = load_web(["https://excalibur.mgmresorts.com/en.html"])
|
# web_documents = load_web_crawl(["https://excalibur.mgmresorts.com/en.html"])
|
||||||
web_documents = load_web_crawl("https://firecrawl.dev")
|
# web_documents = load_web_crawl(["https://www.verint.com"])
|
||||||
|
# web_documents = load_web_crawl("https://firecrawl.dev")
|
||||||
print("[2/2] Generating and storing embeddings...")
|
print("[2/2] Generating and storing embeddings...")
|
||||||
# add_documents(pdf_documents)
|
add_documents(pdf_documents)
|
||||||
add_documents(web_documents)
|
# add_documents(web_documents)
|
||||||
print("Embeddings stored. You can now run the Streamlit app with:\n")
|
print("Embeddings stored. You can now run the Streamlit app with:\n")
|
||||||
print(" streamlit run app/streamlit_app.py")
|
print(" streamlit run app/streamlit_app.py")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
langchain
|
langchain
|
||||||
langchain-community
|
langchain-community
|
||||||
langchain-chroma
|
langchain-chroma
|
||||||
|
langchain-openai
|
||||||
chromadb
|
chromadb
|
||||||
pypdf
|
pypdf
|
||||||
streamlit
|
streamlit
|
||||||
@@ -9,3 +10,7 @@ langchain_ollama
|
|||||||
bs4
|
bs4
|
||||||
tiktoken
|
tiktoken
|
||||||
firecrawl-py
|
firecrawl-py
|
||||||
|
azure-search-documents
|
||||||
|
azure-identity
|
||||||
|
python-dotenv
|
||||||
|
black
|
||||||
|
|||||||
2
retriever/.gitignore
vendored
Normal file
2
retriever/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
node_modules/
|
||||||
|
.env
|
||||||
59
retriever/index.js
Normal file
59
retriever/index.js
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import {
|
||||||
|
AzureAISearchVectorStore,
|
||||||
|
AzureAISearchQueryType,
|
||||||
|
} from "@langchain/community/vectorstores/azure_aisearch";
|
||||||
|
import { OpenAIEmbeddings, AzureOpenAIEmbeddings } from "@langchain/openai";
|
||||||
|
|
||||||
|
const query = process.argv[2] || "What is CX Automation?";
|
||||||
|
|
||||||
|
// the RAG widget uses the OpenAIEmbeddings class but the config will not work because you cannot pass in the api-version param. DO NOT USE
|
||||||
|
// const embedding = new OpenAIEmbeddings({
|
||||||
|
// openAIApiKey: openAIApiKey,
|
||||||
|
// configuration: {
|
||||||
|
// baseURL: `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}?api-version=${azureOpenAIApiVersion}`,
|
||||||
|
// },
|
||||||
|
// })
|
||||||
|
|
||||||
|
const embedding = new AzureOpenAIEmbeddings({
|
||||||
|
azureOpenAIApiInstanceName: process.env.AZURE_OPENAI_API_INSTANCE_NAME,
|
||||||
|
azureOpenAIApiDeploymentName: process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME,
|
||||||
|
azureOpenAIApiVersion: process.env.AZURE_OPENAI_API_VERSION,
|
||||||
|
azureOpenAIApiKey: process.env.AZURE_OPENAI_API_KEY,
|
||||||
|
});
|
||||||
|
|
||||||
|
const store = new AzureAISearchVectorStore(embedding, {
|
||||||
|
endpoint: process.env.AZURE_AISEARCH_ENDPOINT,
|
||||||
|
key: process.env.AZURE_AISEARCH_KEY,
|
||||||
|
indexName: process.env.AZURE_AISEARCH_INDEX_NAME,
|
||||||
|
search: {
|
||||||
|
type: AzureAISearchQueryType.SimilarityHybrid,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
function getSourceId(document) {
|
||||||
|
if (document.metadata) {
|
||||||
|
const mergedMetadata = Object.values(document.metadata).join("");
|
||||||
|
const metatDataObj = JSON.parse(mergedMetadata);
|
||||||
|
return metatDataObj.source;
|
||||||
|
} else return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
const resultDocuments = await store.similaritySearch(query);
|
||||||
|
const sources = resultDocuments.map((doc) => ({
|
||||||
|
source_id: getSourceId(doc),
|
||||||
|
text: doc.pageContent,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const cqaSources = {
|
||||||
|
instances: [
|
||||||
|
{
|
||||||
|
sources: sources,
|
||||||
|
question: query,
|
||||||
|
knowledgebase_description: process.env.AZURE_AISEARCH_INDEX_NAME,
|
||||||
|
extra_guidance: "",
|
||||||
|
language_code: "en-GB",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
console.log(JSON.stringify(cqaSources, null, 2));
|
||||||
2425
retriever/package-lock.json
generated
Normal file
2425
retriever/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
17
retriever/package.json
Normal file
17
retriever/package.json
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"name": "retriever",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"main": "index.js",
|
||||||
|
"scripts": {
|
||||||
|
"test": "node --env-file=.env index.js \"What is Verint?\""
|
||||||
|
},
|
||||||
|
"author": "",
|
||||||
|
"license": "MIT",
|
||||||
|
"description": "",
|
||||||
|
"dependencies": {
|
||||||
|
"@azure/search-documents": "^12.1.0",
|
||||||
|
"@langchain/community": "^0.3.43",
|
||||||
|
"@langchain/core": "^0.3.56"
|
||||||
|
},
|
||||||
|
"type": "module"
|
||||||
|
}
|
||||||
14
shell.nix
14
shell.nix
@@ -1,14 +0,0 @@
|
|||||||
let
|
|
||||||
pkgs = import <nixpkgs> {};
|
|
||||||
in pkgs.mkShell {
|
|
||||||
packages = [
|
|
||||||
(pkgs.python3.withPackages (python-pkgs: [
|
|
||||||
python-pkgs.langchain
|
|
||||||
python-pkgs.langchain-community
|
|
||||||
python-pkgs.chromadb
|
|
||||||
python-pkgs.pypdf
|
|
||||||
python-pkgs.streamlit
|
|
||||||
python-pkgs.ollama
|
|
||||||
]))
|
|
||||||
];
|
|
||||||
}
|
|
||||||
139
vectordb/azure_search.py
Normal file
139
vectordb/azure_search.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
from langchain_community.vectorstores.azuresearch import AzureSearch
|
||||||
|
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
load_dotenv() # take environment variables
|
||||||
|
required_env_vars = [
|
||||||
|
"AZURE_DEPLOYMENT",
|
||||||
|
"AZURE_OPENAI_API_VERSION",
|
||||||
|
"AZURE_ENDPOINT",
|
||||||
|
"AZURE_OPENAI_API_KEY",
|
||||||
|
"VECTOR_STORE_ADDRESS",
|
||||||
|
"VECTOR_STORE_PASSWORD",
|
||||||
|
"INDEX_NAME",
|
||||||
|
"RETRY_TOTAL",
|
||||||
|
]
|
||||||
|
|
||||||
|
missing_vars = [var for var in required_env_vars if not os.environ.get(var)]
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use AzureOpenAIEmbeddings with an Azure account
|
||||||
|
embeddings: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings(
|
||||||
|
azure_deployment=os.getenv("AZURE_DEPLOYMENT"),
|
||||||
|
openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||||
|
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Specify additional properties for the Azure client such as the following https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations
|
||||||
|
vector_store: AzureSearch = AzureSearch(
|
||||||
|
azure_search_endpoint=os.getenv("VECTOR_STORE_ADDRESS"),
|
||||||
|
azure_search_key=os.getenv("VECTOR_STORE_PASSWORD"),
|
||||||
|
index_name=os.getenv("INDEX_NAME"),
|
||||||
|
embedding_function=embeddings.embed_query,
|
||||||
|
# Configure max retries for the Azure client
|
||||||
|
additional_search_client_options={"retry_total": os.getenv("RETRY_TOTAL")},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_document_id(document):
|
||||||
|
"""
|
||||||
|
Get the document ID from the document object.
|
||||||
|
"""
|
||||||
|
if hasattr(document, "metadata") and "id" in document.metadata:
|
||||||
|
return document.metadata["id"]
|
||||||
|
elif hasattr(document, "id"):
|
||||||
|
return document.id
|
||||||
|
else:
|
||||||
|
raise ValueError("Document does not have a valid ID.")
|
||||||
|
|
||||||
|
|
||||||
|
def delete_all_documents():
|
||||||
|
"""
|
||||||
|
Delete all documents from the AzureSearch vector store.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
docs_to_delete = []
|
||||||
|
while True:
|
||||||
|
# Delete all documents in the index
|
||||||
|
docs_to_delete = retrieve("", 10)
|
||||||
|
|
||||||
|
vector_store.delete(list(map(get_document_id, docs_to_delete)))
|
||||||
|
if len(docs_to_delete) > 0:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
print("All documents deleted successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting documents: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def add_documents(documents):
|
||||||
|
# uuids = [str(uuid4()) for _ in range(len(documents))]
|
||||||
|
|
||||||
|
try:
|
||||||
|
vector_store.add_documents(documents)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error adding document to vector store: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve(query_text, n_results=1):
|
||||||
|
# Perform a similarity search
|
||||||
|
docs = vector_store.similarity_search(
|
||||||
|
query=query_text,
|
||||||
|
k=n_results,
|
||||||
|
search_type="similarity",
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
# def add_document_to_vector_store(document):
|
||||||
|
# """
|
||||||
|
# Add a document to the AzureSearch vector store.
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# vector_store: The initialized AzureSearch vector store instance.
|
||||||
|
# document: A dictionary or object representing the document to be added.
|
||||||
|
# Example format:
|
||||||
|
# {
|
||||||
|
# "id": "unique_document_id",
|
||||||
|
# "content": "The text content of the document",
|
||||||
|
# "metadata": {
|
||||||
|
# "source": "source_url",
|
||||||
|
# "created": "2025-03-04T14:14:40.421666",
|
||||||
|
# "modified": "2025-03-04T14:14:40.421666"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
|
||||||
|
# # Add the document to the vector store
|
||||||
|
# vector_store.add_documents([document])
|
||||||
|
# print(f"Document with ID {document['id']} added successfully.")
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Error adding document to vector store: {str(e)}")
|
||||||
|
|
||||||
|
# add_document_to_vector_store("https://api.python.langchain.com/en/latest/langchain_api_reference.html",None)
|
||||||
|
# Example document to add
|
||||||
|
|
||||||
|
# doc = Document(
|
||||||
|
# page_content="This is the content of the document.For testing IVA demo integration ",
|
||||||
|
# metadata= {
|
||||||
|
# "source": "https://example.com/source",
|
||||||
|
# "created": "2025-03-04T14:14:40.421666",
|
||||||
|
# "modified": "2025-03-04T14:14:40.421666"
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# Add the document to the vector store
|
||||||
|
# add_document_to_vector_store( doc)
|
||||||
|
|
||||||
|
# result = retrieve("iva",1)
|
||||||
@@ -2,11 +2,12 @@ from typing import Tuple
|
|||||||
import chromadb
|
import chromadb
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
# from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
# from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||||
# OllamaEmbeddingFunction,
|
# OllamaEmbeddingFunction,
|
||||||
# )
|
# )
|
||||||
from langchain_ollama import OllamaEmbeddings
|
from langchain_ollama import OllamaEmbeddings
|
||||||
from chromadb.api.types import (Metadata,Document,OneOrMany)
|
from chromadb.api.types import Metadata, Document, OneOrMany
|
||||||
|
|
||||||
|
|
||||||
# Define a custom embedding function for ChromaDB using Ollama
|
# Define a custom embedding function for ChromaDB using Ollama
|
||||||
@@ -14,6 +15,7 @@ class ChromaDBEmbeddingFunction:
|
|||||||
"""
|
"""
|
||||||
Custom embedding function for ChromaDB using embeddings from Ollama.
|
Custom embedding function for ChromaDB using embeddings from Ollama.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, langchain_embeddings):
|
def __init__(self, langchain_embeddings):
|
||||||
self.langchain_embeddings = langchain_embeddings
|
self.langchain_embeddings = langchain_embeddings
|
||||||
|
|
||||||
@@ -23,11 +25,12 @@ class ChromaDBEmbeddingFunction:
|
|||||||
input = [input]
|
input = [input]
|
||||||
return self.langchain_embeddings.embed_documents(input)
|
return self.langchain_embeddings.embed_documents(input)
|
||||||
|
|
||||||
|
|
||||||
# Initialize the embedding function with Ollama embeddings
|
# Initialize the embedding function with Ollama embeddings
|
||||||
embedding = ChromaDBEmbeddingFunction(
|
embedding = ChromaDBEmbeddingFunction(
|
||||||
OllamaEmbeddings(
|
OllamaEmbeddings(
|
||||||
model="nomic-embed-text",
|
model="nomic-embed-text",
|
||||||
base_url="http://localhost:11434" # Adjust the base URL as per your Ollama server configuration
|
base_url="http://localhost:11434", # Adjust the base URL as per your Ollama server configuration
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,18 +39,17 @@ persistent_client = chromadb.PersistentClient()
|
|||||||
collection = persistent_client.get_or_create_collection(
|
collection = persistent_client.get_or_create_collection(
|
||||||
name="collection_name",
|
name="collection_name",
|
||||||
metadata={"description": "A collection for RAG with Ollama - Demo1"},
|
metadata={"description": "A collection for RAG with Ollama - Demo1"},
|
||||||
embedding_function=embedding # Use the custom embedding function)
|
embedding_function=embedding, # Use the custom embedding function)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_documents(documents: Tuple[OneOrMany[Document], OneOrMany[Metadata]]):
|
def add_documents(documents: Tuple[OneOrMany[Document], OneOrMany[Metadata]]):
|
||||||
docs, metas = documents
|
docs, metas = documents
|
||||||
uuids = [str(uuid4()) for _ in range(len(docs))]
|
uuids = [str(uuid4()) for _ in range(len(docs))]
|
||||||
collection.add(documents=docs, ids=uuids, metadatas=metas)
|
collection.add(documents=docs, ids=uuids, metadatas=metas)
|
||||||
|
|
||||||
|
|
||||||
def retrieve(query_text, n_results=1):
|
def retrieve(query_text, n_results=1):
|
||||||
# return vector_store.similarity_search(query, k=3)
|
# return vector_store.similarity_search(query, k=3)
|
||||||
results = collection.query(
|
results = collection.query(query_texts=[query_text], n_results=n_results)
|
||||||
query_texts=[query_text],
|
|
||||||
n_results=n_results
|
|
||||||
)
|
|
||||||
return results["documents"], results["metadatas"]
|
return results["documents"], results["metadatas"]
|
||||||
Reference in New Issue
Block a user