rag-system/vectordb/vector_store.py
2025-05-01 12:21:47 -05:00

53 lines
1.8 KiB
Python

from typing import Tuple
import chromadb
from langchain_chroma import Chroma
from uuid import uuid4
# from chromadb.utils.embedding_functions.ollama_embedding_function import (
# OllamaEmbeddingFunction,
# )
from langchain_ollama import OllamaEmbeddings
from chromadb.api.types import (Metadata,Document,OneOrMany)
# Define a custom embedding function for ChromaDB using Ollama
class ChromaDBEmbeddingFunction:
"""
Custom embedding function for ChromaDB using embeddings from Ollama.
"""
def __init__(self, langchain_embeddings):
self.langchain_embeddings = langchain_embeddings
def __call__(self, input):
# Ensure the input is in a list format for processing
if isinstance(input, str):
input = [input]
return self.langchain_embeddings.embed_documents(input)
# Initialize the embedding function with Ollama embeddings
embedding = ChromaDBEmbeddingFunction(
OllamaEmbeddings(
model="nomic-embed-text",
base_url="http://localhost:11434" # Adjust the base URL as per your Ollama server configuration
)
)
persistent_client = chromadb.PersistentClient()
collection = persistent_client.get_or_create_collection(
name="collection_name",
metadata={"description": "A collection for RAG with Ollama - Demo1"},
embedding_function=embedding # Use the custom embedding function)
)
def add_documents(documents: Tuple[OneOrMany[Document], OneOrMany[Metadata]]):
docs, metas = documents
uuids = [str(uuid4()) for _ in range(len(docs))]
collection.add(documents=docs, ids=uuids, metadatas=metas)
def retrieve(query_text, n_results=1):
# return vector_store.similarity_search(query, k=3)
results = collection.query(
query_texts=[query_text],
n_results=n_results
)
return results["documents"], results["metadatas"]