diff --git a/main.py b/main.py new file mode 100644 index 0000000..aa45f00 --- /dev/null +++ b/main.py @@ -0,0 +1,176 @@ +from fastapi import FastAPI, File, UploadFile, HTTPException +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from doctr.models import ocr_predictor +from doctr.io import DocumentFile +import tempfile +import os +from typing import List +import uvicorn + +app = FastAPI( + title="OCR API", + description="Extract text from images using DocTR", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize the model once at startup +model = None + +@app.on_event("startup") +async def load_model(): + global model + print("Loading OCR model...") + model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True) + print("Model loaded successfully!") + +@app.get("/") +async def root(): + return { + "message": "OCR API is running", + "endpoints": { + "POST /ocr": "Extract text from a single image", + "POST /ocr/batch": "Extract text from multiple images" + } + } + +@app.post("/ocr") +async def extract_text(file: UploadFile = File(...)): + """ + Extract text from a single image file. + + Returns: + - text: Extracted text as a single line + - word_count: Number of words extracted + """ + if not file.content_type.startswith('image/'): + raise HTTPException(status_code=400, detail="File must be an image") + + try: + # Save uploaded file temporarily + with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp: + content = await file.read() + tmp.write(content) + tmp_path = tmp.name + + # Process the image + doc = DocumentFile.from_images(tmp_path) + result = model(doc) + + # Extract all words + all_words = [] + for page in result.pages: + for block in page.blocks: + for line in block.lines: + for word in line.words: + all_words.append(word.value) + + # Join as single line + single_line = ' '.join(all_words) + + # Clean up temp file + os.unlink(tmp_path) + + return JSONResponse({ + "success": True, + "text": single_line, + "word_count": len(all_words) + }) + + except Exception as e: + # Clean up temp file if it exists + if 'tmp_path' in locals(): + try: + os.unlink(tmp_path) + except: + pass + raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") + +@app.post("/ocr/batch") +async def extract_text_batch(files: List[UploadFile] = File(...)): + """ + Extract text from multiple image files. + + Returns: + - results: List of extracted texts with metadata + """ + if len(files) > 10: + raise HTTPException(status_code=400, detail="Maximum 10 files allowed per batch") + + results = [] + + for file in files: + if not file.content_type.startswith('image/'): + results.append({ + "filename": file.filename, + "success": False, + "error": "File must be an image" + }) + continue + + try: + # Save uploaded file temporarily + with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp: + content = await file.read() + tmp.write(content) + tmp_path = tmp.name + + # Process the image + doc = DocumentFile.from_images(tmp_path) + result = model(doc) + + # Extract all words + all_words = [] + for page in result.pages: + for block in page.blocks: + for line in block.lines: + for word in line.words: + all_words.append(word.value) + + # Join as single line + single_line = ' '.join(all_words) + + # Clean up temp file + os.unlink(tmp_path) + + results.append({ + "filename": file.filename, + "success": True, + "text": single_line, + "word_count": len(all_words) + }) + + except Exception as e: + # Clean up temp file if it exists + if 'tmp_path' in locals(): + try: + os.unlink(tmp_path) + except: + pass + results.append({ + "filename": file.filename, + "success": False, + "error": str(e) + }) + + return JSONResponse({"results": results}) + +@app.get("/health") +async def health_check(): + """Health check endpoint""" + return { + "status": "healthy", + "model_loaded": model is not None + } + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8999) \ No newline at end of file