Spaces:
Build error
Build error
| import streamlit as st | |
| from flask.Emotion_spotting_service import _Emotion_spotting_service | |
| from flask.Genre_spotting_service import _Genre_spotting_service | |
| from flask.Beat_tracking_service import _Beat_tracking_service | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| import os | |
| import logging | |
| import psutil | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def print_memory_info(): | |
| # Get free CPU memory | |
| virtual_mem = psutil.virtual_memory() | |
| free_cpu_mem = virtual_mem.available / (1024 ** 3) # Convert bytes to GB | |
| # Get free GPU memory | |
| free_gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) - torch.cuda.memory_reserved(0) / (1024 ** 3) # Convert bytes to GB | |
| logger.info(f"Free CPU Memory: {free_cpu_mem:.2f} GB") | |
| logger.info(f"Free GPU Memory: {free_gpu_mem:.2f} GB") | |
| emo_list = [] | |
| gen_list = [] | |
| tempo_list = [] | |
| def load_emo_model(): | |
| emo_service = _Emotion_spotting_service("flask/emotion_model.h5") | |
| return emo_service | |
| def load_genre_model(): | |
| gen_service = _Genre_spotting_service("flask/Genre_classifier_model.h5") | |
| return gen_service | |
| def load_beat_model(): | |
| beat_service = _Beat_tracking_service() | |
| return beat_service | |
| # @st.cache_resource | |
| def load_image_model(): | |
| print_memory_info() | |
| torch.cuda.empty_cache() | |
| pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", variant='fp16') | |
| pipeline.to("cuda") | |
| pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors") | |
| return pipeline | |
| if 'emotion' not in st.session_state: | |
| st.session_state.emotion = None | |
| if 'genre' not in st.session_state: | |
| st.session_state.genre = None | |
| if 'beat' not in st.session_state: | |
| st.session_state.beat = None | |
| emotion_service = load_emo_model() | |
| genre_service = load_genre_model() | |
| beat_service = load_beat_model() | |
| image_service = load_image_model() | |
| st.title("Music2Image webpage") | |
| user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader") | |
| st.caption("Generate images from your audio file") | |
| st.audio(user_input) | |
| c1,c2,c3 = st.columns([1,1,1]) | |
| with c1: | |
| if st.button("Generate emotion"): | |
| emotion = emotion_service.predict(user_input) | |
| st.session_state.emotion = emotion | |
| st.text(st.session_state.emotion) | |
| with c2: | |
| if st.button("Generate genre"): | |
| genre = genre_service.predict(user_input) | |
| st.session_state.genre = genre | |
| st.text(st.session_state.genre) | |
| with c3: | |
| if st.button("Generate beat"): | |
| beat = beat_service.get_beat(user_input) | |
| st.session_state.beat = beat | |
| st.text(st.session_state.beat) | |
| if st.session_state.emotion != None and st.session_state.genre != None and st.session_state.beat != None: | |
| text_output = None | |
| if st.button("Generate text description to be fed into stable diffusion"): | |
| st.caption("Text description of your music file") | |
| text_output = "This piece of music falls under the " + st.session_state.genre[0] + " genre. It is of tempo " + str(int(st.session_state.beat)) + " and evokes a sense of" + st.session_state.emotion + "." | |
| st.text(text_output) | |
| # if text_output: | |
| # if st.button("Generate image from text description"): | |
| # image = image_service(text_output) | |
| # st.image(image) | |