[Diffusion] 간단한 diffusion 모델로 포켓몬 학습하기

2024. 1. 12. 17:39·공부정리/Computer Vision
728x90
반응형
diffusion_jam_test
In [ ]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm import tqdm

import plotly.io
import plotly.graph_objects as go

# colab: coloab
# jupyter lab: jupyterlab
# jupyter notebook, quarto blog: notebook
plotly.io.renderers.default = "notebook"
In [ ]:
# 사용 디바이스 세팅
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
Out[ ]:
'cuda'
In [ ]:
sprites = np.load('pokemon.npy')
sprites.shape
Out[ ]:
(819, 256, 256, 3)
In [ ]:
# 새로운 이미지 크기 설정
new_width, new_height = 64, 64

# 다운샘플링된 이미지를 저장할 리스트
resized_images = []

# 각 이미지에 대해 반복
for img in sprites:
    # Numpy 배열을 PIL 이미지로 변환
    pil_img = Image.fromarray(img)

    # 이미지 리사이즈
    resized_img = pil_img.resize((new_width, new_height), Image.ANTIALIAS)

    # 리사이즈된 이미지를 Numpy 배열로 다시 변환하여 리스트에 추가
    resized_images.append(np.array(resized_img))

# 리스트를 Numpy 배열로 변환
resized_sprites = np.array(resized_images)
sprites = resized_sprites
print("Resized shape:", resized_sprites.shape)
/tmp/ipykernel_3942/2188836465.py:13: DeprecationWarning:

ANTIALIAS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.

Resized shape: (819, 64, 64, 3)
In [ ]:
# 준비된 (64,64)크기의 나비 이미지를 (28,28)로 리사이즈하고
# 픽셀값을 0~1로 노멀라이즈
x0 = (sprites / 255)[12]

# 노멀라이즈 확인
print(x0.min(), x0.max())
0.1411764705882353 1.0
In [ ]:
plt.imshow(x0)
plt.show()
In [ ]:
sprites.shape, sprites.min(), sprites.max()
Out[ ]:
((819, 64, 64, 3), 0, 255)
In [ ]:
H = 64
W = 64
C = 3
In [ ]:
class MyDataset(Dataset):
    # beta는 DDPM 원문의 설정을 따르고
    beta_1 = 1e-4
    beta_T = 0.02
    # 시간 단계는 deeplearning.ai에서 제공하는 숏코스 How Diffusion Models Work의 설정을 따름
    T = 500

    # beta는 첨자 1부터 T까지 사용하기 위해 제일 앞에 더미 데이터 tf.constant([0.])를 추가하여 만듬
    beta = torch.cat([ torch.tensor([0]), torch.linspace(beta_1, beta_T, T)], axis=0)
    alpha = 1 - beta
    alpha_bar = torch.exp(torch.cumsum(torch.log(alpha), axis=0))

    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, i):
        x_0 = self.data[i]

        # normalize -1~1로 만들기
        if self.transform:
            x_0 = self.transform(x_0)

        # noise 추가
        t = np.random.randint(1, MyDataset.T+1)
        eps = torch.randn_like(x_0)
        x_t = torch.sqrt(MyDataset.alpha_bar[t]) * x_0 + torch.sqrt(1 - MyDataset.alpha_bar[t]) * eps

        return x_0, x_t, eps, t
In [ ]:
transform = transforms.Compose([
    # transforms.Resize((64, 64)),  # 여기서 이미지 크기 조정
    transforms.ToTensor(),                # from [0,255] to range [0.0, 1.0]
    transforms.Normalize((0.5,), (0.5,))  # range [-1,1]
])

train_ds = MyDataset(sprites, transform)

m = 4 # 배치사이즈

train_loader = DataLoader(train_ds, batch_size=m, shuffle=True)
train_loader_iter = iter(train_loader)
In [ ]:
samples = next(train_loader_iter)

x_0s = samples[0][:6].numpy()
x_ts = samples[1][:6].numpy()
epss = samples[2][:6].numpy()
ts =  samples[3][:6].numpy()
In [ ]:
fig, axs = plt.subplots(figsize=(10,5), nrows=3, ncols=6)

i = 0
for (x_0, x_t, eps, t) in zip(x_0s, x_ts, epss, ts):
    x_0 = x_0.transpose(1,2,0)
    x_0 = ((x_0 - x_0.min()) / (x_0.max() - x_0.min())).clip(0,1)
    axs[0][i].imshow(x_0)
    axs[0][i].set_title(f"t={t}")
    axs[0][i].set_xticks([])
    axs[0][i].set_yticks([])

    eps = eps.transpose(1,2,0)
    eps = ((eps - eps.min()) / (eps.max() - eps.min())).clip(0,1)
    axs[1][i].imshow(eps)
    axs[1][i].set_xticks([])
    axs[1][i].set_yticks([])

    x_t = x_t.transpose(1,2,0)
    x_t = ((x_t - x_t.min()) / (x_t.max() - x_t.min())).clip(0,1)
    axs[2][i].imshow(x_t)
    axs[2][i].set_xticks([])
    axs[2][i].set_yticks([])

    if i == 0:
        axs[0][i].set_ylabel('x_0')
        axs[1][i].set_ylabel('eps')
        axs[2][i].set_ylabel('x_t')

    i += 1

plt.show()
In [ ]:
class DDPM(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.emb_1 = torch.nn.Linear(in_features=1, out_features=32)
        self.emb_2 = torch.nn.Linear(in_features=32, out_features=64)

        self.down_conv1_32 = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.down_conv2_32 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.down_conv3_64 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.down_conv4_128 = torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)

        self.up_conv1_64 = torch.nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.up_conv2_32 = torch.nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.up_conv3_32 = torch.nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1)

        self.relu = torch.nn.ReLU()
        self.gelu = torch.nn.GELU()

    def forward(self, x, t):
        # x: (N, C, H, W)
        # t: (N,1)
        batch_size = t.shape[0]

        # time embedding
        t = self.relu( self.emb_1(t) ) # (N, 32)
        t = self.relu( self.emb_2(t) ) # (N, 64)
        t = t.reshape(batch_size, -1, 1, 1) # (N, 64, 1, 1)

        # image down conv
        x = self.gelu( self.down_conv1_32(x) )    # (N, 32, 16, 16)
        x_32 = self.gelu( self.down_conv2_32(x) ) # (N, 32, 16, 16)
        size_32 = x_32.shape
        x = torch.nn.functional.max_pool2d(x_32, (2,2)) # (N, 32, 8, 8)
        x = self.gelu( self.down_conv3_64(x) ) # (N, 64, 8, 8)
        size_64 = x.shape
        x = torch.nn.functional.max_pool2d(x, (2,2)) # (N, 64, 4, 4)

        x = x + t # (N, 64, 4, 4) + (N, 64, 1, 1) = (N, 64, 4, 4)
        x = self.gelu( self.down_conv4_128(x) ) # (N, 128, 4, 4)

        # image up conv
        x = self.gelu( self.up_conv1_64(x, output_size=size_64) ) # (N, 64, 8, 8)
        x = self.gelu( self.up_conv2_32(x, output_size=size_32) ) # (N, 32, 16, 16)
        x = torch.cat([x, x_32], dim=1) # (N, 64, 16, 16)
        
        out = self.up_conv3_32(x) # (N, 3, 16, 16)

        return out
In [ ]:
model = DDPM()
model = nn.DataParallel(model)
model.to(device)
Out[ ]:
DataParallel(
  (module): DDPM(
    (emb_1): Linear(in_features=1, out_features=32, bias=True)
    (emb_2): Linear(in_features=32, out_features=64, bias=True)
    (down_conv1_32): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_conv2_32): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_conv3_64): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_conv4_128): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (up_conv1_64): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (up_conv2_32): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (up_conv3_32): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu): ReLU()
    (gelu): GELU(approximate='none')
  )
)
In [ ]:
output = model(samples[0].to(device), samples[3].reshape(-1,1).float().to(device))

print(output.shape)
torch.Size([4, 3, 64, 64])

Train¶

In [ ]:
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
In [ ]:
epochs = 1000
losses = []

for e in range(epochs):
    epoch_loss = 0.0
    epoch_mae = 0.0

    for i, data in enumerate(tqdm(train_loader)):
        x_0, x_t, eps, t = data
        x_t = x_t.to(device)
        eps = eps.to(device)
        t = t.to(device)

        optimizer.zero_grad()
        eps_theta = model(x_t, t.reshape(-1,1).float())
        loss = loss_func(eps_theta, eps)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            epoch_loss += loss.item()
            epoch_mae += torch.nn.functional.l1_loss(eps_theta, eps)

    epoch_loss /= len(train_loader)
    epoch_mae /= len(train_loader)

    print(f"Epoch: {e+1:2d}: loss:{epoch_loss:.4f}, mae:{epoch_mae:.4f}")
    losses.append(epoch_loss)
  0%|          | 0/205 [00:00<?, ?it/s]100%|██████████| 205/205 [00:04<00:00, 44.35it/s]
Epoch:  1: loss:0.7763, mae:0.6940
100%|██████████| 205/205 [00:04<00:00, 51.07it/s]
Epoch:  2: loss:0.2555, mae:0.3651
100%|██████████| 205/205 [00:03<00:00, 51.34it/s]
Epoch: 1000: loss:0.0530, mae:0.1376

In [ ]:
plt.plot(losses)
plt.show()

Sampling¶

In [ ]:
alpha = MyDataset.alpha.to(device)
alpha_bar = MyDataset.alpha_bar.to(device)
beta = MyDataset.beta.to(device)
T = MyDataset.T
In [ ]:
# 샘플링 단계동안 생성된 이미지를 일정 간격마다 저장할 리스트를 준비
interval = 20 # 20 시간 단계마다 한장씩 생성 결과 기록
X = [] # 생성 이미지 저장
saved_frame = [] # 이미지를 저장한 시간 단계를 저장
N = 5 # 모델에 입력할 샘플 개수


# 최초 노이즈 샘플링
x = torch.randn(size=(N, C, H, W)).to(device)

for t in range(T, 0, -1):
    if t > 1:
        z = torch.randn(size=(N,C,H,W)).to(device)
    else:
        z = torch.zeros((N,C,H,W)).to(device)

    t_torch = torch.tensor([[t]]*N, dtype=torch.float32).to(device)
    eps_theta = model(x, t_torch)
    x = (1 / torch.sqrt(alpha[t])) * \
        (x - ((1-alpha[t])/torch.sqrt(1-alpha_bar[t]))*eps_theta) + torch.sqrt(beta[t])*z

    if (T - t) % interval == 0  or t == 1:
        # 현재 시간 단계로 부터 생성되는 t-1번째 이미지를 저장
        saved_frame.append(t)
        x_np = x.detach().cpu().numpy()

        # (N,C,H,W)->(H,N,W,C)
        x_np = x_np.transpose(2,0,3,1).reshape(H,-1,C)
        x_np = ((x_np - x_np.min()) / (x_np.max() - x_np.min())).clip(0,1)
        X.append( x_np*255.0 ) # 0 ~ 1 -> 0 ~ 255

X = np.array(X, dtype=np.uint8)
In [ ]:
fig = go.Figure(
    data = [ go.Image(z=X[0]) ],
    layout = go.Layout(
        # title="Generated image",
        autosize = False,
        width = 800, height = 400,
        margin = dict(l=0, r=0, b=0, t=30),
        xaxis = {"title": f"Generated Image: x_{T-1}"},
        updatemenus = [
            dict(
                type="buttons",
                buttons=[
                    # play button
                    dict(
                        label="Play", method="animate",
                        args=[
                            None,
                            {
                                "frame": {"duration": 50, "redraw": True},
                                "fromcurrent": True,
                                "transition": {"duration": 50, "easing": "quadratic-in-out"}
                            }
                        ]
                    ),
                    # pause button
                    dict(
                        label="Pause", method="animate",
                        args=[
                            [None],
                            {
                                "frame": {"duration": 0, "redraw": False},
                                "mode": "immediate",
                                "transition": {"duration": 0}
                            }
                        ]
                    )
                ],
                direction="left", pad={"r": 10, "t": 87}, showactive=False,
                x=0.1, xanchor="right", y=0, yanchor="top"
            )
        ], # updatemenus = [
    ), # layout = go.Layout(
    frames = [
        {
            'data':[go.Image(z=X[t])],
            'name': t,
            'layout': {
                'xaxis': {'title': f"Generated Image: x_{saved_frame[t]-1}"}
            }
        } for t in range(len(X))
    ]
)

################################################################################
# 슬라이더 처리
sliders_dict = {
    "active": 0, "yanchor": "top", "xanchor": "left",
    "currentvalue": {
        "font": {"size": 15}, "prefix": "input time:",
        "visible": True, "xanchor": "right"
    },
    "transition": {"duration": 100, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 50},
    "len": 0.9, "x": 0.1, "y": 0,
    "steps": []
}

for t in range(len(X)):
    slider_step = {
        "label": f"{saved_frame[t]}", "method": "animate",
        "args": [
            [t], # frame 이름과 일치해야 연결됨
            {
                "frame": {"duration": 100, "redraw": True},
                "mode": "immediate",
                "transition": {"duration": 100}
            }
        ],
    }

    sliders_dict["steps"].append(slider_step)

fig["layout"]["sliders"] = [sliders_dict]
################################################################################

fig.show()
In [ ]:
 

 

최종 결과물

원본 코드는 https://metamath1.github.io/blog/posts/diffusion/ddpm_part2-2.html?utm_source=pytorchkr

728x90
반응형

'공부정리 > Computer Vision' 카테고리의 다른 글

[핵심 머신러닝]Class Activation Map (CAM) (2) - GradCAM  (0) 2024.01.15
[핵심 머신러닝] Class Activation Map (CAM) (1)  (0) 2024.01.15
이미지 파일을 npy 파일로 변환하고 불러오기  (0) 2024.01.12
생성모델 평가지표 - IS, FID  (0) 2024.01.12
Denoisiong Diffusion Probabilistic Models (2) DDPM Loss Function  (0) 2024.01.12
'공부정리/Computer Vision' 카테고리의 다른 글
  • [핵심 머신러닝]Class Activation Map (CAM) (2) - GradCAM
  • [핵심 머신러닝] Class Activation Map (CAM) (1)
  • 이미지 파일을 npy 파일로 변환하고 불러오기
  • 생성모델 평가지표 - IS, FID
sillon
sillon
꾸준해지려고 합니다..
    반응형
  • sillon
    sillon coding
    sillon
  • 전체
    오늘
    어제
    • menu (614)
      • notice (2)
      • python (68)
        • 자료구조 & 알고리즘 (23)
        • 라이브러리 (19)
        • 기초 (8)
        • 자동화 (14)
        • 보안 (1)
      • coding test - python (301)
        • Programmers (166)
        • 백준 (76)
        • Code Tree (22)
        • 기본기 문제 (37)
      • coding test - C++ (5)
        • Programmers (4)
        • 백준 (1)
        • 기본기문제 (0)
      • 공부정리 (5)
        • 신호처리 시스템 (0)
        • Deep learnig & Machine lear.. (41)
        • Data Science (18)
        • Computer Vision (17)
        • NLP (40)
        • Dacon (2)
        • 모두를 위한 딥러닝 (강의 정리) (4)
        • 모두의 딥러닝 (교재 정리) (9)
        • 통계 (2)
      • HCI (23)
        • Haptics (7)
        • Graphics (11)
        • Arduino (4)
      • Project (21)
        • Web Project (1)
        • App Project (1)
        • Paper Project (1)
        • 캡스톤디자인2 (17)
        • etc (1)
      • OS (10)
        • Ubuntu (9)
        • Rasberry pi (1)
      • App & Web (9)
        • Android (7)
        • javascript (2)
      • C++ (5)
        • 기초 (5)
      • Cloud & SERVER (8)
        • Git (2)
        • Docker (1)
        • DB (4)
      • Paper (7)
        • NLP Paper review (6)
      • 데이터 분석 (0)
        • GIS (0)
      • daily (2)
        • 대학원 준비 (0)
      • 영어공부 (6)
        • job interview (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    Python
    백준
    programmers
    소수
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
sillon
[Diffusion] 간단한 diffusion 모델로 포켓몬 학습하기
상단으로

티스토리툴바