dgomes03 commited on
Commit
6d53b71
·
verified ·
1 Parent(s): 8fab2b3

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +35 -0
  2. model.py +20 -0
  3. pytorch_model.bin +3 -0
  4. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ from model import CNN
8
+
9
+ # Load model
10
+ model = CNN()
11
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
12
+ model.eval()
13
+
14
+ # Inference function
15
+ def predict_digit(image):
16
+ image = image.convert("L").resize((28, 28)) # Convert to grayscale
17
+ image = np.array(image) / 255.0 # Normalize
18
+ image = torch.tensor(image).unsqueeze(0).unsqueeze(0).float() # (1, 1, 28, 28)
19
+ with torch.no_grad():
20
+ logits = model(image)
21
+ probs = F.softmax(logits, dim=1).numpy().flatten()
22
+ predicted = np.argmax(probs)
23
+ return {str(i): float(probs[i]) for i in range(10)}
24
+
25
+ # Gradio UI
26
+ interface = gr.Interface(
27
+ fn=predict_digit,
28
+ inputs=gr.Image(type="pil", shape=(280, 280), tool="editor"),
29
+ outputs=gr.Label(num_top_classes=3),
30
+ title="Handwritten Digit Classifier",
31
+ description="Draw a digit or upload a digit image."
32
+ )
33
+
34
+ if __name__ == "__main__":
35
+ interface.launch()
model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CNN(nn.Module):
2
+ def __init__(self):
3
+ super(CNN, self).__init__()
4
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
5
+ self.pool1 = nn.MaxPool2d(2)
6
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
7
+ self.pool2 = nn.MaxPool2d(2)
8
+ self.flatten = nn.Flatten()
9
+ self.fc1 = nn.Linear(64 * 5 * 5, 64)
10
+ self.fc2 = nn.Linear(64, 10)
11
+
12
+ def forward(self, x):
13
+ x = F.relu(self.conv1(x))
14
+ x = self.pool1(x)
15
+ x = F.relu(self.conv2(x))
16
+ x = self.pool2(x)
17
+ x = self.flatten(x)
18
+ x = F.relu(self.fc1(x))
19
+ x = self.fc2(x)
20
+ return x
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7606aba6abcf3685003ebbc6d8dcd6a40189edd71a393a4f46df56d00fb6146e
3
+ size 491361
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ numpy
4
+ pillow