PyTorch 介绍

PyTorch-logo
PyTorch-logo

PyTorch是一个开源的Python机器学习库,基于Torch,底层由C++实现,应用于人工智能领域,如计算机视觉和自然语言处理。它主要由Meta Platforms的人工智能研究团队开发。著名的用途有:特斯拉自动驾驶,Uber最初发起而现属Linux基金会项目的概率编程软件Pyro,Lightning。

TF1.x的主要问题是混乱的api设计以及难以debug的静态图机制,每个小版本之间都可能是完全不同的API调用,然后TF2.x又把所有API炸的一干二净。而PyTorch的易用性直接薄纱TF并且在经年来也是越来越火基本占据了头号位置,并且拥有较为完好的整体生态。

官方安装方式

PyTorch的安装方式相较于TF可谓是相当简单,直接访问官网选择配置即可

注意:推荐使用conda进行环境隔离,请参加往期文章

官方安装方式
官方安装方式

运行命令进行安装:

1
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

安装过去版本的PyTorch

注意:推荐使用conda进行环境隔离,请参加往期文章

在代码复现时通常需要首先复现其代码环境,打开PyTorch的历史发行页面

滚动到需要的版本号附近,复制命令并安装

1
pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu116

V1.12.0

PyTorch的安装极为简单,但PyTorch默认不安装几个扩展包,例如:

  • torch_cluster
  • torch_scatter
  • torch_sparse
  • torch_spline_conv

对于这几种包,必须手动引导安装,如直接输入命令,通常情况下会导致安装失败

1
pip install torch_sparse==0.6.16

正确做法为,在torch的包分发页面选择对应的版本号

选择版本号
选择版本号

复制链接,配置-f选项,输入如下的命令即可安装成功

1
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.0%2Bcu116.html