fashion mnist 是一个类似于 mnist 的图像数据集. 涵盖 10 种类别的 7 万 (6 万训练集 + 1 万测试集) 个不同商品的图片.

tensorboard 是 tensorflow 的一个可视化工具.

创建 summary

我们可以通过tf.summary.create_file_writer(file_path)来创建一个新的 summary 实例.


  # 将当前时间作为子文件名  current_time = datetime.datetime.now().strftime("%y%m%d-%h%m%s")    # 监听的文件的路径  log_dir = 'logs/' + current_time    # 创建writer  summary_writer = tf.summary.create_file_writer(log_dir)  


通过tf.summary.scalar我们可以向 summary 对象存入数据.


  tf.summary.scalar(      name, data, step=none, description=none  )  


  with summary_writer.as_default():      tf.summary.scalar("train-loss", float(cross_entropy), step=step)  


  tf.keras.metrics.mean(      name='mean', dtype=none  )  


  # 准确率表  loss_meter = tf.keras.metrics.mean()  



  tf.keras.metrics.accuracy(      name='accuracy', dtype=none  )  


  # 损失表  acc_meter = tf.keras.metrics.accuracy()  

变量更新 &重置

我们可以通过update_state来实现变量更新, 通过rest_state来实现变量重置.


  # 跟新损失  loss_meter.update_state(cross_entropy)    # 重置  loss_meter.reset_state()  


pre_process 函数

  def pre_process(x, y):      """      数据预处理      :param x: 特征值      :param y: 目标值      :return: 返回处理好的x, y      """      # 转换x      x = tf.cast(x, tf.float32) / 255      x = tf.reshape(x, [-1, 784])        # 转换y      y = tf.cast(y, dtype=tf.int32)      y = tf.one_hot(y, depth=10)        return x, y  

get_data 函数

  def get_data():      """      获取数据      :return: 返回分批完的训练集和测试集      """        # 获取数据      (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()        # 分割训练集      train_db = tf.data.dataset.from_tensor_slices((x_train, y_train)).shuffle(60000, seed=0)      train_db = train_db.batch(batch_size).map(pre_process)        # 分割测试集      test_db = tf.data.dataset.from_tensor_slices((x_test, y_test)).shuffle(10000, seed=0)      test_db = test_db.batch(batch_size).map(pre_process)        # 返回      return train_db, test_db  

train 函数

  def train(epoch, train_db):      """      训练数据      :param train_db: 分批的数据集      :return: 无返回值      """      for step, (x, y) in enumerate(train_db):          with tf.gradienttape() as tape:                # 获取模型输出结果              logits = model(x)                # 计算交叉熵              cross_entropy = tf.losses.categorical_crossentropy(y, logits, from_logits=true)              cross_entropy = tf.reduce_sum(cross_entropy)                # 跟新损失              loss_meter.update_state(cross_entropy)            # 计算梯度          grads = tape.gradient(cross_entropy, model.trainable_variables)            # 跟新参数          optimizer.apply_gradients(zip(grads, model.trainable_variables))            # 每100批调试输出一下误差          if step % 100 == 0:              print("step:", step, "cross_entropy:", loss_meter.result().numpy())                # 重置              loss_meter.reset_state()                # 可视化              with summary_writer.as_default():                  tf.summary.scalar("train-loss", float(cross_entropy), step= epoch * 235 + step)  

test 函数

  def test(epoch, test_db):      """      测试模型      :param epoch: 轮数      :param test_db: 分批的测试集      :return: 无返回值      """        # 重置      acc_meter.reset_state()        for x, y in test_db:          # 获取模型输出结果          logits = model(x)            # 预测结果          pred = tf.argmax(logits, axis=1)            # 从one_hot编码变回来          y = tf.argmax(y, axis=1)            # 计算准确率          acc_meter.update_state(y, pred)        # 调试输出      print("epoch:", epoch + 1, "accuracy:", acc_meter.result().numpy() * 100, "%", )        # 可视化      with summary_writer.as_default():          tf.summary.scalar("val-acc", acc_meter.result().numpy(), step=epoch * 235)  

main 函数

  def main():      """      主函数      :return: 无返回值      """        # 获取数据      train_db, test_db = get_data()        # 轮期      for epoch in range(iteration_num):          train(epoch, train_db)          test(epoch, test_db)  


model: “sequential”
layer (type) output shape param #
dense (dense) (none, 256) 200960
dense_1 (dense) (none, 128) 32896
dense_2 (dense) (none, 64) 8256
dense_3 (dense) (none, 32) 2080
dense_4 (dense) (none, 10) 330
total params: 244,522
trainable params: 244,522
non-trainable params: 0
2021-06-14 18:01:27.399812: i tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] none of the mlir optimization passes are enabled (registered 2)
step: 0 cross_entropy: 591.5974
step: 100 cross_entropy: 196.49309
step: 200 cross_entropy: 125.2562
epoch: 1 accuracy: 84.72999930381775 %
step: 0 cross_entropy: 107.64579
step: 100 cross_entropy: 105.854385
step: 200 cross_entropy: 99.545975
epoch: 2 accuracy: 85.83999872207642 %
step: 0 cross_entropy: 95.42945
step: 100 cross_entropy: 91.366234
step: 200 cross_entropy: 90.84072
epoch: 3 accuracy: 86.69999837875366 %
step: 0 cross_entropy: 82.03317
step: 100 cross_entropy: 83.20552
step: 200 cross_entropy: 81.57012
epoch: 4 accuracy: 86.11000180244446 %
step: 0 cross_entropy: 82.94046
step: 100 cross_entropy: 77.56677
step: 200 cross_entropy: 76.996346
epoch: 5 accuracy: 87.27999925613403 %
step: 0 cross_entropy: 75.59219
step: 100 cross_entropy: 71.70899
step: 200 cross_entropy: 74.15144
epoch: 6 accuracy: 87.29000091552734 %
step: 0 cross_entropy: 76.65844
step: 100 cross_entropy: 70.09151
step: 200 cross_entropy: 70.84446
epoch: 7 accuracy: 88.27999830245972 %
step: 0 cross_entropy: 67.50707
step: 100 cross_entropy: 64.85907
step: 200 cross_entropy: 68.63099
epoch: 8 accuracy: 88.41999769210815 %
step: 0 cross_entropy: 65.50318
step: 100 cross_entropy: 62.2706
step: 200 cross_entropy: 63.80803
epoch: 9 accuracy: 86.21000051498413 %
step: 0 cross_entropy: 66.95486
step: 100 cross_entropy: 61.84385
step: 200 cross_entropy: 62.18851
epoch: 10 accuracy: 88.45999836921692 %
step: 0 cross_entropy: 59.779297
step: 100 cross_entropy: 58.602314
step: 200 cross_entropy: 59.837025
epoch: 11 accuracy: 88.66000175476074 %
step: 0 cross_entropy: 58.10068
step: 100 cross_entropy: 55.097878
step: 200 cross_entropy: 59.906315
epoch: 12 accuracy: 88.70999813079834 %
step: 0 cross_entropy: 57.584858
step: 100 cross_entropy: 54.95376
step: 200 cross_entropy: 55.797752
epoch: 13 accuracy: 88.44000101089478 %
step: 0 cross_entropy: 53.54782
step: 100 cross_entropy: 53.62939
step: 200 cross_entropy: 54.632828
epoch: 14 accuracy: 87.02999949455261 %
step: 0 cross_entropy: 54.387398
step: 100 cross_entropy: 52.323734
step: 200 cross_entropy: 53.968185
epoch: 15 accuracy: 88.98000121116638 %
step: 0 cross_entropy: 50.468914
step: 100 cross_entropy: 50.79311
step: 200 cross_entropy: 51.296227
epoch: 16 accuracy: 88.67999911308289 %
step: 0 cross_entropy: 48.753258
step: 100 cross_entropy: 46.809692
step: 200 cross_entropy: 48.08208
epoch: 17 accuracy: 89.10999894142151 %
step: 0 cross_entropy: 46.830627
step: 100 cross_entropy: 47.208813
step: 200 cross_entropy: 48.671318
epoch: 18 accuracy: 88.77999782562256 %
step: 0 cross_entropy: 46.15514
step: 100 cross_entropy: 45.026627
step: 200 cross_entropy: 45.371685
epoch: 19 accuracy: 88.7399971485138 %
step: 0 cross_entropy: 47.696465
step: 100 cross_entropy: 41.52749
step: 200 cross_entropy: 46.71362
epoch: 20 accuracy: 89.56000208854675 %


