Add ChromaDB implementation
This commit is contained in:
165
backend/chroma_client.py
Normal file
165
backend/chroma_client.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
ChromaDB Client for storing and retrieving document embeddings
|
||||
"""
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils import embedding_functions
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
class ChromaClient:
|
||||
"""
|
||||
Client for interacting with ChromaDB vector database.
|
||||
Uses Ollama for generating embeddings if available, otherwise falls back to default.
|
||||
"""
|
||||
|
||||
def __init__(self, host, port, collection_name='munich_news_articles', ollama_base_url=None):
|
||||
"""
|
||||
Initialize ChromaDB client
|
||||
|
||||
Args:
|
||||
host: ChromaDB host (e.g. 'localhost' or 'chromadb')
|
||||
port: ChromaDB port (default 8000)
|
||||
collection_name: Name of the collection to use
|
||||
ollama_base_url: Optional URL for Ollama embedding function
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.collection_name = collection_name
|
||||
self.client = None
|
||||
self.collection = None
|
||||
|
||||
# Setup embedding function
|
||||
# We prefer using a local embedding model compatible with Ollama or SentenceTransformers
|
||||
# For simplicity in this stack, we can use the default SentenceTransformer (all-MiniLM-L6-v2)
|
||||
# which is downloaded automatically by chromadb utils.
|
||||
# Alternatively, we could define a custom function using Ollama's /api/embeddings
|
||||
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
|
||||
|
||||
def connect(self):
|
||||
"""Establish connection to ChromaDB"""
|
||||
try:
|
||||
self.client = chromadb.HttpClient(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
settings=Settings(allow_reset=True, anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
# Create or get collection
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
print(f"✓ Connected to ChromaDB at {self.host}:{self.port}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"⚠ Could not connect to ChromaDB: {e}")
|
||||
return False
|
||||
|
||||
def add_articles(self, articles):
|
||||
"""
|
||||
Add articles to the vector database
|
||||
|
||||
Args:
|
||||
articles: List of dictionaries containing article data.
|
||||
Must have 'link' (used as ID), 'title', 'content', etc.
|
||||
"""
|
||||
if not self.client or not self.collection:
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
if not articles:
|
||||
return True
|
||||
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for article in articles:
|
||||
# Skip if critical data missing
|
||||
if not article.get('link') or not article.get('content'):
|
||||
continue
|
||||
|
||||
# Use link as unique ID
|
||||
article_id = article.get('link')
|
||||
|
||||
# Prepare text for embedding (Title + Summary + Start of Content)
|
||||
# This gives semantic search a good overview
|
||||
title = article.get('title', '')
|
||||
summary = article.get('summary') or ''
|
||||
content_snippet = article.get('content', '')[:1000]
|
||||
|
||||
text_to_embed = f"{title}\n\n{summary}\n\n{content_snippet}"
|
||||
|
||||
# robust metadata (flat dict, no nested objects)
|
||||
metadata = {
|
||||
"title": title[:100], # Truncate for metadata limits
|
||||
"url": article_id,
|
||||
"source": article.get('source', 'unknown'),
|
||||
"category": article.get('category', 'general'),
|
||||
"published_at": str(article.get('published_at', '')),
|
||||
"mongo_id": str(article.get('_id', ''))
|
||||
}
|
||||
|
||||
ids.append(article_id)
|
||||
documents.append(text_to_embed)
|
||||
metadatas.append(metadata)
|
||||
|
||||
if not ids:
|
||||
return True
|
||||
|
||||
try:
|
||||
self.collection.upsert(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
metadatas=metadatas
|
||||
)
|
||||
print(f"✓ Indexed {len(ids)} articles in ChromaDB")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to index in ChromaDB: {e}")
|
||||
return False
|
||||
|
||||
def search(self, query_text, n_results=5, where=None):
|
||||
"""
|
||||
Search for relevant articles
|
||||
|
||||
Args:
|
||||
query_text: The search query
|
||||
n_results: Number of results to return
|
||||
where: Metadata filter dict (e.g. {"category": "sports"})
|
||||
"""
|
||||
if not self.client or not self.collection:
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
results = self.collection.query(
|
||||
query_texts=[query_text],
|
||||
n_results=n_results,
|
||||
where=where
|
||||
)
|
||||
|
||||
# Format results into a nice list of dicts
|
||||
formatted_results = []
|
||||
if results and results['ids']:
|
||||
for i, id in enumerate(results['ids'][0]):
|
||||
item = {
|
||||
'id': id,
|
||||
'document': results['documents'][0][i] if results['documents'] else None,
|
||||
'metadata': results['metadatas'][0][i] if results['metadatas'] else {},
|
||||
'distance': results['distances'][0][i] if results['distances'] else 0
|
||||
}
|
||||
formatted_results.append(item)
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
print(f"✗ Search failed: {e}")
|
||||
return []
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test client
|
||||
client = ChromaClient(host='localhost', port=8000)
|
||||
client.connect()
|
||||
Reference in New Issue
Block a user