Python采用keras训练GAN生成对抗网络输出手写数字代码
代码语言:python
所属分类:人工智能
代码描述:Python采用keras训练GAN生成对抗网络输出手写数字代码
代码标签: Python keras 训练 GAN 生成 对抗 网络 输出 手写 数字 代码
下面为部分代码预览,完整代码请点击下载或在bfwstudio webide中打开
#!/usr/local/python3/bin/python3 # -*- coding: utf-8 -* from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Dense, Dropout, Input from tensorflow.keras.models import Model,Sequential from tensorflow.python.keras.layers.advanced_activations import LeakyReLU from tensorflow.keras.optimizers import Adam from tqdm import tqdm import numpy as np import matplotlib.pyplot as plt # from google.colab import drive # drive.mount('/content/gdrive') # path = 'gdrive/My Drive/Project/Practice/Result_GAN/' #导入数据集 def load_data(): (x_train, y_train), (_, _) = mnist.load_data() x_train = (x_train.astype(np.float32) - 127.5)/127.5 x_train = x_train.reshape(60000, 784) return (x_train, y_train) X_train, y_train = load_data() print(X_train.shape, y_train.shape) def build_generator(): model = Sequential() model.add(Dense(units=256, input_dim=100)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(units=512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(units=1024)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(units=784, activation='tanh')) model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5)) return model generator = build_generator() generator.summary() def build_discriminator(): model = Sequential() model.add(Dense(units=1024 ,input_dim=784)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.3)) model.add(Dense(units=512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.3)) model.add(Dense(units=256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.3)) model.add(Dense(units=1, activation='sigmoid')) model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5)) return model discriminator = build_discriminator() discriminator.summary() def build_GAN(discriminator, generator): discriminator.trainable=False GAN_input = Input(shape=(100,)) x = generator(GAN_input) GAN_output= discriminator(x) GAN = Model(inputs=GAN_input, outputs=GAN_output) GAN.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5)) return GAN #建立一个GAN网络,GAN由两个神经网络(generator, discriminator)连接而成。 GAN = build_GAN(discriminator, generator) GAN.summary() def draw_images(generator, epoch, examples=25, dim=(5,5), figsize=(10,10)): noise= np.random.normal(loc=0, scale=1, size=[examples, 100]) generated_images = generator.predict(noise) generated_images = generated_images.reshape(25,28,28) plt.figure(figsize=figsize) for i in range(generated_images.shape[0]): plt.subplot(dim[0], dim[1], i+1) plt.imshow(generated_images[i], interpolation='nearest', cmap='Greys') plt.axis('off') plt.tight_layout() plt.savefig('/data/wwwroot/default/Data/image_at_epoch_%d.png' %epoch) def train_GAN(epochs=1, batch_size=128): #Loading the data X_train, y_train = load_data() # Creating GAN generator= build_generator() discriminator= buil.........完整代码请登录后点击上方下载按钮下载查看
网友评论0