forked from sail-sg/EditAnything
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sam2semantic.py
174 lines (150 loc) · 6.51 KB
/
sam2semantic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
# pip install mmcv
from torchvision.utils import save_image
from PIL import Image
import subprocess
from collections import OrderedDict
import numpy as np
import cv2
import textwrap
import torch
import os
from annotator.util import resize_image, HWC3
import mmcv
import random
# device = "cuda" if torch.cuda.is_available() else "cpu" # > 15GB GPU memory required
device = "cpu"
use_blip = True
use_gradio = True
if device == 'cpu':
data_type = torch.float32
else:
data_type = torch.float16
# Diffusion init using diffusers.
# diffusers==0.14.0 required.
from diffusers.utils import load_image
base_model_path = "stabilityai/stable-diffusion-2-inpainting"
config_dict = OrderedDict([('SAM Pretrained(v0-1): Good Natural Sense', 'shgao/edit-anything-v0-1-1'),
('LAION Pretrained(v0-3): Good Face', 'shgao/edit-anything-v0-3'),
('SD Inpainting: Not keep position', 'stabilityai/stable-diffusion-2-inpainting')
])
# Segment-Anything init.
# pip install git+https://github.com/facebookresearch/segment-anything.git
try:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
except ImportError:
print('segment_anything not installed')
result = subprocess.run(['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
print(f'Install segment_anything {result}')
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
print(f'Download sam_vit_h_4b8939.pth {result}')
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
# BLIP2 init.
if use_blip:
# need the latest transformers
# pip install git+https://github.com/huggingface/transformers.git
from transformers import AutoProcessor, Blip2ForConditionalGeneration
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=data_type)
def region_classify_w_blip2(image):
inputs = processor(image, return_tensors="pt").to(device, data_type)
generated_ids = blip_model.generate(**inputs, max_new_tokens=15)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True)[0].strip()
return generated_text
def region_level_semantic_api(image, topk=5):
"""
rank regions by area, and classify each region with blip2
Args:
image: numpy array
topk: int
Returns:
topk_region_w_class_label: list of dict with key 'class_label'
"""
topk_region_w_class_label = []
anns = mask_generator.generate(image)
if len(anns) == 0:
return []
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
for i in range(min(topk, len(sorted_anns))):
ann = anns[i]
m = ann['segmentation']
m_3c = m[:,:, np.newaxis]
m_3c = np.concatenate((m_3c,m_3c,m_3c), axis=2)
bbox = ann['bbox']
region = mmcv.imcrop(image*m_3c, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=1)
region_class_label = region_classify_w_blip2(region)
ann['class_label'] = region_class_label
print(ann['class_label'], str(bbox))
topk_region_w_class_label.append(ann)
return topk_region_w_class_label
def show_semantic_image_label(anns):
"""
show semantic image label for each region
Args:
anns: list of dict with key 'class_label'
Returns:
full_img: numpy array
"""
full_img = None
# generate mask image
for i in range(len(anns)):
m = anns[i]['segmentation']
if full_img is None:
full_img = np.zeros((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
full_img[m != 0] = color_mask
full_img = full_img*255
# add text on this mask image
for i in range(len(anns)):
m = anns[i]['segmentation']
class_label = anns[i]['class_label']
# add text to region
# Calculate the centroid of the region to place the text
y, x = np.where(m != 0)
x_center, y_center = int(np.mean(x)), int(np.mean(y))
# Split the text into multiple lines
max_width = 20 # Adjust this value based on your preferred maximum width
wrapped_text = textwrap.wrap(class_label, width=max_width)
# Add text to region
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.2
font_thickness = 2
font_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) # red
line_spacing = 40 # Adjust this value based on your preferred line
for idx, line in enumerate(wrapped_text):
y_offset = y_center - (len(wrapped_text) - 1) * line_spacing // 2 + idx * line_spacing
text_size = cv2.getTextSize(line, font, font_scale, font_thickness)[0]
x_offset = x_center - text_size[0] // 2
# Draw the text multiple times with small offsets to create a bolder appearance
offsets = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
for off_x, off_y in offsets:
cv2.putText(full_img, line, (x_offset + off_x, y_offset + off_y), font, font_scale, font_color, font_thickness, cv2.LINE_AA)
return full_img
image_path = "images/sa_224577.jpg"
input_image = Image.open(image_path)
detect_resolution=1024
input_image = resize_image(np.array(input_image, dtype=np.uint8), detect_resolution)
region_level_annots = region_level_semantic_api(input_image, topk=5)
output = show_semantic_image_label(region_level_annots)
image_list = []
input_image = resize_image(input_image, 512)
output = resize_image(output, 512)
input_image = np.array(input_image, dtype=np.uint8)
output = np.array(output, dtype=np.uint8)
image_list.append(torch.tensor(input_image).float())
image_list.append(torch.tensor(output).float())
for each in image_list:
print(each.shape, type(each))
print(each.max(), each.min())
image_list = torch.stack(image_list).permute(0, 3, 1, 2)
print(image_list.shape)
save_image(image_list, "images/sample_semantic.jpg", nrow=2,
normalize=True)