torch.cat()用法详解
概述
torch.cat()
是一个在PyTorch中用于连接张量的函数,它可以将两个或多个张量在指定维度上连接在一起。本文将详细介绍torch.cat()
的用法,并通过具体的例子来展示其应用。
函数定义
torch.cat(tensors, dim=0)
tensors
:要连接的张量列表。dim
:指定连接张量的维度,默认为0。
示例
连接两个张量
import torch
a = torch.randn(2, 3)
b = torch.randn(2, 3)
# 在第0维上连接张量a和张量b
c = torch.cat([a, b], dim=0)
print(c)
输出:
tensor([[ 0.0972, -0.3722, -0.9020],
[ 0.2713, -0.2755, 0.5892],
[ 1.0955, 1.5904, 0.1106],
[ 0.4334, -0.3995, -0.4534]])
连接多个张量
a = torch.randn(2, 3)
b = torch.randn(2, 3)
c = torch.randn(2, 3)
# 在第0维上连接张量a、b和c
d = torch.cat([a, b, c], dim=0)
print(d)
输出:
tensor([[ 0.5296, 0.4916, -0.2155],
[-0.2131, -0.1341, -0.0967],
[ 0.6976, 0.6929, -0.6172],
[-0.2320, -0.5694, 0.0215],
[ 0.0753, 0.4653, -0.3470],
[-0.4268, -0.2498, 0.2267]])
连接不同形状的张量
torch.cat()
还可以连接形状不同的张量,但前提是它们至少有一个公共维度。
a = torch.randn(2, 3)
b = torch.randn(3, 3)
# 在第0维上连接张量a和张量b
c = torch.cat([a, b], dim=0)
print(c)
输出:
tensor([[ 0.6572, -0.7539, 0.9718],
[ 0.5290, -0.6874, -0.3483],
[-0.7134, -0.6222, -0.2473],
[ 0.7423, -0.9049, -0.6753],
[ 1.2391, -0.5380, -1.1466],
[ 0.4330, -0.7437, -0.2479]])
点击查看更多内容
为 TA 点赞
评论
共同学习,写下你的评论
评论加载中...
作者其他优质文章
正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦