From 0b295af80696e736933b0348a95be4dfcd2e3408 Mon Sep 17 00:00:00 2001 From: Ramvignesh Pasupathy Date: Wed, 10 Feb 2021 21:32:39 +0530 Subject: [PATCH] replaced print() w/ progbar for model training --- train.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/train.py b/train.py index 854e8d8..9c436bc 100644 --- a/train.py +++ b/train.py @@ -67,30 +67,25 @@ def valid_step(images, labels): # start training for epoch in range(config.EPOCHS): + print (f"\nEpoch: {epoch + 1}/{config.EPOCHS}") + train_loss.reset_states() train_accuracy.reset_states() valid_loss.reset_states() valid_accuracy.reset_states() - step = 0 - for images, labels in train_dataset: - step += 1 + + progbar = tf.keras.utils.Progbar(len(train_dataset), stateful_metrics=["train_loss", "train_acc", "val_loss", "val_acc"]) + + for idx, (images, labels) in enumerate(train_dataset): train_step(images, labels) - print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch + 1, - config.EPOCHS, - step, - math.ceil(train_count / config.BATCH_SIZE), - train_loss.result(), - train_accuracy.result())) + + values = [("train_loss", train_loss.result()), ("train_acc", train_accuracy.result())] + progbar.update(idx+1, values=values) for valid_images, valid_labels in valid_dataset: valid_step(valid_images, valid_labels) - print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, " - "valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch + 1, - config.EPOCHS, - train_loss.result(), - train_accuracy.result(), - valid_loss.result(), - valid_accuracy.result())) + values = [ ("train_loss", train_loss.result()), ("train_acc", train_accuracy.result()), ("val_loss", valid_loss.result()), ("val_acc", valid_accuracy.result()) ] + progbar.update(idx+1, values=values) model.save_weights(filepath=config.save_model_dir, save_format='tf')