Skip to content
Permalink
803fccfd50
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
Latest commit 9a86ad2 Dec 2, 2024 History
1 contributor

Users who have contributed to this file

195 lines (161 sloc) 7.31 KB
import streamlit as st
import os
from PIL import Image
#Importamos pytorch
import torch
#Transformadores de imagenes para preprocesar la entrada
from torchvision import transforms
#Cargadores de Datos que introducen los datos al modelo
from torch.utils.data import DataLoader
#OpenCV para recortar la imagen al final
import cv2 as cv
#Numpy para trabajar con arrays y matplotlib para ver las imagenes
import numpy as np
import matplotlib.pyplot as plt
#Para cargar las imagenes
from PIL import Image
import appUtils
from libs.vit_seg_modeling import VisionTransformer as ViT_seg
from libs.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
####### CODIGO #######
#### CONSTANTES ####
MODELS = {
"Reducido - 1 epoch" : "./models/1 epochs - 3072 - 4 - 8.pth",
"Reducido - 5 epoch" : "./models/5 epochs - 3072 - 4 - 8.pth",
"Completo - 1 epoch" : "./models/1 epochs - 3072 - 12 - 12.pth",
"Completo - 6 epoch" : "./models/6 epochs - 3072 - 12 - 12.pth"}
#### VARIABLES ####
MODEL = None
BACKGROUND = None
THRESHOLD = None
#### Funciones ####
def cargar_modelo(modelo):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if modelo == list(MODELS.keys())[0]:
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = 1
config_vit.n_skip = 3
## Dropout en la seccion del transformer
config_vit.transformer.dropout_rate = 0.2
#FFN tras attention heads
config_vit.transformer.mlp_dim = 3072
## Numero de attention heads
config_vit.transformer.num_heads = 4
## Numero de capas de attentions heads
config_vit.transformer.num_layers = 8
MODEL = appUtils.VisionTransformerModel(config_vit).to(device)
MODEL.load_state_dict(torch.load(MODELS[modelo],map_location=device))
elif modelo == list(MODELS.keys())[1]:
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = 1
config_vit.n_skip = 3
## Dropout en la seccion del transformer
config_vit.transformer.dropout_rate = 0.2
#FFN tras attention heads
config_vit.transformer.mlp_dim = 3072
## Numero de attention heads
config_vit.transformer.num_heads = 4
## Numero de capas de attentions heads
config_vit.transformer.num_layers = 8
MODEL = appUtils.VisionTransformerModel(config_vit).to(device)
MODEL.load_state_dict(torch.load(MODELS[modelo],map_location=device))
elif modelo == list(MODELS.keys())[2]:
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = 1
config_vit.n_skip = 3
## Dropout en la seccion del transformer
config_vit.transformer.dropout_rate = 0.2
#FFN tras attention heads
config_vit.transformer.mlp_dim = 3072
## Numero de attention heads
config_vit.transformer.num_heads = 12
## Numero de capas de attentions heads
config_vit.transformer.num_layers = 12
MODEL = appUtils.VisionTransformerModel(config_vit).to(device)
MODEL.load_state_dict(torch.load(MODELS[modelo],map_location=device))
elif modelo == list(MODELS.keys())[3]:
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = 1
config_vit.n_skip = 3
## Dropout en la seccion del transformer
config_vit.transformer.dropout_rate = 0.2
#FFN tras attention heads
config_vit.transformer.mlp_dim = 3072
## Numero de attention heads
config_vit.transformer.num_heads = 12
## Numero de capas de attentions heads
config_vit.transformer.num_layers = 12
MODEL = appUtils.VisionTransformerModel(config_vit).to(device)
MODEL.load_state_dict(torch.load(MODELS[modelo],map_location=device))
return MODEL
def procesar_imagenes(dataloader,modelo, destaca=True, threshold=60, background=[255, 255, 255]):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
postprocessed = []
for img in dataloader:
pred_mask = modelo(img.to(device)).squeeze(1).cpu().detach()
img2 = appUtils.tensor_to_image_3(img[0])
pred_mask2 = appUtils.tensor_to_image_1(pred_mask[0])
img2 = np.array(img2)
pred_mask2 = np.array(pred_mask2)
if destaca:
postprocessed.append(appUtils.highlight_object(img2, pred_mask2))
else:
postprocessed.append(appUtils.cut_object(img2, pred_mask2, threshold, background))
return postprocessed
#### Aplicacion ####
# Interfaz de Streamlit
st.title("Segmentación de imágenes usando vision transformers")
modelo_seleccionado = st.selectbox("Selecciona un modelo", list(MODELS.keys()))
if modelo_seleccionado:
MODEL = cargar_modelo(modelo_seleccionado)
st.header("Carga tus imágenes")
imagenes_cargadas = st.file_uploader("Sube una o más imágenes", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
if imagenes_cargadas:
st.write("Imágenes cargadas:")
for imagen in imagenes_cargadas:
st.image(Image.open(imagen), caption=imagen.name)
st.header("Configura el procesamiento")
accion = st.radio("Selecciona una acción", options=["destacar", "recortar"])
# Opciones para recortar
if accion == "recortar":
st.markdown(
"""
#### Recortar
##### Umbral
Al realizar la predicción el modelo marca en escala de grises la segmentación, siendo el negro (0), una seguridad total de que ese pixel pertenece al objeto y siendo blanco (255) una seguridad nula
##### Fondo
El color de fondo es el color que se le da a la parte que no cumple los requisitos del umbral.
""")
THRESHOLD = st.slider("Selecciona el umbral.", min_value=0, max_value=255, value=60, step=1)
color_picker = st.color_picker("Selecciona el color de fondo.", value="#FFFFFF")
BACKGROUND = tuple(int(color_picker[i:i+2], 16) for i in (1, 3, 5))
else:
st.markdown(
"""
#### Destacar
Con esta opcion el modelo realizara una prediccion en escala de grises de la segmentacion. Esta segmentacion se superpondrá a la imagen original para ver donde pone el modelo la atención.")
""")
st.header("Selecciona la carpeta de destino para imágenes procesadas")
carpeta_destino = st.text_input("Carpeta de destino:", value="")
if st.button("Procesar imágenes"):
if imagenes_cargadas:
if accion == "recortar":
procesadas = procesar_imagenes(DataLoader(appUtils.TestDataset(imagenes_cargadas, appUtils.get_conf()["TRANSFORM"]),batch_size=1),
MODEL,
destaca=False,
threshold=THRESHOLD,
background=BACKGROUND)
else:
procesadas = procesar_imagenes(DataLoader(appUtils.TestDataset(imagenes_cargadas, appUtils.get_conf()["TRANSFORM"]),batch_size=1), MODEL)
st.write("Imágenes procesadas:")
for imagen in procesadas:
st.image(Image.fromarray(imagen))
if len(carpeta_destino) > 1:
for i, img_array in enumerate(procesadas):
img = Image.fromarray(img_array)
ruta_imagen = os.path.join(carpeta_destino, f"imagen_{i + 1}.png")
img.save(ruta_imagen)
else:
st.warning("Las imagenes no se estan guardando")
else:
st.warning("No hay imágenes cargadas para procesar.")