详细解读视频生成模型Wan2.1代码
这篇博客的初稿写于8个月前Wan2.1刚刚开源之际,如今Wan2.1已经成为了做视频生成方向最最常用的基础模型,据不完全统计,半年间基于Wan2.1发表的顶会顶刊就超过了100篇。这篇博客详细解读一下这篇现代视频生成模型的原理和代码。
前言:这篇博客的初稿写于8个月前Wan2.1刚刚开源之际,如今Wan2.1已经成为了做视频生成方向最最常用的基础模型,据不完全统计,半年间基于Wan2.1发表的顶会顶刊就超过了100篇。这篇博客详细解读一下这篇现代视频生成模型的原理和代码。
目录

VAE设计
时至今日,在很多开源的VAE的对比中,wan2.1 和 wan2.2 的重建指标依旧是非常能打。
wan2.2 和 wan2.1 vae 的设计对比:
-
压缩比大幅提升
-
Wan2.1:8×8×2(128 倍)
-
Wan2.2:16×16×4(1024 倍)
这意味着同样分辨率下,Wan2.2 的 latent 只占 Wan2.1 的 1/8,显存占用降低约 64%。
-
-
重建质量不降反升
-
Wan2.2 通过“非对称编解码 + 残差采样”结构,在更高压缩率下 PSNR 仍略优于 Wan2.1。
-
官方测试 720P 视频 PSNR 达 32.5 dB,比 Wan2.1 的 30.1 dB 高出 2 dB 以上。
-
-
通道维度扩展
-
Wan2.1:latent 通道数 16
-
Wan2.2:latent 通道数 48
更多通道补偿了高压缩带来的信息损失,细节保留更好。
-
-
速度/显存收益
-
在 4090 上,Wan2.2-TI2V-5B 凭借新 VAE 可把 5 s 720P 视频生成时间从 Wan2.1-14B 的数分钟级缩短到约 155 s(多卡)或 534 s(单卡),且仅需 24 GB 显存即可跑满。
-
代码解读
wan2.1 的vae核心是 因果 + 流式”VAE。
-
初始化:所有超参一次性写死
dim=128 # 基础通道数
z_dim=4 # 最终潜码通道数(报告里 16,这里 4 是“单组件”,后面会 *2)
dim_mult=[1,2,4,4] # 4 级下采样,通道变化 128→256→512→512
temperal_downsample=[True,True,False]
# 对应 4 级里哪几级要做时间抽帧:第 2、3 级做 2×,第 4 级不做
Encoder3d 和 Decoder3d 就是报告里说的“3D 因果残差网络”,内部已经按
“空间 2D 因果 + 时间 1D 因果”拆好 kernel,保证不偷看未来。
-
前向:经典 VAE 三段式
x_recon, mu, log_var = model(x)
-
encode拿 μ 和 σ -
reparameterize做采样 -
decode把潜码还原成视频
唯一特殊的是:encode/decode 内部都按“块”跑,整段视频不会一次进显存。
-
encode:把视频切成 “1+4×n” 的因果块
t = x.shape[2]
iter_ = 1 + (t-1)//4 # 先送 1 帧,再每 4 帧一块
for i …
if i==0:
out = encoder(x[:,:,:1,…]) # 第 0 帧单独过
else:
out_= encoder(x[:,:,1+4*(i-1):1+4*i,…])
out = cat([out, out_], 2) # 时间维拼回去
-
这样保证 时间因果:后面块永远拿不到前面块的“未来”信息。
-
feat_cache/feat_idx是给CausalConv3d内部用的“隐藏状态”缓存,
跨块时把上一块最后的 hidden 传下去,等同“RNN 的 h_t”。
-
潜码归一化:scale[0]=mean, scale[1]=scale
mu = (mu - scale[0]) * scale[1]
训练时 scale 是 EMA 统计的全局 mean/std;推理时可直接喂 0/1,
或者把数据集统计量传进来做 offline normalization,保证扩散模型输入 N(0,1)。
-
decode:逐帧滑窗,同样因果
for i in range(iter_):
out_ = decoder(z[:,:,i:i+1,…]) # 每次只送 1 帧潜码
out = cat([out, out_], 2)
解码器也带 feat_cache,所以:
-
显存占用只与“窗口长度”有关,与总时长无关 → 无限长生成。
-
输出帧率 = 潜码帧率 × 时间下采样倍数(4×),正好对上报告里的 6 fps 潜码。
下面是vae的原始代码:
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x,按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
作者在外面还包了一层,相当于warp吧。可以看出整个项目的编码风格不一样,明显是不同人写的:
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
# params
cfg = dict(
dim=96,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0)
cfg.update(**kwargs)
# init model
with torch.device('meta'):
model = WanVAE_(**cfg)
# load checkpoint
logging.info(f'loading {pretrained_path}')
model.load_state_dict(
torch.load(pretrained_path, map_location=device), assign=True)
return model
class WanVAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=torch.float,
device="cuda"):
self.dtype = dtype
self.device = device
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=dtype, device=device)
self.std = torch.tensor(std, dtype=dtype, device=device)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
).eval().requires_grad_(False).to(device)
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
def decode(self, zs):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]
核心DiT结构
wan2.1 采用了比较经典的transformer结构,而不是diffusion 领域大火的单流-双流结构。
-
前向:一条流水线,4 步走完
① Patchify
Python
复制
x = [C, F, H, W]
x = patch_embedding(u.unsqueeze(0)) # [1,2048,F,H/2,W/2]
x = flatten(2).transpose(1,2) # [1,L,2048] L=F*H/2*W/2
每个视频单独算 L,再 cat 成一个 batch,不足 seq_len 补零。
② 时间步嵌入
e = time_embedding(sinusoidal(t)) # [B,2048]
e0 = time_projection(e) # [B,6,2048] 拆成 6 份给 Layer-scale + Gate
6 份分别喂给:
-
每个
WanAttentionBlock的 self-attn gate 与 cross-attn gate -
Head的 输出 scale-shift
用unflatten(1,(6,dim))一次搞定。
③ 文本 / 图像条件
-
文本:T5 向量先
Linear→GELU→Linear统一到 2048,再 pad 到 512 长。 -
图像:CLIP 图像 token 过
MLPProj得到 257×2048,拼在文本前面,形成 image-first 的 cross-attn 序列。
任务类型由model_type控制:t2v只走文本;i2v/flf2v/vace都额外吃clip_fea。
④ 32 层 Transformer + Head
每块 WanAttentionBlock 内部顺序:
-
Self-Attention(窗口 or 全局)
-
Cross-Attention(文本/图像)
-
SwiGLU-FFN
全部 Pre-RMSNorm + 残差 + Layer-scale(6 份 e0)。
最后Head做 线性 + pixel-shuffle,把 2048 映射回out_dim*patch_size[0]*patch_size[1]*patch_size[2],
再由unpatchify恢复[C_out, F, H/8, W/8]。
class WanModel(ModelMixin, ConfigMixin):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
ignore_for_config = [
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
]
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps)
for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
if model_type == 'i2v' or model_type == 'flf2v':
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
# initialize weights
self.init_weights()
def forward(
self,
x,
t,
context,
seq_len,
clip_fea=None,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode or first-last-frame-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v' or self.model_type == 'flf2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
写不下了,下一篇继续。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐



所有评论(0)