Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import pickle | |
| import os | |
| # Load the model | |
| model_path = "numbers.pkl" | |
| classifier = None | |
| def load_model(): | |
| """Load the trained model""" | |
| global classifier | |
| try: | |
| if os.path.exists(model_path): | |
| with open(model_path, 'rb') as file: | |
| classifier = pickle.load(file) | |
| print("β Model loaded successfully!") | |
| else: | |
| print(f"β Model file '{model_path}' not found.") | |
| except Exception as e: | |
| print(f"β Error loading model: {str(e)}") | |
| # Load model on startup | |
| load_model() | |
| # Result mapping | |
| ResultMap = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', | |
| 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'} | |
| def predict_digit(image): | |
| """ | |
| Predict digit from uploaded image | |
| """ | |
| if classifier is None: | |
| return "β Model not loaded", 0.0 | |
| try: | |
| if image is None: | |
| return "β No image provided", 0.0 | |
| # Preprocess image | |
| # Convert to PIL Image if it's not already | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| # Resize and convert | |
| image = image.convert('RGB').resize((64, 64), Image.Resampling.LANCZOS) | |
| # Convert to array and normalize | |
| image_array = np.asarray(image, dtype=np.float32) / 255.0 | |
| # Add batch dimension | |
| image_array = np.expand_dims(image_array, axis=0) | |
| print(f"πΈ Image processed: {image_array.shape}") | |
| # Make prediction | |
| result = classifier.predict(image_array) | |
| # Handle different model output formats | |
| if hasattr(result, 'shape') and len(result.shape) > 1: | |
| predicted_class = np.argmax(result, axis=1)[0] if result.shape[0] == 1 else np.argmax(result) | |
| confidence = float(np.max(result)) | |
| else: | |
| predicted_class = int(result[0]) if hasattr(result, '__len__') else int(result) | |
| confidence = 1.0 | |
| predicted_digit = ResultMap.get(predicted_class, "Unknown") | |
| print(f'π― Prediction: {predicted_digit} (Confidence: {confidence:.2%})') | |
| return f"Predicted Digit: {predicted_digit}", confidence | |
| except Exception as e: | |
| print(f"β Error in prediction: {str(e)}") | |
| return f"β Error: {str(e)}", 0.0 | |
| # Create Gradio interface | |
| def create_gradio_interface(): | |
| """Create the Gradio interface""" | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="π’ Digit Recognition API", | |
| css=""" | |
| .gradio-container { | |
| max-width: 700px !important; | |
| margin: auto !important; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .result-box { | |
| font-size: 1.5rem; | |
| font-weight: bold; | |
| text-align: center; | |
| padding: 1rem; | |
| border-radius: 10px; | |
| } | |
| """ | |
| ) as interface: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>π’ Digit Recognition API</h1> | |
| <p>Upload an image of a handwritten digit (0-9) to get AI-powered recognition</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input | |
| image_input = gr.Image( | |
| label="πΈ Upload Digit Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| predict_btn = gr.Button( | |
| "π Predict Digit", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| # Outputs | |
| prediction_output = gr.Textbox( | |
| label="π― Prediction Result", | |
| interactive=False, | |
| elem_classes=["result-box"] | |
| ) | |
| confidence_output = gr.Number( | |
| label="π Confidence Score", | |
| interactive=False | |
| ) | |
| # Event handlers | |
| predict_btn.click( | |
| fn=predict_digit, | |
| inputs=[image_input], | |
| outputs=[prediction_output, confidence_output] | |
| ) | |
| # Also predict on image upload | |
| image_input.change( | |
| fn=predict_digit, | |
| inputs=[image_input], | |
| outputs=[prediction_output, confidence_output] | |
| ) | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 2rem; color: #666;"> | |
| <p>π <strong>Tips for best results:</strong></p> | |
| <p>β’ Use clear images with dark digits on light backgrounds</p> | |
| <p>β’ Ensure the digit fills most of the image</p> | |
| <p>β’ Supported formats: PNG, JPG, GIF</p> | |
| </div> | |
| """) | |
| return interface | |
| # API endpoint for external calls | |
| def predict_api(image): | |
| """API endpoint that returns JSON response""" | |
| prediction, confidence = predict_digit(image) | |
| # Extract just the digit from the prediction string | |
| digit = prediction.split(": ")[-1] if ": " in prediction else prediction | |
| return { | |
| "digit": digit, | |
| "confidence": float(confidence), | |
| "status": "success" if "β" not in prediction else "error" | |
| } | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| interface = create_gradio_interface() | |
| # Launch with API enabled | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_api=True | |
| ) |