fastapi API is created for similarity search
This commit is contained in:
parent
0a6e41d046
commit
6bd794bdf5
162
main.py
Normal file
162
main.py
Normal file
@ -0,0 +1,162 @@
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_postgres import PGVector
|
||||
from dotenv import load_dotenv
|
||||
import uvicorn
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="Islamic Duas Semantic Search API",
|
||||
description="Semantic search API for Islamic duas based on tags",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Change to specific origins in production
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Database configuration
|
||||
CONNECTION_STRING = 'postgresql+psycopg2://postgres:test@localhost:5433/vector_db'
|
||||
COLLECTION_NAME = 'duas_tags_vectors'
|
||||
|
||||
# Initialize embeddings and vector store
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
||||
db = PGVector(
|
||||
collection_name=COLLECTION_NAME,
|
||||
connection=CONNECTION_STRING,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
# Pydantic models
|
||||
class DuaResult(BaseModel):
|
||||
id: Optional[str] = None
|
||||
arabic: Optional[str] = None
|
||||
transliteration: Optional[str] = None
|
||||
translation: Optional[str] = None
|
||||
urdu: Optional[str] = None
|
||||
romanUrdu: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
occasion: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
similarity_score: float
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
query: str
|
||||
results_count: int
|
||||
results: List[DuaResult]
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str = Field(..., description="Search query", example="protection from evil")
|
||||
k: int = Field(5, description="Number of results to return", ge=1, le=50)
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/", tags=["Health"])
|
||||
async def root():
|
||||
return {
|
||||
"message": "Islamic Duas Semantic Search API",
|
||||
"status": "running",
|
||||
"endpoints": {
|
||||
"search_get": "/search?query=protection&k=5",
|
||||
"search_post": "/search (POST)",
|
||||
"health": "/health"
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
async def health_check():
|
||||
try:
|
||||
# Test database connection
|
||||
test_results = db.similarity_search_with_score("test", k=1)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"database": "connected",
|
||||
"embeddings": "loaded"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# GET endpoint for search
|
||||
@app.get("/search", response_model=SearchResponse, tags=["Search"])
|
||||
async def search_duas_get(
|
||||
query: str = Query(..., description="Search query", example="protection from evil"),
|
||||
k: int = Query(5, description="Number of results", ge=1, le=50)
|
||||
):
|
||||
"""
|
||||
Search for duas using semantic similarity on tags.
|
||||
|
||||
- **query**: Your search query (e.g., "morning prayers", "protection from evil")
|
||||
- **k**: Number of results to return (1-50)
|
||||
"""
|
||||
try:
|
||||
results = db.similarity_search_with_score(query, k=k)
|
||||
|
||||
duas_results = []
|
||||
for doc, score in results:
|
||||
result = DuaResult(
|
||||
id=doc.metadata.get('id'),
|
||||
arabic=doc.metadata.get('arabic'),
|
||||
transliteration=doc.metadata.get('transliteration'),
|
||||
translation=doc.metadata.get('translation'),
|
||||
urdu=doc.metadata.get('urdu'),
|
||||
romanUrdu=doc.metadata.get('romanUrdu'),
|
||||
category=doc.metadata.get('category'),
|
||||
occasion=doc.metadata.get('occasion'),
|
||||
source=doc.metadata.get('source'),
|
||||
tags=doc.metadata.get('tags'),
|
||||
similarity_score=round(1 - score, 4)
|
||||
)
|
||||
duas_results.append(result)
|
||||
|
||||
return SearchResponse(
|
||||
query=query,
|
||||
results_count=len(duas_results),
|
||||
results=duas_results
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
# # Get all unique categories
|
||||
@app.get("/categories", tags=["Metadata"])
|
||||
async def get_categories():
|
||||
"""
|
||||
Get all unique categories from the duas collection.
|
||||
"""
|
||||
try:
|
||||
# This is a simple implementation
|
||||
# For better performance, you might want to cache this
|
||||
results = db.similarity_search_with_score("", k=1000)
|
||||
categories = set()
|
||||
for doc, _ in results:
|
||||
category = doc.metadata.get('category')
|
||||
if category:
|
||||
categories.add(category)
|
||||
|
||||
return {
|
||||
"categories": sorted(list(categories)),
|
||||
"count": len(categories)
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch categories: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app", # Change "main" to your filename if different
|
||||
host="0.0.0.0",
|
||||
port=8899,
|
||||
reload=True
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user