Faham
commited on
Commit
Β·
4b35e49
0
Parent(s):
CREATE: initialized repo
Browse files- .gitignore +14 -0
- .python-version +1 -0
- .streamlit/config.toml +21 -0
- README.md +271 -0
- app.py +1220 -0
- models/audio_sentiment_analysis.ipynb +0 -0
- models/vision_sentiment_analysis.ipynb +0 -0
- pyproject.toml +7 -0
- requirements.txt +12 -0
- run_app.py +65 -0
- test_audio_model.py +173 -0
- test_vision_model.py +136 -0
- uv.lock +8 -0
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# model files
|
| 13 |
+
*.pth
|
| 14 |
+
models/*.pth
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.9
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[global]
|
| 2 |
+
developmentMode = false
|
| 3 |
+
|
| 4 |
+
[server]
|
| 5 |
+
headless = false
|
| 6 |
+
port = 8501
|
| 7 |
+
enableCORS = false
|
| 8 |
+
enableXsrfProtection = false
|
| 9 |
+
|
| 10 |
+
[browser]
|
| 11 |
+
gatherUsageStats = false
|
| 12 |
+
|
| 13 |
+
[theme]
|
| 14 |
+
primaryColor = "#1f77b4"
|
| 15 |
+
backgroundColor = "#ffffff"
|
| 16 |
+
secondaryBackgroundColor = "#f0f2f6"
|
| 17 |
+
textColor = "#262730"
|
| 18 |
+
font = "sans serif"
|
| 19 |
+
|
| 20 |
+
[client]
|
| 21 |
+
showErrorDetails = true
|
README.md
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sentiment Analysis Testing Ground
|
| 2 |
+
|
| 3 |
+
A comprehensive multi-page Streamlit application for testing three independent sentiment analysis models: text, audio, and vision-based sentiment analysis.
|
| 4 |
+
|
| 5 |
+
## π Features
|
| 6 |
+
|
| 7 |
+
- **Multi-Page Interface**: Clean navigation with dedicated pages for each model
|
| 8 |
+
- **Text Sentiment Analysis**: β
**READY TO USE** - TextBlob NLP model integrated
|
| 9 |
+
- **Audio Sentiment Analysis**: β
**READY TO USE** - Fine-tuned Wav2Vec2 model integrated
|
| 10 |
+
- π **File Upload**: Support for WAV, MP3, M4A, FLAC files
|
| 11 |
+
- ποΈ **Audio Recording**: Direct microphone recording (max 5 seconds)
|
| 12 |
+
- π **Smart Preprocessing**: Automatic 16kHz sampling, 5s max duration (CREMA-D + RAVDESS format)
|
| 13 |
+
- **Vision Sentiment Analysis**: β
**READY TO USE** - Fine-tuned ResNet-50 model integrated
|
| 14 |
+
- π **File Upload**: Support for PNG, JPG, JPEG, BMP, TIFF files
|
| 15 |
+
- π· **Camera Capture**: Take photos directly with your camera
|
| 16 |
+
- π **Smart Preprocessing**: Automatic face detection, tight face crop (0% padding), grayscale conversion, 224x224 resize
|
| 17 |
+
- **Fused Model**: Combine predictions from all three models
|
| 18 |
+
- **Modern UI**: Beautiful, responsive interface with custom styling
|
| 19 |
+
- **File Support**: Multiple audio and image format support
|
| 20 |
+
|
| 21 |
+
## π Requirements
|
| 22 |
+
|
| 23 |
+
- Python 3.9 or higher
|
| 24 |
+
- Streamlit 1.28.0 or higher
|
| 25 |
+
- PyTorch 1.13.0 or higher
|
| 26 |
+
- Additional dependencies listed in `requirements.txt`
|
| 27 |
+
|
| 28 |
+
## π οΈ Installation
|
| 29 |
+
|
| 30 |
+
1. **Clone the repository**:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
git clone <your-repo-url>
|
| 34 |
+
cd sentiment-fused
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
2. **Create a virtual environment** (recommended):
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
python -m venv venv
|
| 41 |
+
|
| 42 |
+
# On Windows
|
| 43 |
+
venv\Scripts\activate
|
| 44 |
+
|
| 45 |
+
# On macOS/Linux
|
| 46 |
+
source venv/bin/activate
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
3. **Install dependencies**:
|
| 50 |
+
```bash
|
| 51 |
+
pip install -r requirements.txt
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## π Usage
|
| 55 |
+
|
| 56 |
+
1. **Start the Streamlit application**:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
streamlit run app.py
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
2. **Open your browser** and navigate to the URL shown in the terminal (usually `http://localhost:8501`)
|
| 63 |
+
|
| 64 |
+
3. **Navigate between pages** using the sidebar:
|
| 65 |
+
- π **Home**: Overview and welcome page
|
| 66 |
+
- π **Text Sentiment**: β
**Ready to use** - Analyze text with TextBlob
|
| 67 |
+
- π΅ **Audio Sentiment**: β
**Ready to use** - Analyze audio with Wav2Vec2 - π Upload audio files or ποΈ record directly with microphone using `st.audio_input`
|
| 68 |
+
- πΌοΈ **Vision Sentiment**: β
**Ready to use** - Analyze images with ResNet-50
|
| 69 |
+
- π Upload image files or π· take photos with camera
|
| 70 |
+
- π **Fused Model**: Combine all three models
|
| 71 |
+
|
| 72 |
+
## π§ͺ Testing the Models
|
| 73 |
+
|
| 74 |
+
Before running the full app, you can test if the models load correctly:
|
| 75 |
+
|
| 76 |
+
### Vision Model Test
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python test_vision_model.py
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### Audio Model Test
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
python test_audio_model.py
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
These will verify that:
|
| 89 |
+
|
| 90 |
+
- The model files exist
|
| 91 |
+
- PyTorch can load the architectures
|
| 92 |
+
- The trained weights can be loaded
|
| 93 |
+
- Inference runs without errors
|
| 94 |
+
|
| 95 |
+
### π Troubleshooting Model Issues
|
| 96 |
+
|
| 97 |
+
If you encounter tensor size mismatch errors, run the diagnostic scripts:
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
python check_model.py # For vision model
|
| 101 |
+
python test_audio_model.py # For audio model
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
These will examine your model files and identify:
|
| 105 |
+
|
| 106 |
+
- The actual number of output classes
|
| 107 |
+
- Whether the architectures match expected models
|
| 108 |
+
- Any compatibility issues
|
| 109 |
+
|
| 110 |
+
**Common Issues:**
|
| 111 |
+
|
| 112 |
+
- **Tensor size mismatch**: Models might have been trained with different numbers of classes
|
| 113 |
+
- **Architecture mismatch**: Models might not match expected architectures
|
| 114 |
+
- **Weight loading errors**: Corrupted or incompatible model files
|
| 115 |
+
- **Library dependencies**: Missing transformers, librosa, or other required libraries
|
| 116 |
+
|
| 117 |
+
## π Project Structure
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
sentiment-fused/
|
| 121 |
+
βββ app.py # Main Streamlit application
|
| 122 |
+
βββ requirements.txt # Python dependencies
|
| 123 |
+
βββ README.md # This file
|
| 124 |
+
βββ test_vision_model.py # Vision model test script
|
| 125 |
+
βββ test_audio_model.py # Audio model test script
|
| 126 |
+
βββ main.py # Original main file
|
| 127 |
+
βββ pyproject.toml # Project configuration
|
| 128 |
+
βββ models/ # Model files and notebooks
|
| 129 |
+
βββ audio_sentiment_analysis.ipynb
|
| 130 |
+
βββ vision_sentiment_analysis.ipynb
|
| 131 |
+
βββ wav2vec2_model.pth # β
Fine-tuned Wav2Vec2 model (READY)
|
| 132 |
+
βββ resnet50_model.pth # β
Fine-tuned ResNet-50 model (READY)
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
## π§ Model Integration Status
|
| 136 |
+
|
| 137 |
+
### β
Text Sentiment Model - **READY TO USE**
|
| 138 |
+
|
| 139 |
+
- **Model**: TextBlob (Natural Language Processing)
|
| 140 |
+
- **Features**: Sentiment classification (Positive/Negative/Neutral) with confidence scores
|
| 141 |
+
- **Input**: Any text input
|
| 142 |
+
- **Analysis**: Real-time NLP sentiment analysis
|
| 143 |
+
- **Status**: Fully integrated and tested
|
| 144 |
+
|
| 145 |
+
### β
Vision Sentiment Model - **READY TO USE**
|
| 146 |
+
|
| 147 |
+
- **Model**: ResNet-50 fine-tuned on FER2013 dataset
|
| 148 |
+
- **Training Dataset**:
|
| 149 |
+
- πΌοΈ **FER2013**: Facial Expression Recognition 2013 dataset
|
| 150 |
+
- π― **Classes**: 7 emotions mapped to 3 sentiments (Negative, Neutral, Positive)
|
| 151 |
+
- ποΈ **Architecture**: ResNet-50 with ImageNet weights, fine-tuned for sentiment
|
| 152 |
+
- **Classes**: 3 sentiment classes (Negative, Neutral, Positive)
|
| 153 |
+
- **Input**: Images (PNG, JPG, JPEG, BMP, TIFF)
|
| 154 |
+
- **Preprocessing**:
|
| 155 |
+
- π **Face Detection**: Automatic face detection using OpenCV
|
| 156 |
+
- π¨ **Grayscale Conversion**: Convert to grayscale and replicate to 3 channels
|
| 157 |
+
- π **Face Cropping**: Crop to face region with 0% padding (tightest crop)
|
| 158 |
+
- π **Resize**: Scale to 224x224 pixels (FER2013 format)
|
| 159 |
+
- π― **Transforms**: Resize(224) β CenterCrop(224) β ToTensor β ImageNet Normalization
|
| 160 |
+
- π **Format**: 224x224 RGB with ImageNet mean/std normalization
|
| 161 |
+
- **Status**: Fully integrated and tested
|
| 162 |
+
|
| 163 |
+
### β
Audio Sentiment Model - **READY TO USE**
|
| 164 |
+
|
| 165 |
+
- **Model**: Wav2Vec2-base fine-tuned on RAVDESS + CREMA-D datasets
|
| 166 |
+
- **Training Datasets**:
|
| 167 |
+
- π΅ **RAVDESS**: Ryerson Audio-Visual Database of Emotional Speech and Song
|
| 168 |
+
- π΅ **CREMA-D**: Crowd-sourced Emotional Multimodal Actors Dataset
|
| 169 |
+
- **Classes**: 3 sentiment classes (Negative, Neutral, Positive)
|
| 170 |
+
- **Input**:
|
| 171 |
+
- π **File Upload**: Audio files (WAV, MP3, M4A, FLAC)
|
| 172 |
+
- ποΈ **Direct Recording**: Microphone input using `st.audio_input`
|
| 173 |
+
- **Preprocessing**:
|
| 174 |
+
- π **Sampling Rate**: 16kHz (matching CREMA-D + RAVDESS training)
|
| 175 |
+
- β±οΈ **Duration**: Max 5 seconds (matching training max_duration_s=5.0)
|
| 176 |
+
- π΅ **Feature Extraction**: AutoFeatureExtractor with truncation and padding
|
| 177 |
+
- π **Format**: Automatic resampling, max_length=int(5.0 \* 16000)
|
| 178 |
+
- **Status**: Fully integrated and tested
|
| 179 |
+
|
| 180 |
+
### π Fused Model - **FULLY READY**
|
| 181 |
+
|
| 182 |
+
The fused model now uses all three integrated models: text (TextBlob), audio (Wav2Vec2), and vision (ResNet-50).
|
| 183 |
+
|
| 184 |
+
## π Supported File Formats
|
| 185 |
+
|
| 186 |
+
### Audio Files
|
| 187 |
+
|
| 188 |
+
- WAV (.wav)
|
| 189 |
+
- MP3 (.mp3)
|
| 190 |
+
- M4A (.m4a)
|
| 191 |
+
- FLAC (.flac)
|
| 192 |
+
|
| 193 |
+
### Image Files
|
| 194 |
+
|
| 195 |
+
- PNG (.png)
|
| 196 |
+
- JPEG (.jpg, .jpeg)
|
| 197 |
+
- BMP (.bmp)
|
| 198 |
+
- TIFF (.tiff)
|
| 199 |
+
|
| 200 |
+
## π¨ Customization
|
| 201 |
+
|
| 202 |
+
The application includes custom CSS styling that can be modified in the `app.py` file. Key styling classes:
|
| 203 |
+
|
| 204 |
+
- `.main-header`: Main page headers
|
| 205 |
+
- `.model-card`: Information cards
|
| 206 |
+
- `.result-box`: Result display boxes
|
| 207 |
+
- `.upload-section`: File upload areas
|
| 208 |
+
|
| 209 |
+
## π Troubleshooting
|
| 210 |
+
|
| 211 |
+
### Common Issues
|
| 212 |
+
|
| 213 |
+
1. **Port already in use**: Change the port with `streamlit run app.py --server.port 8502`
|
| 214 |
+
|
| 215 |
+
2. **Vision model loading errors**:
|
| 216 |
+
|
| 217 |
+
- Ensure `models/resnet50_model.pth` exists
|
| 218 |
+
- Run `python test_vision_model.py` to diagnose issues
|
| 219 |
+
- Check PyTorch installation: `python -c "import torch; print(torch.__version__)"`
|
| 220 |
+
|
| 221 |
+
3. **Memory issues**: Large audio/image files may require more memory. Consider file size limits
|
| 222 |
+
|
| 223 |
+
4. **OpenCV issues**: If face detection fails, ensure `opencv-python` is installed:
|
| 224 |
+
|
| 225 |
+
```bash
|
| 226 |
+
pip install opencv-python
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
5. **Dependency conflicts**: Use a virtual environment to avoid package conflicts
|
| 230 |
+
|
| 231 |
+
### Performance Tips
|
| 232 |
+
|
| 233 |
+
- Use appropriate file sizes for audio and images
|
| 234 |
+
- Consider implementing caching for model predictions
|
| 235 |
+
- Use GPU acceleration if available for PyTorch models
|
| 236 |
+
- The vision model automatically uses GPU if available
|
| 237 |
+
|
| 238 |
+
## π€ Contributing
|
| 239 |
+
|
| 240 |
+
1. Fork the repository
|
| 241 |
+
2. Create a feature branch
|
| 242 |
+
3. Make your changes
|
| 243 |
+
4. Test thoroughly
|
| 244 |
+
5. Submit a pull request
|
| 245 |
+
|
| 246 |
+
## π License
|
| 247 |
+
|
| 248 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 249 |
+
|
| 250 |
+
## π Acknowledgments
|
| 251 |
+
|
| 252 |
+
- Streamlit team for the amazing web framework
|
| 253 |
+
- PyTorch community for deep learning tools
|
| 254 |
+
- Hugging Face for transformer models
|
| 255 |
+
- All contributors to the open-source libraries used
|
| 256 |
+
|
| 257 |
+
## π Support
|
| 258 |
+
|
| 259 |
+
For questions or issues:
|
| 260 |
+
|
| 261 |
+
1. Check the troubleshooting section above
|
| 262 |
+
2. Run `python test_vision_model.py` for vision model issues
|
| 263 |
+
3. Review the model integration examples
|
| 264 |
+
4. Open an issue on the repository
|
| 265 |
+
5. Contact the development team
|
| 266 |
+
|
| 267 |
+
---
|
| 268 |
+
|
| 269 |
+
**Happy Sentiment Analysis! π§ β¨**
|
| 270 |
+
|
| 271 |
+
**Note**: All **THREE MODELS** are now fully integrated and ready to use! π
|
app.py
ADDED
|
@@ -0,0 +1,1220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import io
|
| 5 |
+
import numpy as np
|
| 6 |
+
import tempfile
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torchvision import transforms, models
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
# Page configuration
|
| 14 |
+
st.set_page_config(
|
| 15 |
+
page_title="Sentiment Analysis Testing Ground",
|
| 16 |
+
page_icon="π§ ",
|
| 17 |
+
layout="wide",
|
| 18 |
+
initial_sidebar_state="expanded",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Custom CSS for better styling
|
| 22 |
+
st.markdown(
|
| 23 |
+
"""
|
| 24 |
+
<style>
|
| 25 |
+
.main-header {
|
| 26 |
+
font-size: 2.5rem;
|
| 27 |
+
font-weight: bold;
|
| 28 |
+
color: #1f77b4;
|
| 29 |
+
text-align: center;
|
| 30 |
+
margin-bottom: 2rem;
|
| 31 |
+
}
|
| 32 |
+
.model-card {
|
| 33 |
+
background-color: #f0f2f6;
|
| 34 |
+
padding: 1.5rem;
|
| 35 |
+
border-radius: 10px;
|
| 36 |
+
margin: 1rem 0;
|
| 37 |
+
border-left: 4px solid #1f77b4;
|
| 38 |
+
}
|
| 39 |
+
.result-box {
|
| 40 |
+
background-color: #e8f4fd;
|
| 41 |
+
padding: 1rem;
|
| 42 |
+
border-radius: 8px;
|
| 43 |
+
border: 1px solid #1f77b4;
|
| 44 |
+
margin: 1rem 0;
|
| 45 |
+
}
|
| 46 |
+
.upload-section {
|
| 47 |
+
background-color: #f8f9fa;
|
| 48 |
+
padding: 1.5rem;
|
| 49 |
+
border-radius: 10px;
|
| 50 |
+
border: 2px dashed #dee2e6;
|
| 51 |
+
text-align: center;
|
| 52 |
+
margin: 1rem 0;
|
| 53 |
+
}
|
| 54 |
+
</style>
|
| 55 |
+
""",
|
| 56 |
+
unsafe_allow_html=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Global variables for models
|
| 61 |
+
@st.cache_resource
|
| 62 |
+
def load_vision_model():
|
| 63 |
+
"""Load the pre-trained ResNet-50 vision sentiment model"""
|
| 64 |
+
try:
|
| 65 |
+
# Check if model file exists
|
| 66 |
+
model_path = "models/resnet50_model.pth"
|
| 67 |
+
if not os.path.exists(model_path):
|
| 68 |
+
st.error(f"β Vision model file not found at: {model_path}")
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
# Load the model weights first to check the architecture
|
| 72 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 73 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 74 |
+
|
| 75 |
+
# Check the number of classes from the checkpoint
|
| 76 |
+
if "fc.weight" in checkpoint:
|
| 77 |
+
num_classes = checkpoint["fc.weight"].shape[0]
|
| 78 |
+
st.info(f"π Model checkpoint has {num_classes} output classes")
|
| 79 |
+
else:
|
| 80 |
+
# Fallback: try to infer from the last layer
|
| 81 |
+
num_classes = 3 # Default assumption
|
| 82 |
+
st.warning(
|
| 83 |
+
"β οΈ Could not determine number of classes from checkpoint, assuming 3"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Initialize ResNet-50 model with the correct number of classes
|
| 87 |
+
# Note: Your model was trained with RGB images, so we keep 3 channels
|
| 88 |
+
model = models.resnet50(weights=None) # Don't load ImageNet weights
|
| 89 |
+
|
| 90 |
+
num_ftrs = model.fc.in_features
|
| 91 |
+
model.fc = nn.Linear(num_ftrs, num_classes) # Use actual number of classes
|
| 92 |
+
|
| 93 |
+
# Load trained weights
|
| 94 |
+
model.load_state_dict(checkpoint)
|
| 95 |
+
model.to(device)
|
| 96 |
+
model.eval()
|
| 97 |
+
|
| 98 |
+
st.success(f"β
Vision model loaded successfully with {num_classes} classes!")
|
| 99 |
+
return model, device, num_classes
|
| 100 |
+
except Exception as e:
|
| 101 |
+
st.error(f"β Error loading vision model: {str(e)}")
|
| 102 |
+
return None, None, None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@st.cache_data
|
| 106 |
+
def get_vision_transforms():
|
| 107 |
+
"""Get the image transforms used during FER2013 training"""
|
| 108 |
+
return transforms.Compose(
|
| 109 |
+
[
|
| 110 |
+
transforms.Resize(224), # Match training: transforms.Resize(224)
|
| 111 |
+
transforms.CenterCrop(224), # Match training: transforms.CenterCrop(224)
|
| 112 |
+
transforms.ToTensor(),
|
| 113 |
+
transforms.Normalize(
|
| 114 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 115 |
+
), # ImageNet normalization
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def detect_and_preprocess_face(image, crop_tightness=0.05):
|
| 121 |
+
"""
|
| 122 |
+
Detect face in image, crop to face region, convert to grayscale, and resize to 224x224
|
| 123 |
+
to match FER2013 dataset format (grayscale converted to 3-channel RGB)
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
image: Input image (PIL Image or numpy array)
|
| 127 |
+
crop_tightness: Padding around face (0.0 = no padding, 0.3 = 30% padding)
|
| 128 |
+
"""
|
| 129 |
+
try:
|
| 130 |
+
import cv2
|
| 131 |
+
import numpy as np
|
| 132 |
+
|
| 133 |
+
# Convert PIL image to OpenCV format
|
| 134 |
+
if isinstance(image, Image.Image):
|
| 135 |
+
# Convert PIL to numpy array
|
| 136 |
+
img_array = np.array(image)
|
| 137 |
+
# Convert RGB to BGR for OpenCV
|
| 138 |
+
if len(img_array.shape) == 3:
|
| 139 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
| 140 |
+
else:
|
| 141 |
+
img_array = image
|
| 142 |
+
|
| 143 |
+
# Load face detection cascade
|
| 144 |
+
face_cascade = cv2.CascadeClassifier(
|
| 145 |
+
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Convert to grayscale for face detection (detection works better on grayscale)
|
| 149 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
|
| 150 |
+
|
| 151 |
+
# Detect faces
|
| 152 |
+
faces = face_cascade.detectMultiScale(
|
| 153 |
+
gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if len(faces) == 0:
|
| 157 |
+
st.warning("β οΈ No face detected in the image. Using center crop instead.")
|
| 158 |
+
# Fallback: center crop and resize
|
| 159 |
+
if isinstance(image, Image.Image):
|
| 160 |
+
# Convert to RGB first
|
| 161 |
+
rgb_pil = image.convert("RGB")
|
| 162 |
+
# Center crop to square
|
| 163 |
+
width, height = rgb_pil.size
|
| 164 |
+
size = min(width, height)
|
| 165 |
+
left = (width - size) // 2
|
| 166 |
+
top = (height - size) // 2
|
| 167 |
+
right = left + size
|
| 168 |
+
bottom = top + size
|
| 169 |
+
cropped = rgb_pil.crop((left, top, right, bottom))
|
| 170 |
+
# Resize to 224x224 (matching FER2013 training: transforms.Resize(224))
|
| 171 |
+
resized = cropped.resize((224, 224), Image.Resampling.LANCZOS)
|
| 172 |
+
|
| 173 |
+
# Convert to grayscale and then to 3-channel RGB
|
| 174 |
+
gray_pil = resized.convert("L")
|
| 175 |
+
# Convert back to RGB (this replicates grayscale values to all 3 channels)
|
| 176 |
+
gray_rgb_pil = gray_pil.convert("RGB")
|
| 177 |
+
return gray_rgb_pil
|
| 178 |
+
else:
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
# Get the largest face (assuming it's the main subject)
|
| 182 |
+
x, y, w, h = max(faces, key=lambda rect: rect[2] * rect[3])
|
| 183 |
+
|
| 184 |
+
# Add padding around the face based on user preference
|
| 185 |
+
padding_x = int(w * crop_tightness)
|
| 186 |
+
padding_y = int(h * crop_tightness)
|
| 187 |
+
|
| 188 |
+
# Ensure we don't go out of bounds
|
| 189 |
+
x1 = max(0, x - padding_x)
|
| 190 |
+
y1 = max(0, y - padding_y)
|
| 191 |
+
x2 = min(img_array.shape[1], x + w + padding_x)
|
| 192 |
+
y2 = min(img_array.shape[0], y + h + padding_y)
|
| 193 |
+
|
| 194 |
+
# Crop to face region
|
| 195 |
+
face_crop = img_array[y1:y2, x1:x2]
|
| 196 |
+
|
| 197 |
+
# Convert BGR to RGB first
|
| 198 |
+
face_crop_rgb = cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB)
|
| 199 |
+
|
| 200 |
+
# Convert to grayscale
|
| 201 |
+
face_gray = cv2.cvtColor(face_crop_rgb, cv2.COLOR_RGB2GRAY)
|
| 202 |
+
|
| 203 |
+
# Resize to 224x224 (matching FER2013 training: transforms.Resize(224))
|
| 204 |
+
face_resized = cv2.resize(face_gray, (224, 224), interpolation=cv2.INTER_AREA)
|
| 205 |
+
|
| 206 |
+
# Convert grayscale to 3-channel RGB (replicate grayscale values)
|
| 207 |
+
face_rgb_3channel = cv2.cvtColor(face_resized, cv2.COLOR_GRAY2RGB)
|
| 208 |
+
|
| 209 |
+
# Convert back to PIL Image
|
| 210 |
+
face_pil = Image.fromarray(face_rgb_3channel)
|
| 211 |
+
|
| 212 |
+
return face_pil
|
| 213 |
+
|
| 214 |
+
except ImportError:
|
| 215 |
+
st.error(
|
| 216 |
+
"β OpenCV not installed. Please install it with: pip install opencv-python"
|
| 217 |
+
)
|
| 218 |
+
st.info("Falling back to basic preprocessing...")
|
| 219 |
+
# Fallback: basic grayscale conversion and resize
|
| 220 |
+
if isinstance(image, Image.Image):
|
| 221 |
+
rgb_pil = image.convert("RGB")
|
| 222 |
+
resized = rgb_pil.resize((48, 48), Image.Resampling.LANCZOS)
|
| 223 |
+
# Convert to grayscale and then to 3-channel RGB
|
| 224 |
+
gray_pil = resized.convert("L")
|
| 225 |
+
gray_rgb_pil = gray_pil.convert("RGB")
|
| 226 |
+
return gray_rgb_pil
|
| 227 |
+
return None
|
| 228 |
+
except Exception as e:
|
| 229 |
+
st.error(f"β Error in face detection: {str(e)}")
|
| 230 |
+
st.info("Falling back to basic preprocessing...")
|
| 231 |
+
# Fallback: basic grayscale conversion and resize
|
| 232 |
+
if isinstance(image, Image.Image):
|
| 233 |
+
rgb_pil = image.convert("RGB")
|
| 234 |
+
resized = rgb_pil.resize((48, 48), Image.Resampling.LANCZOS)
|
| 235 |
+
# Convert to grayscale and then to 3-channel RGB
|
| 236 |
+
gray_pil = resized.convert("L")
|
| 237 |
+
gray_rgb_pil = gray_pil.convert("RGB")
|
| 238 |
+
return gray_rgb_pil
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_sentiment_mapping(num_classes):
|
| 243 |
+
"""Get the sentiment mapping based on number of classes"""
|
| 244 |
+
if num_classes == 3:
|
| 245 |
+
return {0: "Negative", 1: "Neutral", 2: "Positive"}
|
| 246 |
+
elif num_classes == 4:
|
| 247 |
+
# Common 4-class emotion mapping
|
| 248 |
+
return {0: "Angry", 1: "Sad", 2: "Happy", 3: "Neutral"}
|
| 249 |
+
elif num_classes == 7:
|
| 250 |
+
# FER2013 7-class emotion mapping
|
| 251 |
+
return {
|
| 252 |
+
0: "Angry",
|
| 253 |
+
1: "Disgust",
|
| 254 |
+
2: "Fear",
|
| 255 |
+
3: "Happy",
|
| 256 |
+
4: "Sad",
|
| 257 |
+
5: "Surprise",
|
| 258 |
+
6: "Neutral",
|
| 259 |
+
}
|
| 260 |
+
else:
|
| 261 |
+
# Generic mapping for unknown number of classes
|
| 262 |
+
return {i: f"Class_{i}" for i in range(num_classes)}
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# Placeholder functions for model predictions
|
| 266 |
+
def predict_text_sentiment(text):
|
| 267 |
+
"""
|
| 268 |
+
Analyze text sentiment using TextBlob
|
| 269 |
+
"""
|
| 270 |
+
if not text or text.strip() == "":
|
| 271 |
+
return "No text provided", 0.0
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
from textblob import TextBlob
|
| 275 |
+
|
| 276 |
+
# Create TextBlob object
|
| 277 |
+
blob = TextBlob(text)
|
| 278 |
+
|
| 279 |
+
# Get polarity (-1 to 1, where -1 is very negative, 1 is very positive)
|
| 280 |
+
polarity = blob.sentiment.polarity
|
| 281 |
+
|
| 282 |
+
# Get subjectivity (0 to 1, where 0 is very objective, 1 is very subjective)
|
| 283 |
+
subjectivity = blob.sentiment.subjectivity
|
| 284 |
+
|
| 285 |
+
# Convert polarity to sentiment categories
|
| 286 |
+
if polarity > 0.1:
|
| 287 |
+
sentiment = "Positive"
|
| 288 |
+
confidence = min(0.95, 0.6 + abs(polarity) * 0.3)
|
| 289 |
+
elif polarity < -0.1:
|
| 290 |
+
sentiment = "Negative"
|
| 291 |
+
confidence = min(0.95, 0.6 + abs(polarity) * 0.3)
|
| 292 |
+
else:
|
| 293 |
+
sentiment = "Neutral"
|
| 294 |
+
confidence = 0.7 - abs(polarity) * 0.2
|
| 295 |
+
|
| 296 |
+
# Round confidence to 2 decimal places
|
| 297 |
+
confidence = round(confidence, 2)
|
| 298 |
+
|
| 299 |
+
return sentiment, confidence
|
| 300 |
+
|
| 301 |
+
except ImportError:
|
| 302 |
+
st.error(
|
| 303 |
+
"β TextBlob not installed. Please install it with: pip install textblob"
|
| 304 |
+
)
|
| 305 |
+
return "TextBlob not available", 0.0
|
| 306 |
+
except Exception as e:
|
| 307 |
+
st.error(f"β Error in text sentiment analysis: {str(e)}")
|
| 308 |
+
return "Error occurred", 0.0
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@st.cache_resource
|
| 312 |
+
def load_audio_model():
|
| 313 |
+
"""Load the pre-trained Wav2Vec2 audio sentiment model"""
|
| 314 |
+
try:
|
| 315 |
+
# Check if model file exists
|
| 316 |
+
model_path = "models/wav2vec2_model.pth"
|
| 317 |
+
if not os.path.exists(model_path):
|
| 318 |
+
st.error(f"β Audio model file not found at: {model_path}")
|
| 319 |
+
return None, None, None, None
|
| 320 |
+
|
| 321 |
+
# Load the model weights first to check the architecture
|
| 322 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 323 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 324 |
+
|
| 325 |
+
# Check the number of classes from the checkpoint
|
| 326 |
+
if "classifier.weight" in checkpoint:
|
| 327 |
+
num_classes = checkpoint["classifier.weight"].shape[0]
|
| 328 |
+
st.info(f"π Audio model checkpoint has {num_classes} output classes")
|
| 329 |
+
else:
|
| 330 |
+
num_classes = 3 # Default assumption
|
| 331 |
+
st.warning(
|
| 332 |
+
"β οΈ Could not determine number of classes from checkpoint, assuming 3"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Initialize Wav2Vec2 model with the correct number of classes
|
| 336 |
+
from transformers import AutoModelForAudioClassification
|
| 337 |
+
|
| 338 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
| 339 |
+
"facebook/wav2vec2-base", num_labels=num_classes
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Load trained weights
|
| 343 |
+
model.load_state_dict(checkpoint)
|
| 344 |
+
model.to(device)
|
| 345 |
+
model.eval()
|
| 346 |
+
|
| 347 |
+
# Load feature extractor
|
| 348 |
+
from transformers import AutoFeatureExtractor
|
| 349 |
+
|
| 350 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 351 |
+
"facebook/wav2vec2-base"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
st.success(f"β
Audio model loaded successfully with {num_classes} classes!")
|
| 355 |
+
return model, device, num_classes, feature_extractor
|
| 356 |
+
except Exception as e:
|
| 357 |
+
st.error(f"β Error loading audio model: {str(e)}")
|
| 358 |
+
return None, None, None, None
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def predict_audio_sentiment(audio_bytes):
|
| 362 |
+
"""
|
| 363 |
+
Analyze audio sentiment using fine-tuned Wav2Vec2 model
|
| 364 |
+
Preprocessing matches CREMA-D + RAVDESS training specifications:
|
| 365 |
+
- Target sampling rate: 16kHz
|
| 366 |
+
- Max duration: 5.0 seconds
|
| 367 |
+
- Feature extraction: AutoFeatureExtractor with max_length, truncation, padding
|
| 368 |
+
"""
|
| 369 |
+
if audio_bytes is None:
|
| 370 |
+
return "No audio provided", 0.0
|
| 371 |
+
|
| 372 |
+
try:
|
| 373 |
+
# Load model if not already loaded
|
| 374 |
+
model, device, num_classes, feature_extractor = load_audio_model()
|
| 375 |
+
if model is None:
|
| 376 |
+
return "Model not loaded", 0.0
|
| 377 |
+
|
| 378 |
+
# Load and preprocess audio
|
| 379 |
+
import librosa
|
| 380 |
+
import io
|
| 381 |
+
import tempfile
|
| 382 |
+
|
| 383 |
+
# Save audio bytes to temporary file
|
| 384 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 385 |
+
tmp_file.write(audio_bytes)
|
| 386 |
+
tmp_file_path = tmp_file.name
|
| 387 |
+
|
| 388 |
+
try:
|
| 389 |
+
# Load audio with librosa
|
| 390 |
+
audio, sr = librosa.load(tmp_file_path, sr=None)
|
| 391 |
+
|
| 392 |
+
# Resample to 16kHz if needed
|
| 393 |
+
if sr != 16000:
|
| 394 |
+
audio = librosa.resample(y=audio, orig_sr=sr, target_sr=16000)
|
| 395 |
+
|
| 396 |
+
# Preprocess with feature extractor (matching CREMA-D + RAVDESS training exactly)
|
| 397 |
+
# From training: max_length=int(max_duration_s * TARGET_SAMPLING_RATE) = 5.0 * 16000
|
| 398 |
+
inputs = feature_extractor(
|
| 399 |
+
audio,
|
| 400 |
+
sampling_rate=16000,
|
| 401 |
+
max_length=int(5.0 * 16000), # 5 seconds max (matching training)
|
| 402 |
+
truncation=True,
|
| 403 |
+
padding="max_length",
|
| 404 |
+
return_tensors="pt",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
# Move to device
|
| 408 |
+
input_values = inputs.input_values.to(device)
|
| 409 |
+
|
| 410 |
+
# Run inference
|
| 411 |
+
with torch.no_grad():
|
| 412 |
+
outputs = model(input_values)
|
| 413 |
+
probabilities = torch.softmax(outputs.logits, dim=1)
|
| 414 |
+
confidence, predicted = torch.max(probabilities, 1)
|
| 415 |
+
|
| 416 |
+
# Get sentiment mapping based on number of classes
|
| 417 |
+
if num_classes == 3:
|
| 418 |
+
sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
|
| 419 |
+
else:
|
| 420 |
+
# Generic mapping for unknown number of classes
|
| 421 |
+
sentiment_map = {i: f"Class_{i}" for i in range(num_classes)}
|
| 422 |
+
|
| 423 |
+
sentiment = sentiment_map[predicted.item()]
|
| 424 |
+
confidence_score = confidence.item()
|
| 425 |
+
|
| 426 |
+
return sentiment, confidence_score
|
| 427 |
+
|
| 428 |
+
finally:
|
| 429 |
+
# Clean up temporary file
|
| 430 |
+
os.unlink(tmp_file_path)
|
| 431 |
+
|
| 432 |
+
except ImportError as e:
|
| 433 |
+
st.error(f"β Required library not installed: {str(e)}")
|
| 434 |
+
st.info("Please install: pip install librosa transformers")
|
| 435 |
+
return "Library not available", 0.0
|
| 436 |
+
except Exception as e:
|
| 437 |
+
st.error(f"β Error in audio sentiment prediction: {str(e)}")
|
| 438 |
+
return "Error occurred", 0.0
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def predict_vision_sentiment(image, crop_tightness=0.05):
|
| 442 |
+
"""
|
| 443 |
+
Load ResNet-50 and run inference for vision sentiment analysis
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
image: Input image (PIL Image or numpy array)
|
| 447 |
+
crop_tightness: Padding around face (0.0 = no padding, 0.3 = 30% padding)
|
| 448 |
+
"""
|
| 449 |
+
if image is None:
|
| 450 |
+
return "No image provided", 0.0
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
# Load model if not already loaded
|
| 454 |
+
model, device, num_classes = load_vision_model()
|
| 455 |
+
if model is None:
|
| 456 |
+
return "Model not loaded", 0.0
|
| 457 |
+
|
| 458 |
+
# Preprocess image to match FER2013 format
|
| 459 |
+
st.info(
|
| 460 |
+
"π Detecting face and preprocessing image to match training data format..."
|
| 461 |
+
)
|
| 462 |
+
preprocessed_image = detect_and_preprocess_face(image, crop_tightness=0.0)
|
| 463 |
+
|
| 464 |
+
if preprocessed_image is None:
|
| 465 |
+
return "Image preprocessing failed", 0.0
|
| 466 |
+
|
| 467 |
+
# Show preprocessed image
|
| 468 |
+
st.image(
|
| 469 |
+
preprocessed_image,
|
| 470 |
+
caption="Preprocessed Image (48x48 Grayscale β 3-channel RGB)",
|
| 471 |
+
width=200,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Get transforms
|
| 475 |
+
transform = get_vision_transforms()
|
| 476 |
+
|
| 477 |
+
# Convert preprocessed image to tensor
|
| 478 |
+
image_tensor = transform(preprocessed_image).unsqueeze(0).to(device)
|
| 479 |
+
|
| 480 |
+
# Run inference
|
| 481 |
+
with torch.no_grad():
|
| 482 |
+
outputs = model(image_tensor)
|
| 483 |
+
|
| 484 |
+
# Debug: print output shape
|
| 485 |
+
st.info(f"π Model output shape: {outputs.shape}")
|
| 486 |
+
|
| 487 |
+
probabilities = F.softmax(outputs, dim=1)
|
| 488 |
+
confidence, predicted = torch.max(probabilities, 1)
|
| 489 |
+
|
| 490 |
+
# Get sentiment mapping based on number of classes
|
| 491 |
+
sentiment_map = get_sentiment_mapping(num_classes)
|
| 492 |
+
sentiment = sentiment_map[predicted.item()]
|
| 493 |
+
confidence_score = confidence.item()
|
| 494 |
+
|
| 495 |
+
return sentiment, confidence_score
|
| 496 |
+
|
| 497 |
+
except Exception as e:
|
| 498 |
+
st.error(f"Error in vision sentiment prediction: {str(e)}")
|
| 499 |
+
st.error(
|
| 500 |
+
f"Model output shape mismatch. Expected {num_classes} classes but got different."
|
| 501 |
+
)
|
| 502 |
+
return "Error occurred", 0.0
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def predict_fused_sentiment(text=None, audio_bytes=None, image=None):
|
| 506 |
+
"""
|
| 507 |
+
TODO: Implement ensemble/fusion logic combining all three models
|
| 508 |
+
This is a placeholder function for fused sentiment analysis
|
| 509 |
+
"""
|
| 510 |
+
# Placeholder logic - replace with actual fusion implementation
|
| 511 |
+
results = []
|
| 512 |
+
|
| 513 |
+
if text:
|
| 514 |
+
text_sentiment, text_conf = predict_text_sentiment(text)
|
| 515 |
+
results.append((text_sentiment, text_conf))
|
| 516 |
+
|
| 517 |
+
if audio_bytes:
|
| 518 |
+
audio_sentiment, audio_conf = predict_audio_sentiment(audio_bytes)
|
| 519 |
+
results.append((audio_sentiment, audio_conf))
|
| 520 |
+
|
| 521 |
+
if image:
|
| 522 |
+
vision_sentiment, vision_conf = predict_vision_sentiment(image)
|
| 523 |
+
results.append((vision_sentiment, vision_conf))
|
| 524 |
+
|
| 525 |
+
if not results:
|
| 526 |
+
return "No inputs provided", 0.0
|
| 527 |
+
|
| 528 |
+
# Simple ensemble logic (replace with your fusion strategy)
|
| 529 |
+
sentiment_counts = {}
|
| 530 |
+
total_confidence = 0
|
| 531 |
+
|
| 532 |
+
for sentiment, confidence in results:
|
| 533 |
+
sentiment_counts[sentiment] = sentiment_counts.get(sentiment, 0) + 1
|
| 534 |
+
total_confidence += confidence
|
| 535 |
+
|
| 536 |
+
# Majority voting with confidence averaging
|
| 537 |
+
final_sentiment = max(sentiment_counts, key=sentiment_counts.get)
|
| 538 |
+
avg_confidence = total_confidence / len(results)
|
| 539 |
+
|
| 540 |
+
return final_sentiment, avg_confidence
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
# Sidebar navigation
|
| 544 |
+
st.sidebar.title("π§ Sentiment Analysis")
|
| 545 |
+
st.sidebar.markdown("---")
|
| 546 |
+
|
| 547 |
+
# Navigation
|
| 548 |
+
page = st.sidebar.selectbox(
|
| 549 |
+
"Choose a page:",
|
| 550 |
+
[
|
| 551 |
+
"π Home",
|
| 552 |
+
"π Text Sentiment",
|
| 553 |
+
"π΅ Audio Sentiment",
|
| 554 |
+
"πΌοΈ Vision Sentiment",
|
| 555 |
+
"π Fused Model",
|
| 556 |
+
],
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
# Home Page
|
| 560 |
+
if page == "π Home":
|
| 561 |
+
st.markdown(
|
| 562 |
+
'<h1 class="main-header">Sentiment Analysis Testing Ground</h1>',
|
| 563 |
+
unsafe_allow_html=True,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
st.markdown(
|
| 567 |
+
"""
|
| 568 |
+
<div class="model-card">
|
| 569 |
+
<h2>Welcome to your Multi-Modal Sentiment Analysis Testing Platform!</h2>
|
| 570 |
+
<p>This application provides a comprehensive testing environment for your three independent sentiment analysis models:</p>
|
| 571 |
+
</div>
|
| 572 |
+
""",
|
| 573 |
+
unsafe_allow_html=True,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
col1, col2, col3 = st.columns(3)
|
| 577 |
+
|
| 578 |
+
with col1:
|
| 579 |
+
st.markdown(
|
| 580 |
+
"""
|
| 581 |
+
<div class="model-card">
|
| 582 |
+
<h3>π Text Sentiment Model</h3>
|
| 583 |
+
<p>β
<strong>READY TO USE</strong> - Analyze sentiment from text input using TextBlob</p>
|
| 584 |
+
<ul>
|
| 585 |
+
<li>Process any text input</li>
|
| 586 |
+
<li>Get sentiment classification (Positive/Negative/Neutral)</li>
|
| 587 |
+
<li>View confidence scores</li>
|
| 588 |
+
<li>Real-time NLP analysis</li>
|
| 589 |
+
</ul>
|
| 590 |
+
</div>
|
| 591 |
+
""",
|
| 592 |
+
unsafe_allow_html=True,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
with col2:
|
| 596 |
+
st.markdown(
|
| 597 |
+
"""
|
| 598 |
+
<div class="model-card">
|
| 599 |
+
<h3>π΅ Audio Sentiment Model</h3>
|
| 600 |
+
<p>β
<strong>READY TO USE</strong> - Analyze sentiment from audio files using fine-tuned Wav2Vec2</p>
|
| 601 |
+
<ul>
|
| 602 |
+
<li>Upload audio files (.wav, .mp3, .m4a, .flac)</li>
|
| 603 |
+
<li>ποΈ Record audio directly with microphone (max 5s)</li>
|
| 604 |
+
<li>π Automatic preprocessing: 16kHz sampling, 5s max duration (CREMA-D + RAVDESS format)</li>
|
| 605 |
+
<li>Listen to uploaded/recorded audio</li>
|
| 606 |
+
<li>Get sentiment predictions</li>
|
| 607 |
+
<li>Real-time audio analysis</li>
|
| 608 |
+
</ul>
|
| 609 |
+
</div>
|
| 610 |
+
""",
|
| 611 |
+
unsafe_allow_html=True,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
with col3:
|
| 615 |
+
st.markdown(
|
| 616 |
+
"""
|
| 617 |
+
<div class="model-card">
|
| 618 |
+
<h3>πΌοΈ Vision Sentiment Model</h3>
|
| 619 |
+
<p>Analyze sentiment from images using fine-tuned ResNet-50</p>
|
| 620 |
+
<ul>
|
| 621 |
+
<li>Upload image files (.png, .jpg, .jpeg, .bmp, .tiff)</li>
|
| 622 |
+
<li>π Automatic face detection & preprocessing</li>
|
| 623 |
+
<li>π― Fixed 0% padding for tightest face crop</li>
|
| 624 |
+
<li>π Convert to 224x224 grayscale β 3-channel RGB (FER2013 format)</li>
|
| 625 |
+
<li>π― Transforms: Resize(224) β CenterCrop(224) β ImageNet Normalization</li>
|
| 626 |
+
<li>Preview original & preprocessed images</li>
|
| 627 |
+
<li>Get sentiment predictions</li>
|
| 628 |
+
</ul>
|
| 629 |
+
</div>
|
| 630 |
+
""",
|
| 631 |
+
unsafe_allow_html=True,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
st.markdown(
|
| 635 |
+
"""
|
| 636 |
+
<div class="model-card">
|
| 637 |
+
<h3>π Fused Model</h3>
|
| 638 |
+
<p>Combine predictions from all three models for enhanced accuracy</p>
|
| 639 |
+
<ul>
|
| 640 |
+
<li>Multi-modal input processing</li>
|
| 641 |
+
<li>Ensemble prediction strategies</li>
|
| 642 |
+
<li>Comprehensive sentiment analysis</li>
|
| 643 |
+
</ul>
|
| 644 |
+
</div>
|
| 645 |
+
""",
|
| 646 |
+
unsafe_allow_html=True,
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
st.markdown("---")
|
| 650 |
+
st.markdown(
|
| 651 |
+
"""
|
| 652 |
+
<div style="text-align: center; color: #666;">
|
| 653 |
+
<p><strong>Note:</strong> This application now has <strong>ALL THREE MODELS</strong> fully integrated and ready to use! π</p>
|
| 654 |
+
<p><strong>TextBlob</strong> (Text) + <strong>Wav2Vec2</strong> (Audio) + <strong>ResNet-50</strong> (Vision)</p>
|
| 655 |
+
</div>
|
| 656 |
+
""",
|
| 657 |
+
unsafe_allow_html=True,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
# Text Sentiment Page
|
| 661 |
+
elif page == "π Text Sentiment":
|
| 662 |
+
st.title("π Text Sentiment Analysis")
|
| 663 |
+
st.markdown("Analyze the sentiment of your text using our TextBlob-based model.")
|
| 664 |
+
|
| 665 |
+
# Text input
|
| 666 |
+
text_input = st.text_area(
|
| 667 |
+
"Enter your text here:",
|
| 668 |
+
height=150,
|
| 669 |
+
placeholder="Type or paste your text here to analyze its sentiment...",
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# Analyze button
|
| 673 |
+
if st.button("π Analyze Sentiment", type="primary", use_container_width=True):
|
| 674 |
+
if text_input and text_input.strip():
|
| 675 |
+
with st.spinner("Analyzing text sentiment..."):
|
| 676 |
+
sentiment, confidence = predict_text_sentiment(text_input)
|
| 677 |
+
|
| 678 |
+
# Display results
|
| 679 |
+
st.markdown("### Results")
|
| 680 |
+
|
| 681 |
+
# Display results in columns
|
| 682 |
+
col1, col2 = st.columns(2)
|
| 683 |
+
with col1:
|
| 684 |
+
st.metric("Sentiment", sentiment)
|
| 685 |
+
with col2:
|
| 686 |
+
st.metric("Confidence", f"{confidence:.2f}")
|
| 687 |
+
|
| 688 |
+
# Color-coded sentiment display
|
| 689 |
+
sentiment_colors = {
|
| 690 |
+
"Positive": "π’",
|
| 691 |
+
"Negative": "π΄",
|
| 692 |
+
"Neutral": "π‘",
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
st.markdown(
|
| 696 |
+
f"""
|
| 697 |
+
<div class="result-box">
|
| 698 |
+
<h4>{sentiment_colors.get(sentiment, "β")} Sentiment: {sentiment}</h4>
|
| 699 |
+
<p><strong>Confidence:</strong> {confidence:.2f}</p>
|
| 700 |
+
<p><strong>Input Text:</strong> "{text_input[:100]}{'...' if len(text_input) > 100 else ''}"</p>
|
| 701 |
+
<p><strong>Model:</strong> TextBlob (Natural Language Processing)</p>
|
| 702 |
+
</div>
|
| 703 |
+
""",
|
| 704 |
+
unsafe_allow_html=True,
|
| 705 |
+
)
|
| 706 |
+
else:
|
| 707 |
+
st.error("Please enter some text to analyze.")
|
| 708 |
+
|
| 709 |
+
# Audio Sentiment Page
|
| 710 |
+
elif page == "π΅ Audio Sentiment":
|
| 711 |
+
st.title("π΅ Audio Sentiment Analysis")
|
| 712 |
+
st.markdown(
|
| 713 |
+
"Analyze the sentiment of your audio files using our fine-tuned Wav2Vec2 model."
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# Preprocessing information
|
| 717 |
+
st.info(
|
| 718 |
+
"βΉοΈ **Audio Preprocessing**: Audio will be automatically processed to match CREMA-D + RAVDESS training format: "
|
| 719 |
+
"16kHz sampling rate, max 5 seconds, with automatic resampling and feature extraction."
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
# Model status
|
| 723 |
+
model, device, num_classes, feature_extractor = load_audio_model()
|
| 724 |
+
if model is None:
|
| 725 |
+
st.error("β Audio model could not be loaded. Please check the model file.")
|
| 726 |
+
st.info("Expected model file: `models/wav2vec2_model.pth`")
|
| 727 |
+
else:
|
| 728 |
+
st.success(
|
| 729 |
+
f"β
Audio model loaded successfully on {device} with {num_classes} classes!"
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
# Input method selection
|
| 733 |
+
st.subheader("π€ Choose Input Method")
|
| 734 |
+
input_method = st.radio(
|
| 735 |
+
"Select how you want to provide audio:",
|
| 736 |
+
["π Upload Audio File", "ποΈ Record Audio"],
|
| 737 |
+
horizontal=True,
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
if input_method == "π Upload Audio File":
|
| 741 |
+
# File uploader
|
| 742 |
+
uploaded_audio = st.file_uploader(
|
| 743 |
+
"Choose an audio file",
|
| 744 |
+
type=["wav", "mp3", "m4a", "flac"],
|
| 745 |
+
help="Supported formats: WAV, MP3, M4A, FLAC",
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
audio_source = "uploaded_file"
|
| 749 |
+
audio_name = uploaded_audio.name if uploaded_audio else None
|
| 750 |
+
|
| 751 |
+
else: # Audio recording
|
| 752 |
+
st.markdown(
|
| 753 |
+
"""
|
| 754 |
+
<div class="model-card">
|
| 755 |
+
<h3>ποΈ Audio Recording</h3>
|
| 756 |
+
<p>Record audio directly with your microphone (max 5 seconds).</p>
|
| 757 |
+
<p><strong>Note:</strong> Make sure your microphone is accessible and you have permission to use it.</p>
|
| 758 |
+
</div>
|
| 759 |
+
""",
|
| 760 |
+
unsafe_allow_html=True,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
# Audio recorder
|
| 764 |
+
recorded_audio = st.audio_input(
|
| 765 |
+
label="Click to start recording",
|
| 766 |
+
help="Click the microphone button to start/stop recording. Maximum recording time is 5 seconds.",
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
if recorded_audio is not None:
|
| 770 |
+
# Display recorded audio
|
| 771 |
+
st.audio(recorded_audio, format="audio/wav")
|
| 772 |
+
st.success("β
Audio recorded successfully!")
|
| 773 |
+
|
| 774 |
+
# Convert recorded audio to bytes for processing
|
| 775 |
+
uploaded_audio = recorded_audio
|
| 776 |
+
audio_source = "recorded"
|
| 777 |
+
audio_name = "Recorded Audio"
|
| 778 |
+
else:
|
| 779 |
+
uploaded_audio = None
|
| 780 |
+
audio_source = None
|
| 781 |
+
audio_name = None
|
| 782 |
+
|
| 783 |
+
if uploaded_audio is not None:
|
| 784 |
+
# Display audio player
|
| 785 |
+
if audio_source == "recorded":
|
| 786 |
+
st.audio(uploaded_audio, format="audio/wav")
|
| 787 |
+
st.info(f"ποΈ {audio_name} | Source: Microphone Recording")
|
| 788 |
+
else:
|
| 789 |
+
st.audio(
|
| 790 |
+
uploaded_audio, format=f'audio/{uploaded_audio.name.split(".")[-1]}'
|
| 791 |
+
)
|
| 792 |
+
# File info for uploaded files
|
| 793 |
+
file_size = len(uploaded_audio.getvalue()) / 1024 # KB
|
| 794 |
+
st.info(f"π File: {uploaded_audio.name} | Size: {file_size:.1f} KB")
|
| 795 |
+
|
| 796 |
+
# Analyze button
|
| 797 |
+
if st.button(
|
| 798 |
+
"π Analyze Audio Sentiment", type="primary", use_container_width=True
|
| 799 |
+
):
|
| 800 |
+
if model is None:
|
| 801 |
+
st.error("β Model not loaded. Cannot analyze audio.")
|
| 802 |
+
else:
|
| 803 |
+
with st.spinner("Analyzing audio sentiment..."):
|
| 804 |
+
audio_bytes = uploaded_audio.getvalue()
|
| 805 |
+
sentiment, confidence = predict_audio_sentiment(audio_bytes)
|
| 806 |
+
|
| 807 |
+
# Display results
|
| 808 |
+
st.markdown("### Results")
|
| 809 |
+
|
| 810 |
+
col1, col2 = st.columns(2)
|
| 811 |
+
with col1:
|
| 812 |
+
st.metric("Sentiment", sentiment)
|
| 813 |
+
with col2:
|
| 814 |
+
st.metric("Confidence", f"{confidence:.2f}")
|
| 815 |
+
|
| 816 |
+
# Color-coded sentiment display
|
| 817 |
+
sentiment_colors = {"Positive": "π’", "Negative": "π΄", "Neutral": "π‘"}
|
| 818 |
+
|
| 819 |
+
st.markdown(
|
| 820 |
+
f"""
|
| 821 |
+
<div class="result-box">
|
| 822 |
+
<h4>{sentiment_colors.get(sentiment, "β")} Sentiment: {sentiment}</h4>
|
| 823 |
+
<p><strong>Confidence:</strong> {confidence:.2f}</p>
|
| 824 |
+
<p><strong>Audio Source:</strong> {audio_name}</p>
|
| 825 |
+
<p><strong>Model:</strong> Wav2Vec2 (Fine-tuned on RAVDESS + CREMA-D)</p>
|
| 826 |
+
</div>
|
| 827 |
+
""",
|
| 828 |
+
unsafe_allow_html=True,
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
if input_method == "π Upload Audio File":
|
| 832 |
+
st.info("π Please upload an audio file to begin analysis.")
|
| 833 |
+
else:
|
| 834 |
+
st.info("ποΈ Click the microphone button above to record audio for analysis.")
|
| 835 |
+
|
| 836 |
+
# Vision Sentiment Page
|
| 837 |
+
elif page == "πΌοΈ Vision Sentiment":
|
| 838 |
+
st.title("πΌοΈ Vision Sentiment Analysis")
|
| 839 |
+
st.markdown(
|
| 840 |
+
"Analyze the sentiment of your images using our fine-tuned ResNet-50 model."
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
st.info(
|
| 844 |
+
"βΉοΈ **Note**: Images will be automatically preprocessed to match FER2013 format: face detection, grayscale conversion, and 224x224 resize (converted to 3-channel RGB)."
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
# Face cropping is set to 0% (no padding) for tightest crop
|
| 848 |
+
st.info(
|
| 849 |
+
"π― **Face Cropping**: Set to 0% padding for tightest crop on facial features"
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
# Model status
|
| 853 |
+
model, device, num_classes = load_vision_model()
|
| 854 |
+
if model is None:
|
| 855 |
+
st.error("β Vision model could not be loaded. Please check the model file.")
|
| 856 |
+
st.info("Expected model file: `models/resnet50_model.pth`")
|
| 857 |
+
else:
|
| 858 |
+
st.success(
|
| 859 |
+
f"β
Vision model loaded successfully on {device} with {num_classes} classes!"
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
# Input method selection
|
| 863 |
+
st.subheader("πΈ Choose Input Method")
|
| 864 |
+
input_method = st.radio(
|
| 865 |
+
"Select how you want to provide an image:",
|
| 866 |
+
["π Upload Image File", "π· Take Photo with Camera"],
|
| 867 |
+
horizontal=True,
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
if input_method == "π Upload Image File":
|
| 871 |
+
# File uploader
|
| 872 |
+
uploaded_image = st.file_uploader(
|
| 873 |
+
"Choose an image file",
|
| 874 |
+
type=["png", "jpg", "jpeg", "bmp", "tiff"],
|
| 875 |
+
help="Supported formats: PNG, JPG, JPEG, BMP, TIFF",
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
if uploaded_image is not None:
|
| 879 |
+
# Display image
|
| 880 |
+
image = Image.open(uploaded_image)
|
| 881 |
+
st.image(
|
| 882 |
+
image,
|
| 883 |
+
caption=f"Uploaded Image: {uploaded_image.name}",
|
| 884 |
+
use_container_width=True,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
# File info
|
| 888 |
+
file_size = len(uploaded_image.getvalue()) / 1024 # KB
|
| 889 |
+
st.info(
|
| 890 |
+
f"π File: {uploaded_image.name} | Size: {file_size:.1f} KB | Dimensions: {image.size[0]}x{image.size[1]}"
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
# Analyze button
|
| 894 |
+
if st.button(
|
| 895 |
+
"π Analyze Image Sentiment", type="primary", use_container_width=True
|
| 896 |
+
):
|
| 897 |
+
if model is None:
|
| 898 |
+
st.error("β Model not loaded. Cannot analyze image.")
|
| 899 |
+
else:
|
| 900 |
+
with st.spinner("Analyzing image sentiment..."):
|
| 901 |
+
sentiment, confidence = predict_vision_sentiment(image)
|
| 902 |
+
|
| 903 |
+
# Display results
|
| 904 |
+
st.markdown("### Results")
|
| 905 |
+
|
| 906 |
+
col1, col2 = st.columns(2)
|
| 907 |
+
with col1:
|
| 908 |
+
st.metric("Sentiment", sentiment)
|
| 909 |
+
with col2:
|
| 910 |
+
st.metric("Confidence", f"{confidence:.2f}")
|
| 911 |
+
|
| 912 |
+
# Color-coded sentiment display
|
| 913 |
+
sentiment_colors = {
|
| 914 |
+
"Positive": "π’",
|
| 915 |
+
"Negative": "π΄",
|
| 916 |
+
"Neutral": "π‘",
|
| 917 |
+
}
|
| 918 |
+
|
| 919 |
+
st.markdown(
|
| 920 |
+
f"""
|
| 921 |
+
<div class="result-box">
|
| 922 |
+
<h4>{sentiment_colors.get(sentiment, "β")} Sentiment: {sentiment}</h4>
|
| 923 |
+
<p><strong>Confidence:</strong> {confidence:.2f}</p>
|
| 924 |
+
<p><strong>Image File:</strong> {uploaded_image.name}</p>
|
| 925 |
+
<p><strong>Model:</strong> ResNet-50 (Fine-tuned on FER2013)</p>
|
| 926 |
+
</div>
|
| 927 |
+
""",
|
| 928 |
+
unsafe_allow_html=True,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
else: # Camera capture
|
| 932 |
+
st.markdown(
|
| 933 |
+
"""
|
| 934 |
+
<div class="model-card">
|
| 935 |
+
<h3>π· Camera Capture</h3>
|
| 936 |
+
<p>Take a photo directly with your camera to analyze its sentiment.</p>
|
| 937 |
+
<p><strong>Note:</strong> Make sure your camera is accessible and you have permission to use it.</p>
|
| 938 |
+
</div>
|
| 939 |
+
""",
|
| 940 |
+
unsafe_allow_html=True,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
# Camera input
|
| 944 |
+
camera_photo = st.camera_input(
|
| 945 |
+
"Take a photo",
|
| 946 |
+
help="Click the camera button to take a photo, or use the upload button to select an existing photo",
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
if camera_photo is not None:
|
| 950 |
+
# Display captured image
|
| 951 |
+
image = Image.open(camera_photo)
|
| 952 |
+
st.image(
|
| 953 |
+
image,
|
| 954 |
+
caption="Captured Photo",
|
| 955 |
+
use_container_width=True,
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
# Image info
|
| 959 |
+
st.info(
|
| 960 |
+
f"π· Captured Photo | Dimensions: {image.size[0]}x{image.size[1]} | Format: {image.format}"
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
# Analyze button
|
| 964 |
+
if st.button(
|
| 965 |
+
"π Analyze Photo Sentiment", type="primary", use_container_width=True
|
| 966 |
+
):
|
| 967 |
+
if model is None:
|
| 968 |
+
st.error("β Model not loaded. Cannot analyze image.")
|
| 969 |
+
else:
|
| 970 |
+
with st.spinner("Analyzing photo sentiment..."):
|
| 971 |
+
sentiment, confidence = predict_vision_sentiment(image)
|
| 972 |
+
|
| 973 |
+
# Display results
|
| 974 |
+
st.markdown("### Results")
|
| 975 |
+
|
| 976 |
+
col1, col2 = st.columns(2)
|
| 977 |
+
with col1:
|
| 978 |
+
st.metric("Sentiment", sentiment)
|
| 979 |
+
with col2:
|
| 980 |
+
st.metric("Confidence", f"{confidence:.2f}")
|
| 981 |
+
|
| 982 |
+
# Color-coded sentiment display
|
| 983 |
+
sentiment_colors = {
|
| 984 |
+
"Positive": "π’",
|
| 985 |
+
"Negative": "π΄",
|
| 986 |
+
"Neutral": "π‘",
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
st.markdown(
|
| 990 |
+
f"""
|
| 991 |
+
<div class="result-box">
|
| 992 |
+
<h4>{sentiment_colors.get(sentiment, "β")} Sentiment: {sentiment}</h4>
|
| 993 |
+
<p><strong>Confidence:</strong> {confidence:.2f}</p>
|
| 994 |
+
<p><strong>Image Source:</strong> Camera Capture</p>
|
| 995 |
+
<p><strong>Model:</strong> ResNet-50 (Fine-tuned on FER2013)</p>
|
| 996 |
+
</div>
|
| 997 |
+
""",
|
| 998 |
+
unsafe_allow_html=True,
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
# Show info if no image is provided
|
| 1002 |
+
if input_method == "π Upload Image File" and "uploaded_image" not in locals():
|
| 1003 |
+
st.info("π Please upload an image file to begin analysis.")
|
| 1004 |
+
elif input_method == "π· Take Photo with Camera" and "camera_photo" not in locals():
|
| 1005 |
+
st.info("π· Click the camera button above to take a photo for analysis.")
|
| 1006 |
+
|
| 1007 |
+
# Fused Model Page
|
| 1008 |
+
elif page == "π Fused Model":
|
| 1009 |
+
st.title("π Fused Model Analysis")
|
| 1010 |
+
st.markdown(
|
| 1011 |
+
"Combine predictions from all three models for enhanced sentiment analysis."
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
st.markdown(
|
| 1015 |
+
"""
|
| 1016 |
+
<div class="model-card">
|
| 1017 |
+
<h3>Multi-Modal Sentiment Analysis</h3>
|
| 1018 |
+
<p>This page allows you to input text, audio, and/or image data to get a comprehensive sentiment analysis
|
| 1019 |
+
using all three models combined.</p>
|
| 1020 |
+
</div>
|
| 1021 |
+
""",
|
| 1022 |
+
unsafe_allow_html=True,
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
# Input sections
|
| 1026 |
+
col1, col2 = st.columns(2)
|
| 1027 |
+
|
| 1028 |
+
with col1:
|
| 1029 |
+
st.subheader("π Text Input")
|
| 1030 |
+
text_input = st.text_area(
|
| 1031 |
+
"Enter text (optional):",
|
| 1032 |
+
height=100,
|
| 1033 |
+
placeholder="Type or paste your text here...",
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
st.subheader("π΅ Audio Input")
|
| 1037 |
+
|
| 1038 |
+
# Audio preprocessing information for fused model
|
| 1039 |
+
st.info(
|
| 1040 |
+
"βΉοΈ **Audio Preprocessing**: Audio will be automatically processed to match CREMA-D + RAVDESS training format: "
|
| 1041 |
+
"16kHz sampling rate, max 5 seconds, with automatic resampling and feature extraction."
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
# Audio input method for fused model
|
| 1045 |
+
audio_input_method = st.radio(
|
| 1046 |
+
"Audio input method:",
|
| 1047 |
+
["π Upload File", "ποΈ Record Audio"],
|
| 1048 |
+
key="fused_audio_method",
|
| 1049 |
+
horizontal=True,
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
if audio_input_method == "π Upload File":
|
| 1053 |
+
uploaded_audio = st.file_uploader(
|
| 1054 |
+
"Upload audio file (optional):",
|
| 1055 |
+
type=["wav", "mp3", "m4a", "flac"],
|
| 1056 |
+
key="fused_audio",
|
| 1057 |
+
)
|
| 1058 |
+
audio_source = "uploaded_file"
|
| 1059 |
+
audio_name = uploaded_audio.name if uploaded_audio else None
|
| 1060 |
+
else:
|
| 1061 |
+
# Audio recorder for fused model
|
| 1062 |
+
recorded_audio = st.audio_input(
|
| 1063 |
+
label="Record audio (optional):",
|
| 1064 |
+
key="fused_audio_recorder",
|
| 1065 |
+
help="Click to record audio for sentiment analysis",
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
if recorded_audio is not None:
|
| 1069 |
+
st.audio(recorded_audio, format="audio/wav")
|
| 1070 |
+
st.success("β
Audio recorded successfully!")
|
| 1071 |
+
uploaded_audio = recorded_audio
|
| 1072 |
+
audio_source = "recorded"
|
| 1073 |
+
audio_name = "Recorded Audio"
|
| 1074 |
+
else:
|
| 1075 |
+
uploaded_audio = None
|
| 1076 |
+
audio_source = None
|
| 1077 |
+
audio_name = None
|
| 1078 |
+
|
| 1079 |
+
with col2:
|
| 1080 |
+
st.subheader("πΌοΈ Image Input")
|
| 1081 |
+
|
| 1082 |
+
# Face cropping is set to 0% (no padding) for tightest crop
|
| 1083 |
+
st.info(
|
| 1084 |
+
"π― **Face Cropping**: Set to 0% padding for tightest crop on facial features"
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
# Image input method for fused model
|
| 1088 |
+
image_input_method = st.radio(
|
| 1089 |
+
"Image input method:",
|
| 1090 |
+
["π Upload File", "π· Take Photo"],
|
| 1091 |
+
key="fused_image_method",
|
| 1092 |
+
horizontal=True,
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
if image_input_method == "π Upload File":
|
| 1096 |
+
uploaded_image = st.file_uploader(
|
| 1097 |
+
"Upload image file (optional):",
|
| 1098 |
+
type=["png", "jpg", "jpeg", "bmp", "tiff"],
|
| 1099 |
+
key="fused_image",
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
if uploaded_image:
|
| 1103 |
+
image = Image.open(uploaded_image)
|
| 1104 |
+
st.image(image, caption="Uploaded Image", use_container_width=True)
|
| 1105 |
+
else:
|
| 1106 |
+
# Camera capture for fused model
|
| 1107 |
+
camera_photo = st.camera_input(
|
| 1108 |
+
"Take a photo (optional):",
|
| 1109 |
+
key="fused_camera",
|
| 1110 |
+
help="Click to take a photo for sentiment analysis",
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
if camera_photo:
|
| 1114 |
+
image = Image.open(camera_photo)
|
| 1115 |
+
st.image(image, caption="Captured Photo", use_container_width=True)
|
| 1116 |
+
# Set uploaded_image to camera_photo for processing
|
| 1117 |
+
uploaded_image = camera_photo
|
| 1118 |
+
|
| 1119 |
+
if uploaded_audio:
|
| 1120 |
+
st.audio(
|
| 1121 |
+
uploaded_audio, format=f'audio/{uploaded_audio.name.split(".")[-1]}'
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
# Analyze button
|
| 1125 |
+
if st.button("π Run Fused Analysis", type="primary", use_container_width=True):
|
| 1126 |
+
if text_input or uploaded_audio or uploaded_image:
|
| 1127 |
+
with st.spinner("Running fused sentiment analysis..."):
|
| 1128 |
+
# Prepare inputs
|
| 1129 |
+
audio_bytes = uploaded_audio.getvalue() if uploaded_audio else None
|
| 1130 |
+
image = Image.open(uploaded_image) if uploaded_image else None
|
| 1131 |
+
|
| 1132 |
+
# Get fused prediction
|
| 1133 |
+
sentiment, confidence = predict_fused_sentiment(
|
| 1134 |
+
text=text_input if text_input else None,
|
| 1135 |
+
audio_bytes=audio_bytes,
|
| 1136 |
+
image=image,
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
# Display results
|
| 1140 |
+
st.markdown("### Fused Model Results")
|
| 1141 |
+
|
| 1142 |
+
col1, col2 = st.columns(2)
|
| 1143 |
+
with col1:
|
| 1144 |
+
st.metric("Final Sentiment", sentiment)
|
| 1145 |
+
with col2:
|
| 1146 |
+
st.metric("Overall Confidence", f"{confidence:.2f}")
|
| 1147 |
+
|
| 1148 |
+
# Show individual model results
|
| 1149 |
+
st.markdown("### Individual Model Results")
|
| 1150 |
+
|
| 1151 |
+
results_data = []
|
| 1152 |
+
|
| 1153 |
+
if text_input:
|
| 1154 |
+
text_sentiment, text_conf = predict_text_sentiment(text_input)
|
| 1155 |
+
results_data.append(
|
| 1156 |
+
{
|
| 1157 |
+
"Model": "Text (TextBlob) β
",
|
| 1158 |
+
"Input": f"Text: {text_input[:50]}...",
|
| 1159 |
+
"Sentiment": text_sentiment,
|
| 1160 |
+
"Confidence": f"{text_conf:.2f}",
|
| 1161 |
+
}
|
| 1162 |
+
)
|
| 1163 |
+
|
| 1164 |
+
if uploaded_audio:
|
| 1165 |
+
audio_sentiment, audio_conf = predict_audio_sentiment(audio_bytes)
|
| 1166 |
+
results_data.append(
|
| 1167 |
+
{
|
| 1168 |
+
"Model": "Audio (Wav2Vec2) β
",
|
| 1169 |
+
"Input": f"Audio: {audio_name}",
|
| 1170 |
+
"Sentiment": audio_sentiment,
|
| 1171 |
+
"Confidence": f"{audio_conf:.2f}",
|
| 1172 |
+
}
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
if uploaded_image:
|
| 1176 |
+
# Face cropping is set to 0% (no padding) for tightest crop
|
| 1177 |
+
vision_sentiment, vision_conf = predict_vision_sentiment(
|
| 1178 |
+
image, crop_tightness=0.0
|
| 1179 |
+
)
|
| 1180 |
+
results_data.append(
|
| 1181 |
+
{
|
| 1182 |
+
"Model": "Vision (ResNet-50)",
|
| 1183 |
+
"Input": f"Image: {uploaded_image.name}",
|
| 1184 |
+
"Sentiment": vision_sentiment,
|
| 1185 |
+
"Confidence": f"{vision_conf:.2f}",
|
| 1186 |
+
}
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
if results_data:
|
| 1190 |
+
df = pd.DataFrame(results_data)
|
| 1191 |
+
st.dataframe(df, use_container_width=True)
|
| 1192 |
+
|
| 1193 |
+
# Final result display
|
| 1194 |
+
sentiment_colors = {"Positive": "π’", "Negative": "π΄", "Neutral": "π‘"}
|
| 1195 |
+
|
| 1196 |
+
st.markdown(
|
| 1197 |
+
f"""
|
| 1198 |
+
<div class="result-box">
|
| 1199 |
+
<h4>{sentiment_colors.get(sentiment, "β")} Final Fused Sentiment: {sentiment}</h4>
|
| 1200 |
+
<p><strong>Overall Confidence:</strong> {confidence:.2f}</p>
|
| 1201 |
+
<p><strong>Models Used:</strong> {len(results_data)}</p>
|
| 1202 |
+
</div>
|
| 1203 |
+
""",
|
| 1204 |
+
unsafe_allow_html=True,
|
| 1205 |
+
)
|
| 1206 |
+
else:
|
| 1207 |
+
st.warning(
|
| 1208 |
+
"β οΈ Please provide at least one input (text, audio, or image) for fused analysis."
|
| 1209 |
+
)
|
| 1210 |
+
|
| 1211 |
+
# Footer
|
| 1212 |
+
st.markdown("---")
|
| 1213 |
+
st.markdown(
|
| 1214 |
+
"""
|
| 1215 |
+
<div style="text-align: center; color: #666; padding: 1rem;">
|
| 1216 |
+
<p>Built with β€οΈ | by <a href="https://github.com/iamfaham">iamfaham</a></p>
|
| 1217 |
+
</div>
|
| 1218 |
+
""",
|
| 1219 |
+
unsafe_allow_html=True,
|
| 1220 |
+
)
|
models/audio_sentiment_analysis.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/vision_sentiment_analysis.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sentiment-fused"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.9"
|
| 7 |
+
dependencies = []
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.28.0
|
| 2 |
+
pandas>=1.5.0
|
| 3 |
+
Pillow>=9.0.0
|
| 4 |
+
numpy>=1.21.0
|
| 5 |
+
textblob>=0.17.0
|
| 6 |
+
torch>=1.13.0
|
| 7 |
+
torchvision>=0.14.0
|
| 8 |
+
transformers>=4.21.0
|
| 9 |
+
librosa>=0.9.0
|
| 10 |
+
soundfile>=0.12.0
|
| 11 |
+
opencv-python>=4.5.0
|
| 12 |
+
accelerate>=0.20.0
|
run_app.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Startup script for the Sentiment Analysis Testing Ground Streamlit application.
|
| 4 |
+
This script provides an easy way to launch the application with proper configuration.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
"""Main function to start the Streamlit application."""
|
| 14 |
+
|
| 15 |
+
print("π§ Starting Sentiment Analysis Testing Ground...")
|
| 16 |
+
print("=" * 50)
|
| 17 |
+
|
| 18 |
+
# Check if app.py exists
|
| 19 |
+
if not os.path.exists("app.py"):
|
| 20 |
+
print("β Error: app.py not found in current directory!")
|
| 21 |
+
print("Please make sure you're in the correct directory.")
|
| 22 |
+
sys.exit(1)
|
| 23 |
+
|
| 24 |
+
# Check if requirements are installed
|
| 25 |
+
try:
|
| 26 |
+
import streamlit
|
| 27 |
+
import pandas
|
| 28 |
+
import PIL
|
| 29 |
+
|
| 30 |
+
print("β
Dependencies check passed")
|
| 31 |
+
except ImportError as e:
|
| 32 |
+
print(f"β Missing dependency: {e}")
|
| 33 |
+
print("Please install requirements: pip install -r requirements.txt")
|
| 34 |
+
sys.exit(1)
|
| 35 |
+
|
| 36 |
+
print("π Launching Streamlit application...")
|
| 37 |
+
print("π± The app will open in your default browser")
|
| 38 |
+
print("π If it doesn't open automatically, go to: http://localhost:8501")
|
| 39 |
+
print("βΉοΈ Press Ctrl+C to stop the application")
|
| 40 |
+
print("=" * 50)
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
# Start Streamlit with the app
|
| 44 |
+
subprocess.run(
|
| 45 |
+
[
|
| 46 |
+
sys.executable,
|
| 47 |
+
"-m",
|
| 48 |
+
"streamlit",
|
| 49 |
+
"run",
|
| 50 |
+
"app.py",
|
| 51 |
+
"--server.headless",
|
| 52 |
+
"false",
|
| 53 |
+
"--server.port",
|
| 54 |
+
"8501",
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
except KeyboardInterrupt:
|
| 58 |
+
print("\nπ Application stopped by user")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"β Error starting application: {e}")
|
| 61 |
+
sys.exit(1)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
test_audio_model.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script for the Wav2Vec2 audio sentiment analysis model
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import librosa
|
| 10 |
+
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
| 11 |
+
import tempfile
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_audio_model():
|
| 15 |
+
"""Test the audio model loading and inference"""
|
| 16 |
+
|
| 17 |
+
print("π Testing Wav2Vec2 Audio Sentiment Model")
|
| 18 |
+
print("=" * 50)
|
| 19 |
+
|
| 20 |
+
# Check if model file exists
|
| 21 |
+
model_path = "models/wav2vec2_model.pth"
|
| 22 |
+
if not os.path.exists(model_path):
|
| 23 |
+
print(f"β Audio model file not found at: {model_path}")
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
print(f"β
Found model file: {model_path}")
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
# Set device
|
| 30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
print(f"π₯οΈ Using device: {device}")
|
| 32 |
+
|
| 33 |
+
# Load the model checkpoint to check architecture
|
| 34 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 35 |
+
print(f"π Checkpoint keys: {list(checkpoint.keys())}")
|
| 36 |
+
|
| 37 |
+
# Check for classifier weights
|
| 38 |
+
if "classifier.weight" in checkpoint:
|
| 39 |
+
num_classes = checkpoint["classifier.weight"].shape[0]
|
| 40 |
+
print(f"π Model has {num_classes} output classes")
|
| 41 |
+
else:
|
| 42 |
+
print("β οΈ Could not determine number of classes from checkpoint")
|
| 43 |
+
num_classes = 3 # Default assumption
|
| 44 |
+
|
| 45 |
+
# Initialize model
|
| 46 |
+
print("π Initializing Wav2Vec2 model...")
|
| 47 |
+
model_checkpoint = "facebook/wav2vec2-base"
|
| 48 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
| 49 |
+
model_checkpoint, num_labels=num_classes
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Load trained weights
|
| 53 |
+
print("π Loading trained weights...")
|
| 54 |
+
model.load_state_dict(checkpoint)
|
| 55 |
+
model.to(device)
|
| 56 |
+
model.eval()
|
| 57 |
+
|
| 58 |
+
print("β
Model loaded successfully!")
|
| 59 |
+
|
| 60 |
+
# Test with dummy audio
|
| 61 |
+
print("π§ͺ Testing inference with dummy audio...")
|
| 62 |
+
|
| 63 |
+
# Create dummy audio (1 second of random noise at 16kHz)
|
| 64 |
+
dummy_audio = np.random.randn(16000).astype(np.float32)
|
| 65 |
+
|
| 66 |
+
# Load feature extractor
|
| 67 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
| 68 |
+
|
| 69 |
+
# Preprocess audio
|
| 70 |
+
inputs = feature_extractor(
|
| 71 |
+
dummy_audio,
|
| 72 |
+
sampling_rate=16000,
|
| 73 |
+
max_length=80000, # 5 seconds * 16000 Hz
|
| 74 |
+
truncation=True,
|
| 75 |
+
padding="max_length",
|
| 76 |
+
return_tensors="pt",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Move to device
|
| 80 |
+
input_values = inputs.input_values.to(device)
|
| 81 |
+
|
| 82 |
+
# Run inference
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
outputs = model(input_values)
|
| 85 |
+
probabilities = torch.softmax(outputs.logits, dim=1)
|
| 86 |
+
confidence, predicted = torch.max(probabilities, 1)
|
| 87 |
+
|
| 88 |
+
print(f"π Model output shape: {outputs.logits.shape}")
|
| 89 |
+
print(f"π― Predicted class: {predicted.item()}")
|
| 90 |
+
print(f"π Confidence: {confidence.item():.3f}")
|
| 91 |
+
print(f"π All probabilities: {probabilities.squeeze().cpu().numpy()}")
|
| 92 |
+
|
| 93 |
+
# Sentiment mapping
|
| 94 |
+
sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
|
| 95 |
+
predicted_sentiment = sentiment_map.get(
|
| 96 |
+
predicted.item(), f"Class_{predicted.item()}"
|
| 97 |
+
)
|
| 98 |
+
print(f"π Predicted sentiment: {predicted_sentiment}")
|
| 99 |
+
|
| 100 |
+
print("β
Audio model test completed successfully!")
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"β Error testing audio model: {str(e)}")
|
| 105 |
+
import traceback
|
| 106 |
+
|
| 107 |
+
traceback.print_exc()
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def check_audio_model_file():
|
| 112 |
+
"""Check the audio model file details"""
|
| 113 |
+
|
| 114 |
+
print("\nπ Audio Model File Analysis")
|
| 115 |
+
print("=" * 30)
|
| 116 |
+
|
| 117 |
+
model_path = "models/wav2vec2_model.pth"
|
| 118 |
+
if not os.path.exists(model_path):
|
| 119 |
+
print(f"β Model file not found: {model_path}")
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
# File size
|
| 123 |
+
file_size = os.path.getsize(model_path) / (1024 * 1024) # MB
|
| 124 |
+
print(f"π File size: {file_size:.1f} MB")
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
# Load checkpoint
|
| 128 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 129 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 130 |
+
|
| 131 |
+
print(f"π Checkpoint keys ({len(checkpoint)} total):")
|
| 132 |
+
for key, value in checkpoint.items():
|
| 133 |
+
if isinstance(value, torch.Tensor):
|
| 134 |
+
print(f" - {key}: {value.shape} ({value.dtype})")
|
| 135 |
+
else:
|
| 136 |
+
print(f" - {key}: {type(value)}")
|
| 137 |
+
|
| 138 |
+
# Check classifier
|
| 139 |
+
if "classifier.weight" in checkpoint:
|
| 140 |
+
num_classes = checkpoint["classifier.weight"].shape[0]
|
| 141 |
+
print(f"\nπ― Classifier output classes: {num_classes}")
|
| 142 |
+
print(
|
| 143 |
+
f"π Classifier weight shape: {checkpoint['classifier.weight'].shape}"
|
| 144 |
+
)
|
| 145 |
+
if "classifier.bias" in checkpoint:
|
| 146 |
+
print(
|
| 147 |
+
f"π Classifier bias shape: {checkpoint['classifier.bias'].shape}"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Check wav2vec2 base model
|
| 151 |
+
if "wav2vec2.feature_extractor.conv_layers.0.conv.weight" in checkpoint:
|
| 152 |
+
print(f"π Wav2Vec2 base model: Present")
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"β Error analyzing checkpoint: {str(e)}")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
print("π Starting Wav2Vec2 Audio Model Tests")
|
| 160 |
+
print("=" * 60)
|
| 161 |
+
|
| 162 |
+
# Check model file
|
| 163 |
+
check_audio_model_file()
|
| 164 |
+
|
| 165 |
+
print("\n" + "=" * 60)
|
| 166 |
+
|
| 167 |
+
# Test model loading and inference
|
| 168 |
+
success = test_audio_model()
|
| 169 |
+
|
| 170 |
+
if success:
|
| 171 |
+
print("\nπ All audio model tests passed!")
|
| 172 |
+
else:
|
| 173 |
+
print("\nπ₯ Audio model tests failed!")
|
test_vision_model.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script for the vision sentiment analysis model.
|
| 4 |
+
This script verifies that the ResNet-50 model can be loaded and run inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torchvision import transforms, models
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_sentiment_mapping(num_classes):
|
| 17 |
+
"""Get the sentiment mapping based on number of classes"""
|
| 18 |
+
if num_classes == 3:
|
| 19 |
+
return {0: "Negative", 1: "Neutral", 2: "Positive"}
|
| 20 |
+
elif num_classes == 4:
|
| 21 |
+
# Common 4-class emotion mapping
|
| 22 |
+
return {0: "Angry", 1: "Sad", 2: "Happy", 3: "Neutral"}
|
| 23 |
+
elif num_classes == 7:
|
| 24 |
+
# FER2013 7-class emotion mapping
|
| 25 |
+
return {0: "Angry", 1: "Disgust", 2: "Fear", 3: "Happy", 4: "Sad", 5: "Surprise", 6: "Neutral"}
|
| 26 |
+
else:
|
| 27 |
+
# Generic mapping for unknown number of classes
|
| 28 |
+
return {i: f"Class_{i}" for i in range(num_classes)}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_vision_model():
|
| 32 |
+
"""Test the vision sentiment analysis model"""
|
| 33 |
+
|
| 34 |
+
print("π§ Testing Vision Sentiment Analysis Model")
|
| 35 |
+
print("=" * 50)
|
| 36 |
+
|
| 37 |
+
# Check if model file exists
|
| 38 |
+
model_path = "models/resnet50_model.pth"
|
| 39 |
+
if not os.path.exists(model_path):
|
| 40 |
+
print(f"β Model file not found: {model_path}")
|
| 41 |
+
print("Please ensure the model file exists in the models/ directory")
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
print(f"β
Model file found: {model_path}")
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
# Load the model weights first to check the architecture
|
| 48 |
+
print("π₯ Loading model checkpoint...")
|
| 49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 51 |
+
|
| 52 |
+
# Check the number of classes from the checkpoint
|
| 53 |
+
if 'fc.weight' in checkpoint:
|
| 54 |
+
num_classes = checkpoint['fc.weight'].shape[0]
|
| 55 |
+
print(f"π Model checkpoint has {num_classes} output classes")
|
| 56 |
+
else:
|
| 57 |
+
# Fallback: try to infer from the last layer
|
| 58 |
+
num_classes = 3 # Default assumption
|
| 59 |
+
print("β οΈ Could not determine number of classes from checkpoint, assuming 3")
|
| 60 |
+
|
| 61 |
+
# Initialize ResNet-50 model with the correct number of classes
|
| 62 |
+
print("π§ Initializing ResNet-50 model...")
|
| 63 |
+
model = models.resnet50(weights=None) # Don't load ImageNet weights
|
| 64 |
+
num_ftrs = model.fc.in_features
|
| 65 |
+
model.fc = nn.Linear(num_ftrs, num_classes) # Use actual number of classes
|
| 66 |
+
|
| 67 |
+
print(f"π₯ Loading trained weights for {num_classes} classes...")
|
| 68 |
+
model.load_state_dict(checkpoint)
|
| 69 |
+
model.to(device)
|
| 70 |
+
model.eval()
|
| 71 |
+
|
| 72 |
+
print(f"β
Model loaded successfully with {num_classes} classes!")
|
| 73 |
+
print(f"π₯οΈ Using device: {device}")
|
| 74 |
+
|
| 75 |
+
# Test with a dummy image
|
| 76 |
+
print("π§ͺ Testing inference with dummy image...")
|
| 77 |
+
|
| 78 |
+
# Create a dummy image (224x224 RGB)
|
| 79 |
+
dummy_image = Image.fromarray(
|
| 80 |
+
np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Apply transforms
|
| 84 |
+
transform = transforms.Compose(
|
| 85 |
+
[
|
| 86 |
+
transforms.Resize(224),
|
| 87 |
+
transforms.CenterCrop(224),
|
| 88 |
+
transforms.ToTensor(),
|
| 89 |
+
transforms.Normalize(
|
| 90 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 91 |
+
),
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
image_tensor = transform(dummy_image).unsqueeze(0).to(device)
|
| 96 |
+
|
| 97 |
+
# Run inference
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
outputs = model(image_tensor)
|
| 100 |
+
print(f"π Model output shape: {outputs.shape}")
|
| 101 |
+
|
| 102 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
| 103 |
+
confidence, predicted = torch.max(probabilities, 1)
|
| 104 |
+
|
| 105 |
+
# Get sentiment mapping based on number of classes
|
| 106 |
+
sentiment_map = get_sentiment_mapping(num_classes)
|
| 107 |
+
sentiment = sentiment_map[predicted.item()]
|
| 108 |
+
confidence_score = confidence.item()
|
| 109 |
+
|
| 110 |
+
print(f"π― Test prediction: {sentiment} (confidence: {confidence_score:.3f})")
|
| 111 |
+
print(f"π Available classes: {list(sentiment_map.values())}")
|
| 112 |
+
print("β
Inference test passed!")
|
| 113 |
+
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"β Error testing model: {str(e)}")
|
| 118 |
+
import traceback
|
| 119 |
+
traceback.print_exc()
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def main():
|
| 124 |
+
"""Main function"""
|
| 125 |
+
success = test_vision_model()
|
| 126 |
+
|
| 127 |
+
if success:
|
| 128 |
+
print("\nπ All tests passed! The vision model is ready to use.")
|
| 129 |
+
print("You can now run the Streamlit app with: streamlit run app.py")
|
| 130 |
+
else:
|
| 131 |
+
print("\nπ₯ Tests failed. Please check the error messages above.")
|
| 132 |
+
sys.exit(1)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
main()
|
uv.lock
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version = 1
|
| 2 |
+
revision = 2
|
| 3 |
+
requires-python = ">=3.9"
|
| 4 |
+
|
| 5 |
+
[[package]]
|
| 6 |
+
name = "sentiment-fused"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
source = { virtual = "." }
|