176 lines
5.0 KiB
Python
176 lines
5.0 KiB
Python
from fastapi import FastAPI, File, UploadFile, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from doctr.models import ocr_predictor
|
|
from doctr.io import DocumentFile
|
|
import tempfile
|
|
import os
|
|
from typing import List
|
|
import uvicorn
|
|
|
|
app = FastAPI(
|
|
title="OCR API",
|
|
description="Extract text from images using DocTR",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Initialize the model once at startup
|
|
model = None
|
|
|
|
@app.on_event("startup")
|
|
async def load_model():
|
|
global model
|
|
print("Loading OCR model...")
|
|
model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
|
|
print("Model loaded successfully!")
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {
|
|
"message": "OCR API is running",
|
|
"endpoints": {
|
|
"POST /ocr": "Extract text from a single image",
|
|
"POST /ocr/batch": "Extract text from multiple images"
|
|
}
|
|
}
|
|
|
|
@app.post("/ocr")
|
|
async def extract_text(file: UploadFile = File(...)):
|
|
"""
|
|
Extract text from a single image file.
|
|
|
|
Returns:
|
|
- text: Extracted text as a single line
|
|
- word_count: Number of words extracted
|
|
"""
|
|
if not file.content_type.startswith('image/'):
|
|
raise HTTPException(status_code=400, detail="File must be an image")
|
|
|
|
try:
|
|
# Save uploaded file temporarily
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
|
content = await file.read()
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
# Process the image
|
|
doc = DocumentFile.from_images(tmp_path)
|
|
result = model(doc)
|
|
|
|
# Extract all words
|
|
all_words = []
|
|
for page in result.pages:
|
|
for block in page.blocks:
|
|
for line in block.lines:
|
|
for word in line.words:
|
|
all_words.append(word.value)
|
|
|
|
# Join as single line
|
|
single_line = ' '.join(all_words)
|
|
|
|
# Clean up temp file
|
|
os.unlink(tmp_path)
|
|
|
|
return JSONResponse({
|
|
"success": True,
|
|
"text": single_line,
|
|
"word_count": len(all_words)
|
|
})
|
|
|
|
except Exception as e:
|
|
# Clean up temp file if it exists
|
|
if 'tmp_path' in locals():
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except:
|
|
pass
|
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
|
|
|
@app.post("/ocr/batch")
|
|
async def extract_text_batch(files: List[UploadFile] = File(...)):
|
|
"""
|
|
Extract text from multiple image files.
|
|
|
|
Returns:
|
|
- results: List of extracted texts with metadata
|
|
"""
|
|
if len(files) > 10:
|
|
raise HTTPException(status_code=400, detail="Maximum 10 files allowed per batch")
|
|
|
|
results = []
|
|
|
|
for file in files:
|
|
if not file.content_type.startswith('image/'):
|
|
results.append({
|
|
"filename": file.filename,
|
|
"success": False,
|
|
"error": "File must be an image"
|
|
})
|
|
continue
|
|
|
|
try:
|
|
# Save uploaded file temporarily
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
|
content = await file.read()
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
# Process the image
|
|
doc = DocumentFile.from_images(tmp_path)
|
|
result = model(doc)
|
|
|
|
# Extract all words
|
|
all_words = []
|
|
for page in result.pages:
|
|
for block in page.blocks:
|
|
for line in block.lines:
|
|
for word in line.words:
|
|
all_words.append(word.value)
|
|
|
|
# Join as single line
|
|
single_line = ' '.join(all_words)
|
|
|
|
# Clean up temp file
|
|
os.unlink(tmp_path)
|
|
|
|
results.append({
|
|
"filename": file.filename,
|
|
"success": True,
|
|
"text": single_line,
|
|
"word_count": len(all_words)
|
|
})
|
|
|
|
except Exception as e:
|
|
# Clean up temp file if it exists
|
|
if 'tmp_path' in locals():
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except:
|
|
pass
|
|
results.append({
|
|
"filename": file.filename,
|
|
"success": False,
|
|
"error": str(e)
|
|
})
|
|
|
|
return JSONResponse({"results": results})
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": model is not None
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8999) |