专栏名称: 吃果冻不吐果冻皮
专注于AI工程化(LLM、MLOps、LLMOps、RAG、Agent)落地。
目录
相关文章推荐
笔吧评测室  ·  英伟达 RTX 50 系列 GPU 放弃对 ... ·  昨天  
笔吧评测室  ·  聊一款「键盘里面塞电脑」的新奇主机 ·  昨天  
笔吧评测室  ·  真・手有余香:华硕推出 MD102 ... ·  2 天前  
笔吧评测室  ·  戴尔全新命名 14/16 Plus ... ·  2 天前  
笔吧评测室  ·  新一代 ROG XG 显卡扩展坞发布:可选 ... ·  3 天前  
51好读  ›  专栏  ›  吃果冻不吐果冻皮

大模型精度(FP16,FP32,BF16)详解与实践

吃果冻不吐果冻皮  · 公众号  ·  · 2024-05-24 12:25

正文

【点击】 加入大模型技术交流群

原文:https://zhuanlan.zhihu.com/p/657886517

本篇文章主要对训练LLM以及部署应用时的精度问题进行了一些探讨和实践,读过后应该会对常用的浮点数FP16,FP32,BF16有一个更好的理解~

浮点数据类型在IEEE 754-2019(2008) [1] 标准中进行了详细的定义,定义了不同精度的浮点数格式,如binary16、binary32和binary64,分别用16位、32位和64位二进制来表示,想要更全方位深入的了解的话,可以点引用查看官方的paper。下面进行一些常用的浮点数介绍。

FP16

FP16也叫做 float16,两种叫法是完全一样的,全称是Half-precision floating-point(半精度浮点数),在IEEE 754标准中是叫做binary16,简单来说是用16位二进制来表示的浮点数,来看一下是怎么表示的(以下图都来源于维基百科 [2] ):

其中:

Sign(符号位): 1 位,0表示整数;1表示负数。

Fraction(尾数位): 10位,简单地来说就是表示小数部分,存储的尾数位数为10位,但其隐含了首位的1,实际的尾数精度为11位,这里的隐含位可能有点难以理解,简单通俗来说,假设尾数部分为1001000000,为默认在其前面加一个1,最后变成1.1001000000然后换成10进制就是:

# 第一种计算方式
1.1001000000 = 1 * 2^0 + 1 * 2^(-1) + 0 * 2^(-2) + 0 * 2^(-3) + 1 * 2^(-4) + 0 * 2^(-5) + 0 * 2^(-6) + 0 * 2^(-7) + 0 * 2^(-8) + 0 * 2^(-9) = 1.5625
# 第二种计算方式
1.1001000000 = 1 + 576
(1001000000变成10进制)/1024 = 1.5625


所以正常情况下计算公式就是:

举一个例子来计算,这个是FP16(float16)能表示的最大的正数:

0 11110 1111111111 = ( 1 ) 0 × 2 30 15 × ( 1 + 1023 1024 ) = 65504


同样,这个是FP16(float16)能表示的最大的负数:

1 11110 1111111111 = ( 1 ) 1 × 2 30 15 × ( 1 + 1023 1024 ) = 65504

这就是FP16(float16)表示的范围[-65504,65504]。

我们来看一些特殊情况,FP16(float16)能表示最小的正数是多少呢?
0 00000 0000000001 = ( 1 ) 0 × 2 1 15 × ( 1 + 1 1024 ) 0.000000059604645

我们就不一一的计算了,贴一个FP16(float16)特殊数值的情况:

上表中,subnormal number是指指数位为全0的特殊情况情况,其他的也是一些常见的特殊情况。

接下来看一下在pytorch中是如何表示的:

torch.finfo(torch.float16)
# 结果
finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, smallest_normal=6.10352e-05, tiny=6.10352e-05, dtype=float16)

一些解释:

  1. resolution (分辨率):这个浮点数类型的在十进制上的分辨率,表示两个不同值之间的最小间隔。对于 torch.float16 ,分辨率是 0.001,就是说两个不同的 torch.float16 数值之间的最小间隔是 0.001。

  2. min (最小值):对于 torch.float16 ,最小值是 -65504。

  3. max (最大值):对于 torch.float16 ,最大值是 65504。

  4. eps (机器精度):机器精度表示在给定数据类型下,比 1 大的最小浮点数,对于 torch.float16 ,机器精度是 0.000976562,对应上表中的smallest number larger than one。

  5. smallest_normal (最小正规数):最小正规数是大于零的最小浮点数,对于 torch.float16 ,最小正规数是 6.10352e-05,对应上表中的smallest positive normal number

  6. tiny (最小非零数):最小非零数是大于零的最小浮点数,对于 torch.float16 ,最小非零数也是 6.10352e-05,也是对应上表中的smallest positive normal number

这里要详细的解释一下 resolution (分辨率),这个是我们以十进制来说的两个数之间的最小间隔,我们看一个例子就会明白:

import torch

# 把10进制数转化为 torch.float16
num = 3.141
num_fp16 = torch.tensor(num).half()
print(num_fp16)
# 结果
tensor(3.1406, dtype=torch.float16)

num = 3.1415
num_fp16 = torch.tensor(num).half()
print(num_fp16)
# 结果
tensor(3.1406, dtype=torch.float16)
# 可以看到3.141和3.1415间隔只有0.0005,所以在float16下结果是一样的

num = 3.142
num_fp16 = torch.tensor(num).half()
print(num_fp16)
# 结果
tensor(3.1426, dtype=torch.float16)
# 可以看到结果不一样了

从上面代码可以看到, 十进制中相隔0.001,在float16中才会有变化, 这个时候会有一个疑问,难道精度只有小数点后三位?那怎么之前见了很多参数都是有很多小数点的?那我们来看一下全过程,把float16变成2进制,再把2进制变成16进制:

import struct
def float16_to_bin(num):
# 将float16数打包为2字节16位,使用struct.pack
packed_num = struct.pack('e', num)

# 解包打包后的字节以获取整数表示
int_value = struct.unpack('H', packed_num)[0]

# 将整数表示转换为二进制
binary_representation = bin(int_value)[2:].zfill(16)
return binary_representation

num = 3.141
num_fp16 = torch.tensor(num).half()
print(num_fp16)
binary_representation = float16_to_bin(num_fp16)

print(binary_representation) # 打印二进制表示


# 结果
tensor(3.1406, dtype=torch.float16)
0100001001001000


num = 3.1415
num_fp16 = torch.tensor(num).half()
binary_representation = float16_to_bin(num_fp16)

print(binary_representation) # 打印二进制表示


# 结果
tensor(3.1406, dtype=torch.float16)
0100001001001000 # 还是一样的结果

num = 3.142
num_fp16 = torch.tensor(num).half()
print(num_fp16)
binary_representation = float16_to_bin(num_fp16)

print(binary_representation) # 打印二进制表示


# 结果
tensor(3.1426, dtype=torch.float16)
0100001001001001 # 不一样了

再看一下把2进制变成16进制:

def binary_to_float16(binary_string):
# 检查输入是否是有效的16位二进制字符串
if len(binary_string) != 16:
raise ValueError("输入的二进制字符串必须是16位长")

# 提取组成部分:符号、指数、尾数
sign = int(binary_string[0]) # 符号位
exponent = int(binary_string[1:6], 2) # 指数位
mantissa = int(binary_string[6:], 2) / 1024.0 # 尾数位,除以2的10次方(即1024)以获得10位精度

# 根据符号、指数和尾数计算float16值
value = (-1) ** sign * (1 + mantissa) * 2 ** (exponent - 15)
return value

# 10进制3.141对应float16:3.1406
binary_representation = "0100001001001000"
# 将二进制表示转换为float16
float16_value = binary_to_float16(binary_representation)
print("通过2进制转化后Float16值:", float16_value)
# 结果:
通过2进制转化后Float16值: 3.140625

# 10进制3.1415对应float16:3.1406
binary_representation = "0100001001001000"
# 将二进制表示转换为float16
float16_value = binary_to_float16(binary_representation)
print("通过2进制转化后Float16值:", float16_value)
# 结果:
通过2进制转化后Float16值: 3.140625

# 10进制3.142对应float16:3.1426
binary_representation = "0100001001001001"
# 将二进制表示转换为float16
float16_value = binary_to_float16(binary_representation)
print("通过2进制转化后Float16值:", float16_value)
# 结果:






请到「今天看啥」查看全文