class GeneratorUpsampleConv(nn.Module):
"""Генератор DCGAN: Upsample + Conv. Вход: [B, latent_dim, 1, 1], выход: [B, 1, 28, 28]."""
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input ~ [ bs, latent_dim_size, 1, 1 ]
nn.Upsample( [ 7, 7 ] ), # [ bs, generator_num_features*8, 7, 7 ]
nn.Conv2d(latent_dim_size, generator_num_features * 8, kernel_size=3, padding=1),
nn.ReLU(True),
nn.Upsample( [ 14, 14 ] ), # [ bs, generator_num_features, 14, 14 ]
nn.BatchNorm2d(generator_num_features * 8),
nn.Conv2d(generator_num_features * 8, generator_num_features, kernel_size=3, padding=1),
nn.ReLU(True),
nn.Upsample( [ 28, 28 ] ), # [ bs, generator_num_features, 28, 28 ]
nn.BatchNorm2d(generator_num_features),
nn.Conv2d(generator_num_features, 1, kernel_size=3, padding=1),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
class GeneratorConvTranspose2d(nn.Module):
"""Генератор DCGAN на ConvTranspose2d. Вход: [B, latent_dim, 1, 1], выход: [B, 1, 28, 28]."""
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input ~ [ bs, latent_dim_size, 1, 1 ]
nn.ConvTranspose2d( latent_dim_size, generator_num_features * 8, kernel_size=7, bias=False), # [ bs, generator_num_features*8, 7, 7 ]
nn.ReLU(True),
nn.BatchNorm2d(generator_num_features * 8),
nn.ConvTranspose2d(generator_num_features * 8, generator_num_features, kernel_size=4, stride=2, padding=1, bias=False), # ~ [ bs, generator_num_features, 14, 14 ]
nn.ReLU(True),
nn.BatchNorm2d(generator_num_features),
nn.ConvTranspose2d( generator_num_features, 1, kernel_size=4, stride=2, padding=1, bias=False), # ~ [ bs, generator_num_features, 28, 28 ]
nn.Tanh()
)
def forward(self, input):
return self.main(input)