高性能 :DeepSeek-V3 inference 推理时反量化实现 fp8_cast_bf16
【代码】高性能 :DeepSeek-V3 inference 推理时反量化实现 fp8_cast_bf16。
·
FP8 (8 bits) & FP16 (16 bits)
- FP8 和 BF16 都是浮点数格式(floating-point formats),
float通过科学计数法表示数据,float = [符号位+指数位+系数位]
| FP8 (8 bits):SEEEMMMM | FP16 (16 bits):SEEEEEMMMMMMMMMM |
|---|---|
| S (1 bit) | S (1 bit) |
| EEE (3 bits) | EEEEE (5 bits) |
| MMMM (4 bits) | MMMMMMMMMM (10 bits) |
- FP8:1位符号位、3位指数位、4位尾数位。
- FP16:1位符号位、5位指数位、10位尾数位。
| 特性 | FP8 | BF16 |
|---|---|---|
| 位数 | 8 位 | 16 位 |
| 存储需求 | 非常低 | 低(但高于 FP8) |
| 精度 | 精度非常低,仅适合低精度计算 | 较低的精度,但比 FP8 精度高 |
| 范围 | 较小的数值范围 | 与 FP32 相似,具有广泛的数值范围 |
| 主要用途 | 主要用于训练中的权重表示 | 主要用于训练和推理,尤其适用于加速机器学习 |
| 优点 | 极大的存储节省和计算加速 | 适用于大规模深度学习模型,精度损失较小 |
fp8_cast_bf16
- FP8到BF16转换: 主要通过
weight_dequant函数将FP8权重转换为BF16格式。
import os # 导入操作系统接口模块,用于文件和目录操作
import json # 导入JSON模块,用于读取和写入JSON格式的数据
from argparse import ArgumentParser # 导入ArgumentParser类,用于命令行参数解析
from glob import glob # 导入glob模块,用于文件路径模式匹配
from tqdm import tqdm # 导入tqdm模块,用于显示进度条
import torch # 导入PyTorch库
from safetensors.torch import load_file, save_file # 从safetensors库导入load_file和save_file函数
from kernel import weight_dequant # 从kernel模块导入weight_dequant函数,用于权重解量化
def main(fp8_path, bf16_path):
"""
将FP8权重转换为BF16并保存转换后的权重。
该函数从指定的目录读取FP8权重,将其转换为BF16格式,
并将转换后的权重保存到另一个指定的目录。它还更新了
模型索引文件,反映出这些更改。
参数:
fp8_path (str): 存放FP8权重和模型索引文件的目录路径。
bf16_path (str): 保存转换后的BF16权重的目录路径。
异常:
KeyError: 如果缺少所需的scale_inv张量,则会引发此异常。
注意:
- 假定FP8权重存储为safetensor文件。
- 该函数缓存已加载的safetensor文件以优化内存使用。
- 函数更新模型索引文件,删除对scale_inv张量的引用。
"""
# 设置默认数据类型为bfloat16
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True) # 如果输出目录不存在,则创建它
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") # 模型索引文件路径
with open(model_index_file, "r") as f:
model_index = json.load(f) # 读取模型索引文件
weight_map = model_index["weight_map"] # 获取权重映射
# 用于缓存已加载的safetensor文件
loaded_files = {}
fp8_weight_names = [] # 用于存储FP8权重的名称
def get_tensor(tensor_name):
"""
从缓存的safetensor文件中检索张量,如果没有缓存则从磁盘加载。
参数:
tensor_name (str): 要检索的张量名称。
返回:
torch.Tensor: 检索到的张量。
异常:
KeyError: 如果在safetensor文件中找不到指定的张量,则引发此异常。
"""
file_name = weight_map[tensor_name] # 获取该张量所在的文件名
if file_name not in loaded_files: # 如果该文件未加载
file_path = os.path.join(fp8_path, file_name) # 构建文件路径
loaded_files[file_name] = load_file(file_path, device="cuda") # 加载文件并缓存
return loaded_files[file_name][tensor_name] # 返回缓存的张量
# 获取所有safetensor文件路径,并按字母排序
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
# 遍历所有的safetensor文件
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file) # 获取文件名
current_state_dict = load_file(safetensor_file, device="cuda") # 加载当前safetensor文件
loaded_files[file_name] = current_state_dict # 将文件缓存起来
new_state_dict = {} # 用于存储转换后的新权重字典
for weight_name, weight in current_state_dict.items(): # 遍历文件中的所有权重
if weight_name.endswith("_scale_inv"): # 如果权重是scale_inv,跳过
continue
elif weight.element_size() == 1: # 如果权重是FP8(即1字节)
scale_inv_name = f"{weight_name}_scale_inv" # 对应的scale_inv张量名称
try:
# 尝试获取对应的scale_inv张量
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name) # 将FP8权重名称记录下来
new_state_dict[weight_name] = weight_dequant(weight, scale_inv) # 转换为BF16
except KeyError:
# 如果没有找到scale_inv张量,则跳过转换
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight # 保留原始权重
else:
new_state_dict[weight_name] = weight # 如果不是FP8,直接保留原始权重
# 保存转换后的权重
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# 内存管理:保持仅2个最近使用的文件
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files)) # 获取最老的文件
del loaded_files[oldest_file] # 删除最老的文件
torch.cuda.empty_cache() # 清理缓存
# 更新模型索引文件
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names: # 遍历所有FP8权重
scale_inv_name = f"{weight_name}_scale_inv" # 对应的scale_inv名称
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name) # 从weight_map中删除scale_inv权重
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) # 保存更新后的索引文件
if __name__ == "__main__":
# 设置命令行参数解析
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True) # 输入FP8权重路径
parser.add_argument("--output-bf16-hf-path", type=str, required=True) # 输出BF16权重路径
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path) # 调用主函数进行转换
weight_dequant
- 引入包,建议先阅读Triton向量相加 的基础示例以理解Triton的工作方式。
from typing import Tuple
import torch
import triton
import triton.language as tl # Triton 语言(Triton Language)允许用户在 GPU 上编写高效的并行计算内核https://github.com/triton-lang/triton
from triton import Config
weight_dequant函数用于将量化的权重张量(x)进行反量化处理,恢复到浮动值。以下是注释的解释:
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' # 确保输入张量是连续的(即内存布局连续)
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' # 确保输入张量 x 和 s 都是二维的
M, N = x.size() # 获取输入张量 x 的尺寸 M (行数) 和 N (列数)
# 创建一个和 x 形状相同的新张量 y,用来保存反量化后的结果
y = torch.empty_like(x, dtype=torch.get_default_dtype())
# 定义一个 grid 函数来计算 triton 内核所需的网格大小
# triton.cdiv 是向上取整除法,用来确保我们分配足够的线程处理每个块
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
# 调用 triton 内核 `weight_dequant_kernel` 进行反量化操作
# 将 quantized weight `x` 和 scale `s` 与结果张量 `y` 一起传递给内核
# `M`, `N`, `block_size` 作为额外的参数传递
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
# 返回反量化后的张量 y
return y
- 计算网格大小:
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])): 使用triton.cdiv来计算块的数量。triton.cdiv是向上取整除法,用于确定每个维度需要多少个块来处理M和N大小的数据。meta['BLOCK_SIZE'])是每个块处理的元素数量(默认值为 128)。
weight_dequant_kernel
- Nvidia GPU CUDA使用grid、block、thread进行索引。
- 实现反量化的核函数(模型可能使用的是LSQ(Learned Step Quantization)Quantization,仅有量化步长参数),通过
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)调用:
@triton.jit # 使用 Triton 编译器将此函数编译为高效的 GPU 内核
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
# 获取当前线程在程序中的编号
pid_m = tl.program_id(axis=0) # 获取当前行维度上的线程编号,pid_m 和 pid_n 的范围由矩阵的尺寸 M 和 N,以及线程块的大小 BLOCK_SIZE 决定
pid_n = tl.program_id(axis=1) # 获取当前列维度上的线程编号,pid_m 的值从 0 到 ceil(M / BLOCK_SIZE) - 1
# 计算矩阵列的块数
n = tl.cdiv(N, BLOCK_SIZE) # 使用向上取整除法计算列方向上的块数
# 计算当前线程块在行和列方向上的偏移量
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # 当前块在行方向的偏移量
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # 当前块在列方向的偏移量
# 将行和列的偏移量组合成一个二维的索引数组
offs = offs_m[:, None] * N + offs_n[None, :] # 将行和列偏移量结合,得到每个元素的全局索引,offs_m[:, None]形状会变成 (BLOCK_SIZE, 1),相加广播后变为(BLOCK_SIZE, BLOCK_SIZE)
# 使用掩码保证我们不会超出矩阵的边界
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) # 掩码,确保线程不会访问超出矩阵范围的数据
# 加载量化后的权重数据(量化后的值)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) # 从内存中加载量化后的数据,并转换为 float32 类型
# 加载缩放因子
s = tl.load(s_ptr + pid_m * n + pid_n) # 从内存中加载对应的缩放因子,s_ptr是指向缩放因子数组的指针,pid_m * n + pid_n计算出当前线程块在缩放因子数组中的位置。
# 执行去量化操作:去量化 = 量化值 * 缩放因子
y = x * s # 去量化的计算公式
# 将去量化后的数据存储到输出缓存中
tl.store(y_ptr + offs, y, mask=mask) # 将去量化后的值存储到输出内存中,使用掩码确保数据存储在合法的范围内,`offs` 是索引,`mask=mask` 确保只有合法的元素被存储
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐



所有评论(0)