공부정리/Computer Vision
[Diffusion] 간단한 diffusion 모델로 포켓몬 학습하기
sillon
2024. 1. 12. 17:39
728x90
반응형
In [ ]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import plotly.io
import plotly.graph_objects as go
# colab: coloab
# jupyter lab: jupyterlab
# jupyter notebook, quarto blog: notebook
plotly.io.renderers.default = "notebook"
In [ ]:
# 사용 디바이스 세팅
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
Out[ ]:
'cuda'
In [ ]:
sprites = np.load('pokemon.npy')
sprites.shape
Out[ ]:
(819, 256, 256, 3)
In [ ]:
# 새로운 이미지 크기 설정
new_width, new_height = 64, 64
# 다운샘플링된 이미지를 저장할 리스트
resized_images = []
# 각 이미지에 대해 반복
for img in sprites:
# Numpy 배열을 PIL 이미지로 변환
pil_img = Image.fromarray(img)
# 이미지 리사이즈
resized_img = pil_img.resize((new_width, new_height), Image.ANTIALIAS)
# 리사이즈된 이미지를 Numpy 배열로 다시 변환하여 리스트에 추가
resized_images.append(np.array(resized_img))
# 리스트를 Numpy 배열로 변환
resized_sprites = np.array(resized_images)
sprites = resized_sprites
print("Resized shape:", resized_sprites.shape)
/tmp/ipykernel_3942/2188836465.py:13: DeprecationWarning: ANTIALIAS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
Resized shape: (819, 64, 64, 3)
In [ ]:
# 준비된 (64,64)크기의 나비 이미지를 (28,28)로 리사이즈하고
# 픽셀값을 0~1로 노멀라이즈
x0 = (sprites / 255)[12]
# 노멀라이즈 확인
print(x0.min(), x0.max())
0.1411764705882353 1.0
In [ ]:
plt.imshow(x0)
plt.show()
In [ ]:
sprites.shape, sprites.min(), sprites.max()
Out[ ]:
((819, 64, 64, 3), 0, 255)
In [ ]:
H = 64
W = 64
C = 3
In [ ]:
class MyDataset(Dataset):
# beta는 DDPM 원문의 설정을 따르고
beta_1 = 1e-4
beta_T = 0.02
# 시간 단계는 deeplearning.ai에서 제공하는 숏코스 How Diffusion Models Work의 설정을 따름
T = 500
# beta는 첨자 1부터 T까지 사용하기 위해 제일 앞에 더미 데이터 tf.constant([0.])를 추가하여 만듬
beta = torch.cat([ torch.tensor([0]), torch.linspace(beta_1, beta_T, T)], axis=0)
alpha = 1 - beta
alpha_bar = torch.exp(torch.cumsum(torch.log(alpha), axis=0))
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return self.data.shape[0]
def __getitem__(self, i):
x_0 = self.data[i]
# normalize -1~1로 만들기
if self.transform:
x_0 = self.transform(x_0)
# noise 추가
t = np.random.randint(1, MyDataset.T+1)
eps = torch.randn_like(x_0)
x_t = torch.sqrt(MyDataset.alpha_bar[t]) * x_0 + torch.sqrt(1 - MyDataset.alpha_bar[t]) * eps
return x_0, x_t, eps, t
In [ ]:
transform = transforms.Compose([
# transforms.Resize((64, 64)), # 여기서 이미지 크기 조정
transforms.ToTensor(), # from [0,255] to range [0.0, 1.0]
transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
])
train_ds = MyDataset(sprites, transform)
m = 4 # 배치사이즈
train_loader = DataLoader(train_ds, batch_size=m, shuffle=True)
train_loader_iter = iter(train_loader)
In [ ]:
samples = next(train_loader_iter)
x_0s = samples[0][:6].numpy()
x_ts = samples[1][:6].numpy()
epss = samples[2][:6].numpy()
ts = samples[3][:6].numpy()
In [ ]:
fig, axs = plt.subplots(figsize=(10,5), nrows=3, ncols=6)
i = 0
for (x_0, x_t, eps, t) in zip(x_0s, x_ts, epss, ts):
x_0 = x_0.transpose(1,2,0)
x_0 = ((x_0 - x_0.min()) / (x_0.max() - x_0.min())).clip(0,1)
axs[0][i].imshow(x_0)
axs[0][i].set_title(f"t={t}")
axs[0][i].set_xticks([])
axs[0][i].set_yticks([])
eps = eps.transpose(1,2,0)
eps = ((eps - eps.min()) / (eps.max() - eps.min())).clip(0,1)
axs[1][i].imshow(eps)
axs[1][i].set_xticks([])
axs[1][i].set_yticks([])
x_t = x_t.transpose(1,2,0)
x_t = ((x_t - x_t.min()) / (x_t.max() - x_t.min())).clip(0,1)
axs[2][i].imshow(x_t)
axs[2][i].set_xticks([])
axs[2][i].set_yticks([])
if i == 0:
axs[0][i].set_ylabel('x_0')
axs[1][i].set_ylabel('eps')
axs[2][i].set_ylabel('x_t')
i += 1
plt.show()
In [ ]:
class DDPM(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb_1 = torch.nn.Linear(in_features=1, out_features=32)
self.emb_2 = torch.nn.Linear(in_features=32, out_features=64)
self.down_conv1_32 = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
self.down_conv2_32 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
self.down_conv3_64 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.down_conv4_128 = torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.up_conv1_64 = torch.nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1)
self.up_conv2_32 = torch.nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1)
self.up_conv3_32 = torch.nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1)
self.relu = torch.nn.ReLU()
self.gelu = torch.nn.GELU()
def forward(self, x, t):
# x: (N, C, H, W)
# t: (N,1)
batch_size = t.shape[0]
# time embedding
t = self.relu( self.emb_1(t) ) # (N, 32)
t = self.relu( self.emb_2(t) ) # (N, 64)
t = t.reshape(batch_size, -1, 1, 1) # (N, 64, 1, 1)
# image down conv
x = self.gelu( self.down_conv1_32(x) ) # (N, 32, 16, 16)
x_32 = self.gelu( self.down_conv2_32(x) ) # (N, 32, 16, 16)
size_32 = x_32.shape
x = torch.nn.functional.max_pool2d(x_32, (2,2)) # (N, 32, 8, 8)
x = self.gelu( self.down_conv3_64(x) ) # (N, 64, 8, 8)
size_64 = x.shape
x = torch.nn.functional.max_pool2d(x, (2,2)) # (N, 64, 4, 4)
x = x + t # (N, 64, 4, 4) + (N, 64, 1, 1) = (N, 64, 4, 4)
x = self.gelu( self.down_conv4_128(x) ) # (N, 128, 4, 4)
# image up conv
x = self.gelu( self.up_conv1_64(x, output_size=size_64) ) # (N, 64, 8, 8)
x = self.gelu( self.up_conv2_32(x, output_size=size_32) ) # (N, 32, 16, 16)
x = torch.cat([x, x_32], dim=1) # (N, 64, 16, 16)
out = self.up_conv3_32(x) # (N, 3, 16, 16)
return out
In [ ]:
model = DDPM()
model = nn.DataParallel(model)
model.to(device)
Out[ ]:
DataParallel( (module): DDPM( (emb_1): Linear(in_features=1, out_features=32, bias=True) (emb_2): Linear(in_features=32, out_features=64, bias=True) (down_conv1_32): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (down_conv2_32): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (down_conv3_64): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (down_conv4_128): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (up_conv1_64): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (up_conv2_32): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (up_conv3_32): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu): ReLU() (gelu): GELU(approximate='none') ) )
In [ ]:
output = model(samples[0].to(device), samples[3].reshape(-1,1).float().to(device))
print(output.shape)
torch.Size([4, 3, 64, 64])
Train¶
In [ ]:
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
In [ ]:
epochs = 1000
losses = []
for e in range(epochs):
epoch_loss = 0.0
epoch_mae = 0.0
for i, data in enumerate(tqdm(train_loader)):
x_0, x_t, eps, t = data
x_t = x_t.to(device)
eps = eps.to(device)
t = t.to(device)
optimizer.zero_grad()
eps_theta = model(x_t, t.reshape(-1,1).float())
loss = loss_func(eps_theta, eps)
loss.backward()
optimizer.step()
with torch.no_grad():
epoch_loss += loss.item()
epoch_mae += torch.nn.functional.l1_loss(eps_theta, eps)
epoch_loss /= len(train_loader)
epoch_mae /= len(train_loader)
print(f"Epoch: {e+1:2d}: loss:{epoch_loss:.4f}, mae:{epoch_mae:.4f}")
losses.append(epoch_loss)
0%| | 0/205 [00:00<?, ?it/s]100%|██████████| 205/205 [00:04<00:00, 44.35it/s]
Epoch: 1: loss:0.7763, mae:0.6940
100%|██████████| 205/205 [00:04<00:00, 51.07it/s]
Epoch: 2: loss:0.2555, mae:0.3651
100%|██████████| 205/205 [00:03<00:00, 51.34it/s]
Epoch: 1000: loss:0.0530, mae:0.1376
In [ ]:
plt.plot(losses)
plt.show()
Sampling¶
In [ ]:
alpha = MyDataset.alpha.to(device)
alpha_bar = MyDataset.alpha_bar.to(device)
beta = MyDataset.beta.to(device)
T = MyDataset.T
In [ ]:
# 샘플링 단계동안 생성된 이미지를 일정 간격마다 저장할 리스트를 준비
interval = 20 # 20 시간 단계마다 한장씩 생성 결과 기록
X = [] # 생성 이미지 저장
saved_frame = [] # 이미지를 저장한 시간 단계를 저장
N = 5 # 모델에 입력할 샘플 개수
# 최초 노이즈 샘플링
x = torch.randn(size=(N, C, H, W)).to(device)
for t in range(T, 0, -1):
if t > 1:
z = torch.randn(size=(N,C,H,W)).to(device)
else:
z = torch.zeros((N,C,H,W)).to(device)
t_torch = torch.tensor([[t]]*N, dtype=torch.float32).to(device)
eps_theta = model(x, t_torch)
x = (1 / torch.sqrt(alpha[t])) * \
(x - ((1-alpha[t])/torch.sqrt(1-alpha_bar[t]))*eps_theta) + torch.sqrt(beta[t])*z
if (T - t) % interval == 0 or t == 1:
# 현재 시간 단계로 부터 생성되는 t-1번째 이미지를 저장
saved_frame.append(t)
x_np = x.detach().cpu().numpy()
# (N,C,H,W)->(H,N,W,C)
x_np = x_np.transpose(2,0,3,1).reshape(H,-1,C)
x_np = ((x_np - x_np.min()) / (x_np.max() - x_np.min())).clip(0,1)
X.append( x_np*255.0 ) # 0 ~ 1 -> 0 ~ 255
X = np.array(X, dtype=np.uint8)
In [ ]:
fig = go.Figure(
data = [ go.Image(z=X[0]) ],
layout = go.Layout(
# title="Generated image",
autosize = False,
width = 800, height = 400,
margin = dict(l=0, r=0, b=0, t=30),
xaxis = {"title": f"Generated Image: x_{T-1}"},
updatemenus = [
dict(
type="buttons",
buttons=[
# play button
dict(
label="Play", method="animate",
args=[
None,
{
"frame": {"duration": 50, "redraw": True},
"fromcurrent": True,
"transition": {"duration": 50, "easing": "quadratic-in-out"}
}
]
),
# pause button
dict(
label="Pause", method="animate",
args=[
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0}
}
]
)
],
direction="left", pad={"r": 10, "t": 87}, showactive=False,
x=0.1, xanchor="right", y=0, yanchor="top"
)
], # updatemenus = [
), # layout = go.Layout(
frames = [
{
'data':[go.Image(z=X[t])],
'name': t,
'layout': {
'xaxis': {'title': f"Generated Image: x_{saved_frame[t]-1}"}
}
} for t in range(len(X))
]
)
################################################################################
# 슬라이더 처리
sliders_dict = {
"active": 0, "yanchor": "top", "xanchor": "left",
"currentvalue": {
"font": {"size": 15}, "prefix": "input time:",
"visible": True, "xanchor": "right"
},
"transition": {"duration": 100, "easing": "cubic-in-out"},
"pad": {"b": 10, "t": 50},
"len": 0.9, "x": 0.1, "y": 0,
"steps": []
}
for t in range(len(X)):
slider_step = {
"label": f"{saved_frame[t]}", "method": "animate",
"args": [
[t], # frame 이름과 일치해야 연결됨
{
"frame": {"duration": 100, "redraw": True},
"mode": "immediate",
"transition": {"duration": 100}
}
],
}
sliders_dict["steps"].append(slider_step)
fig["layout"]["sliders"] = [sliders_dict]
################################################################################
fig.show()
In [ ]:
최종 결과물
원본 코드는 https://metamath1.github.io/blog/posts/diffusion/ddpm_part2-2.html?utm_source=pytorchkr
728x90
반응형