일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- DB
- postgresql
- ML
- html
- PRISMA
- Git
- frontend
- ts
- C++
- Express
- SOLID
- CV
- Three
- API
- UI
- PyTorch
- CSS
- opencv
- GAN
- js
- vscode
- mongo
- react
- Linux
- sqlite
- python
- nodejs
- figma
- review
- ps
- Today
- Total
아카이브
[Pytorch] 조건 적대적 생성 모델(CGAN) 구현하기 - MNIST를 기반으로 본문
다음 블로그의 글을 참고하였습니다.
https://ddongwon.tistory.com/124
[Pytorch] GAN 구현 및 학습
1. 개요 https://github.com/godeastone/GAN-torch Pytorch 로 구현한 GAN 전체 코드는 위 git repository에서 확인할 수 있다. 2. GAN GAN은 2014년 Ian Goodfellow 님에 의해 개발되었다. GAN 논문에 대한 자세한 정보는 아래
ddongwon.tistory.com
이전에 구현했던 적대적 생성 모델(GAN)에 몇 가지 변형을 추가하여 조건 적대적 생성 모델(Condition GAN, CGAN)을 구현하겠습니다.
0. CGAN 개요
CGAN은 GAN과는 달리 위조값 생성의 조건에 해당되는 condition 변수(y)가 추가됩니다.
이로 인해 이전 처럼 무작위적 위조값 생성이 아닌 자신이 원하는 범주에 해당되는 위조값을 생성할 수 있습니다.
조건 y는 generator와 discriminator에 모두 사용되며, noise vector z 와 generator의 결과값인 fake data와 합쳐집니다.
MINST는 0에서 9까지 숫자의 손글씨 에 대한 데이터셋입니다.
따라서 생성하는 이미지에 부여할 수 있는 조건은 숫자의 종류이고, 0에서 9까지 총 10가지 경우를 가집니다.
이를 one-hot-encoding으로 나타내면 condition y는 길이가 10인 1차원 tensor로 나타낼 수 있습니다.
1. 모델 설정
Generator와 Discriminator 모두 첫 번째 계층에서 입력값에 cond_size를 추가합니다.
이 외에는 기존의 GAN과 모두 동일하기 때문에 모델과 hyperparameter에 대한 설명은 생략합니다.
2. 코드 (GAN에서의 변경사항)
3.1. Hyperparameter
# hyperparamter
max_epoch = 200
batch_size = 100
lr = 0.0002
cond_size = 10 ## <newly added> ##
img_size = 784
noise_size = 100
hidden_size1 = 256
hidden_size2 = 512
조건에 해당되는 cond_size 가 추가되었습니다.
3.1. Generator & Discriminator
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(noise_size + cond_size, hidden_size1) ## <newly added> ##
self.linear2 = nn.Linear(hidden_size1, hidden_size2)
self.linear3 = nn.Linear(hidden_size2, img_size)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.tanh(self.linear3(x))
return x
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(img_size + cond_size, hidden_size2) ## <newly added> ##
self.linear2 = nn.Linear(hidden_size2, hidden_size1)
self.linear3 = nn.Linear(hidden_size1, 1)
self.relu = nn.LeakyReLU(0.2)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
Generator와 Discriminator 모두 첫 번째 계층에서 입력값에 cond_size를 추가하기 때문에 해당 길이를 더해줍니다.
3.2. Dataset iteration
1개 batch를 처리할 때 진행되는 흐름입니다.
붉은 부분이 discriminator의 학습 단계이고, 푸른 부분이 generator의 학습 단계입니다.
보라색 부분은 discriminator와 generator가 공통으로 거치는 단계입니다.
GAN과 마찬가지로 CGAN에서도 discriminator와 generator를 따로 학습시킵니다.
순서는 상관 없으나, 둘 중 하나의 모델이 갱신된 후 다른 하나가 갱신되어야 합니다.
달라진 점은
- Dataloader에서 호출되는 label을 조건 tensor인 condition으로 사용하며
- 이를 z, real_images, fake_images에 합쳐 입력값으로 사용한다는 것입니다.
3.2.1. Batch initialization
# 3.7. dataset iteration
for i, (imgs, label) in enumerate(dataloader):
label_real = torch.full((batch_size, 1), 1, dtype=torch.float32).to(device)
label_fake = torch.full((batch_size, 1), 0, dtype=torch.float32).to(device)
condition = nn.functional.one_hot(label, num_classes=10).to(device) # one-hot-encoded label
real_imgs = imgs.reshape(batch_size, -1).to(device) # B * 784
real_imgs_cond = torch.concat((real_imgs, condition), 1)
label의 값을 기반으로 one-hot-encoding 형태의 condition tensor를 생성합니다.
one_hot(tensor, num_classes) | tensor의 각 원소에 대해 num_classes길이의 one-hot-encoded tesnor를 반환한다 | |
tensor | Tensor | one-hot-encoding으로 변환할 요소들의 tensor |
num_classes | int | one-hot-encoding 길이 |
이 condition을 real_imgs tensor의 옆에 붙여 real_imgs_cond를 만드는데, 이 때는 concat( ) 함수를 사용합니다.
concat(tensors, dim) | 입력한 tensor 들을 하나로 합친다 | |
tensors | list | 합칠 tensor들의 list |
dim | int | 합쳐질 차원(방향) |
3.2.2. Generator training
## Generator ##
optim_G.zero_grad()
optim_D.zero_grad()
z = torch.randn(batch_size, noise_size).to(device)
z_cond = torch.concat((z, condition), 1) ## <newly added> ##
fake_imgs = generator(z_cond)
fake_imgs_cond = torch.concat((fake_imgs, condition), 1) ## <newly added> ##
loss_G = loss(discriminator(fake_imgs_cond), label_real)
loss_G.backward()
optim_G.step()
같은 방법으로 condition을 noise vector z와 합쳐 z_cond로 만들고, fake_imgs와 합쳐 fake_imgs_cond로 만듭니다.
3.2.3. Discriminator training
## Discriminator ##
optim_G.zero_grad()
optim_D.zero_grad()
z = torch.randn(batch_size, noise_size).to(device)
z_cond = torch.concat((z, condition), 1) ## <newly added> ##
fake_imgs = generator(z_cond)
fake_imgs_cond = torch.concat((fake_imgs, condition), 1) ## <newly added> ##
loss_fake = loss(discriminator(fake_imgs_cond), label_fake)
loss_real = loss(discriminator(real_imgs_cond), label_real)
loss_D = (loss_fake + loss_real) / 2
loss_D.backward()
optim_D.step()
같은 방법으로 condition을 noise vector z와 합쳐 z_cond로 만들고, fake_imgs와 합쳐 fake_imgs_cond로 만듭니다.
또한 3.2.1.절에서 만든 real_imgs_cond도 사용합니다.
3.3. Show final result
GAN에서는 학습에 따라 데이터셋을 닮은 이미지가 생성되는지만 확인했습니다.
CGAN에서는 실제 조건을 제시하여 그 조건에 맞는 이미지를 올바르게 생성하는지까지 확인해야 합니다.
# show final result
tensor_order = torch.tensor(list(i for i in range(10)))
label_order = nn.functional.one_hot(tensor_order, num_classes=10).to(device)
sample_z = torch.randn(10, noise_size).to(device)
sample_z_cond = torch.concat((sample_z, label_order), 1)
fin_imgs = generator(sample_z_cond)
final = fin_imgs.reshape(10, 1, 28, 28)
save_image(final, os.path.join(res_path, '{} final.png'.format(res_file)))
0부터 9까지의 라벨이 하나씩만 주어진 조건 label_order를 만든 뒤
무작위로 생성된 sample_z와 합쳐 sample_z_cond를 만듭니다.
이를 학습된 generator에 입력하고 라벨에 따라 0에서 부터 9까지의 이미지가 올바르게 생성되는 지 확인합니다.
3.3. Main code
import time
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
# Device setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("current device : {}".format(device))
# Ouput Directory setting
res_file = "CGAN"
res_path = "results/" + res_file
if not os.path.exists(res_path):
os.makedirs(res_path)
# hyperparamter
max_epoch = 200
batch_size = 100
lr = 0.0002
cond_size = 10 ## <newly added> ##
img_size = 784
noise_size = 100
hidden_size1 = 256
hidden_size2 = 512
loss = nn.BCELoss()
# nn.Module : Base class for all neural network modules.
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(noise_size + cond_size, hidden_size1) ## <newly added> ##
self.linear2 = nn.Linear(hidden_size1, hidden_size2)
self.linear3 = nn.Linear(hidden_size2, img_size)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.tanh(self.linear3(x))
return x
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(img_size + cond_size, hidden_size2) ## <newly added> ##
self.linear2 = nn.Linear(hidden_size2, hidden_size1)
self.linear3 = nn.Linear(hidden_size1, 1)
self.relu = nn.LeakyReLU(0.2)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
# Dataset - MNIST
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)])
dataset = MNIST(root='../data',
train=True,
transform=transform,
download=True)
# Dataloader
dataloader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
# Model
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# Optimizer
optim_G = optim.Adam(generator.parameters(), lr=lr)
optim_D = optim.Adam(discriminator.parameters(), lr=lr)
start = time.time()
for epoch in range(max_epoch):
print("epoch: {}".format(epoch+1))
# dataset iteration
for i, (imgs, label) in enumerate(dataloader):
label_real = torch.full((batch_size, 1), 1, dtype=torch.float32).to(device)
label_fake = torch.full((batch_size, 1), 0, dtype=torch.float32).to(device)
condition = nn.functional.one_hot(label, num_classes=10).to(device) # one-hot-encoded label
real_imgs = imgs.reshape(batch_size, -1).to(device) # B * 784
real_imgs_cond = torch.concat((real_imgs, condition), 1)
## Generator ##
optim_G.zero_grad()
optim_D.zero_grad()
z = torch.randn(batch_size, noise_size).to(device)
z_cond = torch.concat((z, condition), 1) ## <newly added> ##
fake_imgs = generator(z_cond)
fake_imgs_cond = torch.concat((fake_imgs, condition), 1) ## <newly added> ##
loss_G = loss(discriminator(fake_imgs_cond), label_real)
loss_G.backward()
optim_G.step()
## Discriminator ##
optim_G.zero_grad()
optim_D.zero_grad()
z = torch.randn(batch_size, noise_size).to(device)
z_cond = torch.concat((z, condition), 1) ## <newly added> ##
fake_imgs = generator(z_cond)
fake_imgs_cond = torch.concat((fake_imgs, condition), 1) ## <newly added> ##
loss_fake = loss(discriminator(fake_imgs_cond), label_fake)
loss_real = loss(discriminator(real_imgs_cond), label_real)
loss_D = (loss_fake + loss_real) / 2
loss_D.backward()
optim_D.step()
performance_D = discriminator(real_imgs_cond).mean()
performance_G = discriminator(fake_imgs_cond).mean()
if (i + 1) % 150 == 0:
print("Epoch [ {}/{} ] Step [ {}/{} ] d_loss : {:.5f} g_loss : {:.5f}"
.format(epoch+1, max_epoch, i+1, len(dataloader), loss_D.item(), loss_G.item()))
# print performance
print(" Epoch {}'s discriminator performance : {:.2f} generator performance : {:.2f}"
.format(epoch, performance_D, performance_G))
# Save fake images in each epoch
res_imgs = fake_imgs.reshape(batch_size, 1, 28, 28)
save_image(res_imgs, os.path.join(res_path, '{} {}.png'.format(res_file, epoch + 1)))
elasped_time = time.time() - start
print("[Done] Time performance : {} s".format(elasped_time))
# show final result
tensor_order = torch.tensor(list(i for i in range(10)))
label_order = nn.functional.one_hot(tensor_order, num_classes=10).to(device)
sample_z = torch.randn(10, noise_size).to(device)
sample_z_cond = torch.concat((sample_z, label_order), 1)
fin_imgs = generator(sample_z_cond)
final = fin_imgs.reshape(10, 1, 28, 28)
save_image(final, os.path.join(res_path, '{} final.png'.format(res_file)))
3. 결과
이미지가 완벽히 생성되지는 않았지만, 제시된 조건(label)에 따라 적절한 모양의 이미지를 생성해낸 것 같습니다.
'Pytorch' 카테고리의 다른 글
[Pytorch] 적대적 생성 모델(GAN)의 hyperparameter값에 따른 변화 (0) | 2023.02.14 |
---|---|
[Pytorch] 적대적 생성 모델(GAN) 구현하기 - MINST를 기반으로 (0) | 2023.02.11 |
[Pytorch] Pytorch 함수 정리 (0) | 2023.02.06 |
[Pytorch] pip로 Pytorch 설치하기 (0) | 2023.01.24 |