import gradio as gr import numpy as np from PIL import Image import pickle import os # Load the model model_path = "numbers.pkl" classifier = None def load_model(): """Load the trained model""" global classifier try: if os.path.exists(model_path): with open(model_path, 'rb') as file: classifier = pickle.load(file) print("✅ Model loaded successfully!") else: print(f"❌ Model file '{model_path}' not found.") except Exception as e: print(f"❌ Error loading model: {str(e)}") # Load model on startup load_model() # Result mapping ResultMap = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'} def predict_digit(image): """ Predict digit from uploaded image """ if classifier is None: return "❌ Model not loaded", 0.0 try: if image is None: return "❌ No image provided", 0.0 # Preprocess image # Convert to PIL Image if it's not already if not isinstance(image, Image.Image): image = Image.fromarray(image) # Resize and convert image = image.convert('RGB').resize((64, 64), Image.Resampling.LANCZOS) # Convert to array and normalize image_array = np.asarray(image, dtype=np.float32) / 255.0 # Add batch dimension image_array = np.expand_dims(image_array, axis=0) print(f"📸 Image processed: {image_array.shape}") # Make prediction result = classifier.predict(image_array) # Handle different model output formats if hasattr(result, 'shape') and len(result.shape) > 1: predicted_class = np.argmax(result, axis=1)[0] if result.shape[0] == 1 else np.argmax(result) confidence = float(np.max(result)) else: predicted_class = int(result[0]) if hasattr(result, '__len__') else int(result) confidence = 1.0 predicted_digit = ResultMap.get(predicted_class, "Unknown") print(f'🎯 Prediction: {predicted_digit} (Confidence: {confidence:.2%})') return f"Predicted Digit: {predicted_digit}", confidence except Exception as e: print(f"❌ Error in prediction: {str(e)}") return f"❌ Error: {str(e)}", 0.0 # Create Gradio interface def create_gradio_interface(): """Create the Gradio interface""" with gr.Blocks( theme=gr.themes.Soft(), title="🔢 Digit Recognition API", css=""" .gradio-container { max-width: 700px !important; margin: auto !important; } .header { text-align: center; margin-bottom: 2rem; } .result-box { font-size: 1.5rem; font-weight: bold; text-align: center; padding: 1rem; border-radius: 10px; } """ ) as interface: gr.HTML("""
Upload an image of a handwritten digit (0-9) to get AI-powered recognition
📋 Tips for best results:
• Use clear images with dark digits on light backgrounds
• Ensure the digit fills most of the image
• Supported formats: PNG, JPG, GIF