用einops包替换掉PyTorch的大部分矩阵变换函数

只是 rearrangereducerepeateinsum 这四个函数而已。
但真的很好用。

爱因斯坦求和标记:Einops tutorial, part 1: basics - Einops

edit time: 2026-02-09 22:36:31

rearrange, reduce, repeat

rearrange 函数:以更可读的方式重排、合并或拆分维度。
在这种计算机视觉的上下文里还挺合适的,分别是batch channel height width这样的实际语义,做调整时保持语义会很有利于阅读。

# y = x.transpose(0, 2, 3, 1)
y = rearrange(x, 'b c h w -> b h w c')

reduce 函数:这个真的很好读,一看就是对channel维度做了mean操作。

# x.mean(-1)
reduce(x, 'b h w c -> b h w', 'mean')

repeat 函数:同理,一看就是在中间这个地方重复了5次。

repeat(ims[0], "h w c -> h 5 w c").shape

einsum

(1)那么接下来可以把 einsum 用于矩阵乘法。

Y = D @ A.T # basic implementation
Y = einsum(D, A, "... d_in, d_out d_in -> ... d_out")
# better in reading

(2)当然也可以用于广播操作。
We have a batch of images, and for each image we want to generate 10 dimmed versions based on some scaling factor:

images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)
## Reshape and multiply
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
dimmed_images = images_rearr * dim_value

## Or:
dimmed_images = einsum(
	images, dim_by,
	"b h w c, dim_value -> b dim_value h w c"
)

(3) 假如需要对一批图片的中间两个维度施加一个线性变换(channel独立)。
注意:需要独立的维度一般放在前面。

channels_last = torch.randn(64, 32, 32, 3) # (batch, height, width, channel)
B = torch.randn(32*32, 32*32)

channels_first = rearrange(
	channels_last,
	"b h w c -> b c (h w)"
)
channels_first_transformed = einsum(
	channels_first, B,
	"b c pixel_in, pixel_out pixel_in -> b c pixel_out"
)
channels_last_transformed = rearrange(
	channels_first_transformed,
	"b c (h w) -> b h w c"
)

没有了。以上就是基础用法。去实践吧。这个包真的意料之外地好用。