Fix: multiple file upload issue ressolved
This commit is contained in:
parent
6e648c2282
commit
38a0929865
176
main.py
Normal file
176
main.py
Normal file
@ -0,0 +1,176 @@
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from doctr.models import ocr_predictor
|
||||
from doctr.io import DocumentFile
|
||||
import tempfile
|
||||
import os
|
||||
from typing import List
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI(
|
||||
title="OCR API",
|
||||
description="Extract text from images using DocTR",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize the model once at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
print("Loading OCR model...")
|
||||
model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
|
||||
print("Model loaded successfully!")
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"message": "OCR API is running",
|
||||
"endpoints": {
|
||||
"POST /ocr": "Extract text from a single image",
|
||||
"POST /ocr/batch": "Extract text from multiple images"
|
||||
}
|
||||
}
|
||||
|
||||
@app.post("/ocr")
|
||||
async def extract_text(file: UploadFile = File(...)):
|
||||
"""
|
||||
Extract text from a single image file.
|
||||
|
||||
Returns:
|
||||
- text: Extracted text as a single line
|
||||
- word_count: Number of words extracted
|
||||
"""
|
||||
if not file.content_type.startswith('image/'):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
|
||||
try:
|
||||
# Save uploaded file temporarily
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
# Process the image
|
||||
doc = DocumentFile.from_images(tmp_path)
|
||||
result = model(doc)
|
||||
|
||||
# Extract all words
|
||||
all_words = []
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
for word in line.words:
|
||||
all_words.append(word.value)
|
||||
|
||||
# Join as single line
|
||||
single_line = ' '.join(all_words)
|
||||
|
||||
# Clean up temp file
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"text": single_line,
|
||||
"word_count": len(all_words)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temp file if it exists
|
||||
if 'tmp_path' in locals():
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except:
|
||||
pass
|
||||
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
||||
|
||||
@app.post("/ocr/batch")
|
||||
async def extract_text_batch(files: List[UploadFile] = File(...)):
|
||||
"""
|
||||
Extract text from multiple image files.
|
||||
|
||||
Returns:
|
||||
- results: List of extracted texts with metadata
|
||||
"""
|
||||
if len(files) > 10:
|
||||
raise HTTPException(status_code=400, detail="Maximum 10 files allowed per batch")
|
||||
|
||||
results = []
|
||||
|
||||
for file in files:
|
||||
if not file.content_type.startswith('image/'):
|
||||
results.append({
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "File must be an image"
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
# Save uploaded file temporarily
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
# Process the image
|
||||
doc = DocumentFile.from_images(tmp_path)
|
||||
result = model(doc)
|
||||
|
||||
# Extract all words
|
||||
all_words = []
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
for word in line.words:
|
||||
all_words.append(word.value)
|
||||
|
||||
# Join as single line
|
||||
single_line = ' '.join(all_words)
|
||||
|
||||
# Clean up temp file
|
||||
os.unlink(tmp_path)
|
||||
|
||||
results.append({
|
||||
"filename": file.filename,
|
||||
"success": True,
|
||||
"text": single_line,
|
||||
"word_count": len(all_words)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temp file if it exists
|
||||
if 'tmp_path' in locals():
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except:
|
||||
pass
|
||||
results.append({
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return JSONResponse({"results": results})
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"model_loaded": model is not None
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8999)
|
||||
Loading…
x
Reference in New Issue
Block a user