jarvis0852's picture
Update app.py
01c3821 verified
raw
history blame
5.72 kB
import gradio as gr
import numpy as np
from PIL import Image
import pickle
import os
# Load the model
model_path = "numbers.pkl"
classifier = None
def load_model():
"""Load the trained model"""
global classifier
try:
if os.path.exists(model_path):
with open(model_path, 'rb') as file:
classifier = pickle.load(file)
print("βœ… Model loaded successfully!")
else:
print(f"❌ Model file '{model_path}' not found.")
except Exception as e:
print(f"❌ Error loading model: {str(e)}")
# Load model on startup
load_model()
# Result mapping
ResultMap = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4',
5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}
def predict_digit(image):
"""
Predict digit from uploaded image
"""
if classifier is None:
return "❌ Model not loaded", 0.0
try:
if image is None:
return "❌ No image provided", 0.0
# Preprocess image
# Convert to PIL Image if it's not already
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Resize and convert
image = image.convert('RGB').resize((64, 64), Image.Resampling.LANCZOS)
# Convert to array and normalize
image_array = np.asarray(image, dtype=np.float32) / 255.0
# Add batch dimension
image_array = np.expand_dims(image_array, axis=0)
print(f"πŸ“Έ Image processed: {image_array.shape}")
# Make prediction
result = classifier.predict(image_array)
# Handle different model output formats
if hasattr(result, 'shape') and len(result.shape) > 1:
predicted_class = np.argmax(result, axis=1)[0] if result.shape[0] == 1 else np.argmax(result)
confidence = float(np.max(result))
else:
predicted_class = int(result[0]) if hasattr(result, '__len__') else int(result)
confidence = 1.0
predicted_digit = ResultMap.get(predicted_class, "Unknown")
print(f'🎯 Prediction: {predicted_digit} (Confidence: {confidence:.2%})')
return f"Predicted Digit: {predicted_digit}", confidence
except Exception as e:
print(f"❌ Error in prediction: {str(e)}")
return f"❌ Error: {str(e)}", 0.0
# Create Gradio interface
def create_gradio_interface():
"""Create the Gradio interface"""
with gr.Blocks(
theme=gr.themes.Soft(),
title="πŸ”’ Digit Recognition API",
css="""
.gradio-container {
max-width: 700px !important;
margin: auto !important;
}
.header {
text-align: center;
margin-bottom: 2rem;
}
.result-box {
font-size: 1.5rem;
font-weight: bold;
text-align: center;
padding: 1rem;
border-radius: 10px;
}
"""
) as interface:
gr.HTML("""
<div class="header">
<h1>πŸ”’ Digit Recognition API</h1>
<p>Upload an image of a handwritten digit (0-9) to get AI-powered recognition</p>
</div>
""")
with gr.Row():
with gr.Column():
# Input
image_input = gr.Image(
label="πŸ“Έ Upload Digit Image",
type="pil",
height=300
)
predict_btn = gr.Button(
"πŸš€ Predict Digit",
variant="primary",
size="lg"
)
with gr.Column():
# Outputs
prediction_output = gr.Textbox(
label="🎯 Prediction Result",
interactive=False,
elem_classes=["result-box"]
)
confidence_output = gr.Number(
label="πŸ“Š Confidence Score",
interactive=False
)
# Event handlers
predict_btn.click(
fn=predict_digit,
inputs=[image_input],
outputs=[prediction_output, confidence_output]
)
# Also predict on image upload
image_input.change(
fn=predict_digit,
inputs=[image_input],
outputs=[prediction_output, confidence_output]
)
gr.HTML("""
<div style="text-align: center; margin-top: 2rem; color: #666;">
<p>πŸ“‹ <strong>Tips for best results:</strong></p>
<p>β€’ Use clear images with dark digits on light backgrounds</p>
<p>β€’ Ensure the digit fills most of the image</p>
<p>β€’ Supported formats: PNG, JPG, GIF</p>
</div>
""")
return interface
# API endpoint for external calls
def predict_api(image):
"""API endpoint that returns JSON response"""
prediction, confidence = predict_digit(image)
# Extract just the digit from the prediction string
digit = prediction.split(": ")[-1] if ": " in prediction else prediction
return {
"digit": digit,
"confidence": float(confidence),
"status": "success" if "❌" not in prediction else "error"
}
if __name__ == "__main__":
# Create and launch the interface
interface = create_gradio_interface()
# Launch with API enabled
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=True
)