# Setting up checkpoints to save model during training
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
Now we define our training loop
def train(dataset, epochs):
generator_loss_list = []
discriminator_loss_list = []
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
gen_loss, disc_loss = train_step(image_batch)
generator_loss_list.append(gen_loss.numpy())
discriminator_loss_list.append(disc_loss.numpy())
#generate_and_save_images(generator, epoch + 1, seed_images)
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix=checkpoint_prefix)
print(f'Time for epoch {epoch} is {time.time() - start}')
#generate_and_save_images(generator, epochs, seed_images)
loss_file = './data/lossfile.txt'
with open(loss_file, 'w') as outfile:
outfile.write(str(generator_loss_list))
outfile.write('\n')
outfile.write('\n')
outfile.write(str(discriminator_loss_list))
outfile.write('\n')
outfile.write('\n')
To train simply call this function. Warning: this might take a long time so there is a folder of a pretrained network already included in the repository.
train(train_dataset, EPOCHS)
And here is the result of training our model for 100 epochs
Now to avoid having to train and everything, which will take a while depending on your computer setup we now load in the model which produced the above gif.
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
restored_generator = checkpoint.generator
restored_discriminator = checkpoint.discriminator
print(restored_generator)
print(restored_discriminator)