Pytorch与Mxnet/Gluon炼丹对比

mxnet的gluon接口是模仿pytorch的接口设计的,所以两者有很多共通之处。总结一下两者接口的对比。

数据在设备之间的传输

  1. Pytorch中,有torch.cuda.is_available()方法判断是否可用gpu,而mxnet则没有,所以在代码中,我们转么定义了一个try_gpu()函数来判断gpu是否可用。

  2. mxnet中使用as_in_contextcopyto来将数据在cpu、gpu设备之间传递。如果源变量和目标变量的context一致,as_in_context函数使目标变量和源变量共享源变量的内存;而copyto函数总是为目标变量创建新的内存。

    1
    2
    y = x.copyto(mx.gpu())
    z = x.as_in_context(mx.gpu())
  3. Pytorch中使用.to()来将数据在cpu、gpu设备之间传递,与as_in_context效果一样。除此之外,它还负责Tensor的类型转换。

    1
    2
    3
    4
    tensor = torch.randn(2, 2)  
    tensor.to(torch.float64)
    cuda0 = torch.device('cuda:0')
    tensor.to(cuda0)

数据的flatten与reshape

gluon的nn模块中,如果想把多维的ndarray拉伸为一维,只需添加一个nn.Flatten()就可以了,这在后面接全连接的卷积神经网络的构建中很有用。

pytorch中没有提供类似的模块,但可以通过view这种reshape的办法来实现。

1
2
3
x = torch.randn(4, 4)
y = x.view(16)
z = x.view(-1, 8)

当然了,mxnet中的reshape就是pytorch中的view

其他

  1. 返回标量
    mxnet: nd.ndarray.NDArray.asscalar()
    pytorch: torch.Tensor.item()

  2. 返回某维度上的最大值
    mxnet: predicted = output.argmax(axis=1)
    pytorch: _, predicted = torch.max(outputs.data, 1)

持续技术分享,您的支持将鼓励我继续创作!