跳转至

Pytorch 6 :Visualization

导言
  • Loss,网络结构等可视化。

可视化

网络结构可视化

自动 https://stackoverflow.com/questions/52468956/how-do-i-visualize-a-net-in-pytorch

或者手动drawio

误差实时可视化TensorBoard

https://www.cnblogs.com/sddai/p/14516691.html

原理: 通过读取保存的log文件来可视化数据

标量可视化

记录数据,默认在当前目录下一个名为'runs/'的文件夹中。

from torch.utils.tensorboard import SummaryWriter

# 写log的东西
log_writer = SummaryWriter('./path/to/log')
# 第一个参数是名称,第二个参数是y值,第三个参数是x值。
log_writer.add_scalar('Loss/train', float(loss), epoch)

运行 tensorboard --logdir=runs/ --port 8123 在某端口打开,比如 https://127.0.0.1:6006

网络结构可视化

在tensorboard的基础上使用tensorboardX

from tensorboardX import SummaryWriter

with SummaryWriter(comment='LeNet') as w:
    w.add_graph(net, (net_input, ))

PR曲线

什么是PR曲线

log_writer.add_pr_curve("pr_curve", label_batch, predict, epoch)

x,y轴分别是recall和precision。应该有可能有矛盾的数据,或者网络分不开,对于不同的阈值,可以划分出PR图。

与ROC曲线左上凸不同的是,PR曲线是右上凸效果越好。