Permalink
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
App_segmentacion_imagenes_vision_transformer/app.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
195 lines (161 sloc)
7.31 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import os | |
from PIL import Image | |
#Importamos pytorch | |
import torch | |
#Transformadores de imagenes para preprocesar la entrada | |
from torchvision import transforms | |
#Cargadores de Datos que introducen los datos al modelo | |
from torch.utils.data import DataLoader | |
#OpenCV para recortar la imagen al final | |
import cv2 as cv | |
#Numpy para trabajar con arrays y matplotlib para ver las imagenes | |
import numpy as np | |
import matplotlib.pyplot as plt | |
#Para cargar las imagenes | |
from PIL import Image | |
import appUtils | |
from libs.vit_seg_modeling import VisionTransformer as ViT_seg | |
from libs.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg | |
####### CODIGO ####### | |
#### CONSTANTES #### | |
MODELS = { | |
"Reducido - 1 epoch" : "./models/1 epochs - 3072 - 4 - 8.pth", | |
"Reducido - 5 epoch" : "./models/5 epochs - 3072 - 4 - 8.pth", | |
"Completo - 1 epoch" : "./models/1 epochs - 3072 - 12 - 12.pth", | |
"Completo - 6 epoch" : "./models/6 epochs - 3072 - 12 - 12.pth"} | |
#### VARIABLES #### | |
MODEL = None | |
BACKGROUND = None | |
THRESHOLD = None | |
#### Funciones #### | |
def cargar_modelo(modelo): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if modelo == list(MODELS.keys())[0]: | |
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16'] | |
config_vit.n_classes = 1 | |
config_vit.n_skip = 3 | |
## Dropout en la seccion del transformer | |
config_vit.transformer.dropout_rate = 0.2 | |
#FFN tras attention heads | |
config_vit.transformer.mlp_dim = 3072 | |
## Numero de attention heads | |
config_vit.transformer.num_heads = 4 | |
## Numero de capas de attentions heads | |
config_vit.transformer.num_layers = 8 | |
MODEL = appUtils.VisionTransformerModel(config_vit).to(device) | |
MODEL.load_state_dict(torch.load(MODELS[modelo],map_location=device)) | |
elif modelo == list(MODELS.keys())[1]: | |
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16'] | |
config_vit.n_classes = 1 | |
config_vit.n_skip = 3 | |
## Dropout en la seccion del transformer | |
config_vit.transformer.dropout_rate = 0.2 | |
#FFN tras attention heads | |
config_vit.transformer.mlp_dim = 3072 | |
## Numero de attention heads | |
config_vit.transformer.num_heads = 4 | |
## Numero de capas de attentions heads | |
config_vit.transformer.num_layers = 8 | |
MODEL = appUtils.VisionTransformerModel(config_vit).to(device) | |
MODEL.load_state_dict(torch.load(MODELS[modelo],map_location=device)) | |
elif modelo == list(MODELS.keys())[2]: | |
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16'] | |
config_vit.n_classes = 1 | |
config_vit.n_skip = 3 | |
## Dropout en la seccion del transformer | |
config_vit.transformer.dropout_rate = 0.2 | |
#FFN tras attention heads | |
config_vit.transformer.mlp_dim = 3072 | |
## Numero de attention heads | |
config_vit.transformer.num_heads = 12 | |
## Numero de capas de attentions heads | |
config_vit.transformer.num_layers = 12 | |
MODEL = appUtils.VisionTransformerModel(config_vit).to(device) | |
MODEL.load_state_dict(torch.load(MODELS[modelo],map_location=device)) | |
elif modelo == list(MODELS.keys())[3]: | |
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16'] | |
config_vit.n_classes = 1 | |
config_vit.n_skip = 3 | |
## Dropout en la seccion del transformer | |
config_vit.transformer.dropout_rate = 0.2 | |
#FFN tras attention heads | |
config_vit.transformer.mlp_dim = 3072 | |
## Numero de attention heads | |
config_vit.transformer.num_heads = 12 | |
## Numero de capas de attentions heads | |
config_vit.transformer.num_layers = 12 | |
MODEL = appUtils.VisionTransformerModel(config_vit).to(device) | |
MODEL.load_state_dict(torch.load(MODELS[modelo],map_location=device)) | |
return MODEL | |
def procesar_imagenes(dataloader,modelo, destaca=True, threshold=60, background=[255, 255, 255]): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
postprocessed = [] | |
for img in dataloader: | |
pred_mask = modelo(img.to(device)).squeeze(1).cpu().detach() | |
img2 = appUtils.tensor_to_image_3(img[0]) | |
pred_mask2 = appUtils.tensor_to_image_1(pred_mask[0]) | |
img2 = np.array(img2) | |
pred_mask2 = np.array(pred_mask2) | |
if destaca: | |
postprocessed.append(appUtils.highlight_object(img2, pred_mask2)) | |
else: | |
postprocessed.append(appUtils.cut_object(img2, pred_mask2, threshold, background)) | |
return postprocessed | |
#### Aplicacion #### | |
# Interfaz de Streamlit | |
st.title("Segmentación de imágenes usando vision transformers") | |
modelo_seleccionado = st.selectbox("Selecciona un modelo", list(MODELS.keys())) | |
if modelo_seleccionado: | |
MODEL = cargar_modelo(modelo_seleccionado) | |
st.header("Carga tus imágenes") | |
imagenes_cargadas = st.file_uploader("Sube una o más imágenes", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
if imagenes_cargadas: | |
st.write("Imágenes cargadas:") | |
for imagen in imagenes_cargadas: | |
st.image(Image.open(imagen), caption=imagen.name) | |
st.header("Configura el procesamiento") | |
accion = st.radio("Selecciona una acción", options=["destacar", "recortar"]) | |
# Opciones para recortar | |
if accion == "recortar": | |
st.markdown( | |
""" | |
#### Recortar | |
##### Umbral | |
Al realizar la predicción el modelo marca en escala de grises la segmentación, siendo el negro (0), una seguridad total de que ese pixel pertenece al objeto y siendo blanco (255) una seguridad nula | |
##### Fondo | |
El color de fondo es el color que se le da a la parte que no cumple los requisitos del umbral. | |
""") | |
THRESHOLD = st.slider("Selecciona el umbral.", min_value=0, max_value=255, value=60, step=1) | |
color_picker = st.color_picker("Selecciona el color de fondo.", value="#FFFFFF") | |
BACKGROUND = tuple(int(color_picker[i:i+2], 16) for i in (1, 3, 5)) | |
else: | |
st.markdown( | |
""" | |
#### Destacar | |
Con esta opcion el modelo realizara una prediccion en escala de grises de la segmentacion. Esta segmentacion se superpondrá a la imagen original para ver donde pone el modelo la atención.") | |
""") | |
st.header("Selecciona la carpeta de destino para imágenes procesadas") | |
carpeta_destino = st.text_input("Carpeta de destino:", value="") | |
if st.button("Procesar imágenes"): | |
if imagenes_cargadas: | |
if accion == "recortar": | |
procesadas = procesar_imagenes(DataLoader(appUtils.TestDataset(imagenes_cargadas, appUtils.get_conf()["TRANSFORM"]),batch_size=1), | |
MODEL, | |
destaca=False, | |
threshold=THRESHOLD, | |
background=BACKGROUND) | |
else: | |
procesadas = procesar_imagenes(DataLoader(appUtils.TestDataset(imagenes_cargadas, appUtils.get_conf()["TRANSFORM"]),batch_size=1), MODEL) | |
st.write("Imágenes procesadas:") | |
for imagen in procesadas: | |
st.image(Image.fromarray(imagen)) | |
if len(carpeta_destino) > 1: | |
for i, img_array in enumerate(procesadas): | |
img = Image.fromarray(img_array) | |
ruta_imagen = os.path.join(carpeta_destino, f"imagen_{i + 1}.png") | |
img.save(ruta_imagen) | |
else: | |
st.warning("Las imagenes no se estan guardando") | |
else: | |
st.warning("No hay imágenes cargadas para procesar.") |