AI学习者 · 2024年07月15日

超大模型加载转换Trick

原文:https://zhuanlan.zhihu.com/p/698950172

在深度学习领域,大模型的训练和推理通常需要消耗大量的计算和内存。如何高效地加载和使用大模型是一个相当关键的问题。在这篇博客中,我将分享一些关于更快加载大模型和减少内存的技巧.

问题分析

假设现在我们有一个236B 超大模型的原始权重的 checkpoint.pth 文件, 比如 DeepSeek Chat V2, 以BF16 格式存储, 一个标准的加载流程如下

import torch

state_dict = torch.load(checkpoint_file)
my_model = BigModelClass(...)
my_model.load_state_dict(state_dict)

在这段代码的中, my_model = BigModelClass(...) 会初始化一个模型, torch.load(checkpoint_file)函数会将模型权重从磁盘加载到内存中。然后,my_model.load_state_dict(state_dict)函数会将权重从内存加载到模型的参数中。这两个步骤都可能会消耗大量的时间和内存。理想情况下, 一个236B BF16格式的模型需要占据 472GB 的内存, 上面的代码会有两个模型副本, 这意味着峰值需要944GB 内存, 接近1T ,这是非常夸张的也是不可接受的.

我们用一段简单的代码来验证上面的推断, 首先初始化一个 1B size 的模型并存下来,

import torch

def count_parameters(model):
    total_params =  sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params / 1e9 

def model_memory_size_in_megabytes(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.numel() * param.element_size()  

    bytes_in_gb = 1024 * 1024 * 1024 
    return param_size / bytes_in_gb

class BigModel(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])

    def forward(self, x):
        return self.linears(x)

size = 10000
model = BigModel(size)

# 打印模型的参数量
print(f'The model has {count_parameters(model):,} B trainable parameters')
print(f"The model's memory size is approximately {model_memory_size_in_megabytes(model):.2f} GB.")
torch.save(model.state_dict(), 'checkpoint.pth')
The model has 1.0001 B trainable parameters
The model's memory size is approximately 3.73 GB.

然后 按照上面的方式加载模型, 并统计cpu 内存占用, torch 默认是FP32 格式, 1B模型占用约 4GB 内存(实际为3.73GB左右), 下面代码验证后基本符合预期

def print_usage():
    pid = os.getpid()
    py = psutil.Process(pid)
    memory_use = py.memory_info()[0] / 2. ** 30  # memory use in GB...I think
    print(f'memory: {memory_use:.2f} GB')
    print('CPU percent:', psutil.cpu_percent())

print('Before Load the state_dict:')
print_usage()
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 8.5
start_time = time.time()
state_dict = torch.load('checkpoint.pth')
print(f'Loading the state_dict took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict:')
print_usage()
Loading the state_dict took 2.09 seconds
After Load the state_dict:
memory: 4.06 GB
CPU percent: 7.0

4.06 - 0.34 = 3.72基本一致

start_time = time.time()
model = BigModel(size)
print(f'Init the model took {time.time() - start_time:.2f} seconds')
print('After Init the model:')
print_usage()
Init the model took 7.23 seconds
After Init the model:
memory: 7.79 GB
CPU percent: 7.6

7.79 - 4.06 = 3.73 基本一致

start_time = time.time()
model.load_state_dict(state_dict)
print(f'Loading the state_dict to model took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict to model:')
print_usage()
Loading the state_dict to model took 2.63 seconds
After Load the state_dict to model:
memory: 7.79 GB
CPU percent: 16.4

问题解决

分析清楚在加载和初始化环节中各个流程的开销, 我们来看看可以如何加速每个过程.

使用torch.load(mmap=True)

首先,让我们考虑一下当我们使用 加载检查点时会发生什么torch.load。当我们使用 保存检查点时torch.save,张量存储会使用保存它们的设备进行标记。使用torch.load,张量存储将加载到它们标记的设备(除非使用标志覆盖此行为 map_location)。为了便于解释,我们假设张量保存在 CPU 上。这意味着在第一行,所有张量存储都将加载到 CPU RAM 中,这在以下情况下可能是不可行的:

  • CPU RAM 小于检查点的大小。
  • 等待整个检查点加载到 RAM 中,然后再执行某些按张量处理等操作。
start_time = time.time()
state_dict = torch.load('checkpoint.pth')
end_time = time.time()
print(f"loading time without mmap={end_time - start_time}")
print_usage()
loading time without mmap=2.0737619400024414
memory: 4.06 GB
CPU percent: 8.7

torch.load中的mmap参数图解决上述两个问题。顾名思义,mmap关键字参数 totorch.load 使用mmap 调用 ,将磁盘上的文件映射到虚拟内存,并让操作系统自动处理到物理内存的加载和卸载。当这个标志被传递时,张量存储将被内存映射。

start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True)
end_time = time.time()
print(f"loading time with mmap={end_time - start_time}")
print_usage()
loading time with mmap=0.003424406051635742
memory: 0.34 GB

通过上面对比,我们可以发现 使用mmap可以加速模型加载并减少内存占用, 对于236B的模型, 我们实际上并不需要 1TB的 CPU内存来完成转换

使用 torch.device('meta')

当模型size 巨大时, 模型初始化也需要巨大时间, 我们扩大一下模型size到25B, 初始化一个模型就需要接近3分钟.

size = 50000
start_time = time.time()
model = BigModel(size)
end_time = time.time()
print(f"init time={end_time - start_time}")
print(f'The model has {count_parameters(model):,} B trainable parameters')
print(f"The model's memory size is approximately {model_memory_size_in_megabytes(model):.2f} GB.")
init time=184.56671452522278
The model has 25.0005 B trainable parameters
The model's memory size is approximately 93.13 GB.

但在load 模型时, 初始化这一步是多余的, 我们实际上只需要知道模型的所有 key 和 对应的 shape,

这个时候, torch.device('meta') 这个 上下文就可以发挥作用了, torch.device() 上下文管理器确保工厂调用将像它们被传递了指定的"device"作为参数一样执行。在 torch.device('meta') 上的张量不携带数据。然而,它们具有张量所具有的所有其他元数据,例如.size().stride().requires_grad等。

with torch.device('meta'):
   model = BigModel(size)
model.load_state_dict(state_dict, assign=True)

for n, p in model.named_parameters():
    assert p.device.type != "meta", f"{n} has not been loaded!"
    

注意, 在使用 torch.device('meta')后, 我们需要加上 assign=True参数来让参数被加载. 最后一段代码可以check 所有参数被正确加载了, 加载后的参数的 device应该不再是 meta 了.

实验结果

最后, 我们直接上一个100B size大小的大模型来对比, 是否使用 torch.load(mmap=True) 和torch.device('meta') 速度差别.

size = 100000
model = BigModel(size)

# 打印模型的参数量
print(f'The model has {count_parameters(model):,} B trainable parameters')
print(f"The model's memory size is approximately {model_memory_size_in_megabytes(model):.2f} GB.")
torch.save(model.state_dict(), 'checkpoint.pth')
The model has 100.001 B trainable parameters
The model's memory size is approximately 186.27 GB.

加速前

start_time = time.time()
state_dict = torch.load('checkpoint.pth')
print(f'Loading the state_dict took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict:')
print_usage()

start_time = time.time()
model = BigModel(size)
print(f'Init the model took {time.time() - start_time:.2f} seconds')
print('After Init the model:')
print_usage()

start_time = time.time()
model.load_state_dict(state_dict)
print(f'Loading the state_dict to model took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict to model:')
print_usage()

start_time = time.time()
input = torch.randn(1, size)
output = model(input)
print(output)
print(f'One time forward {time.time() - start_time:.2f} seconds')
print_usage()
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 9.1
Loading the state_dict took 852.06 seconds
After Load the state_dict:
memory: 372.87 GB
CPU percent: 5.0
Init the model took 518.15 seconds
After Init the model:
memory: 745.41 GB
CPU percent: 4.9
Loading the state_dict to model took 125.63 seconds
After Load the state_dict to model:
memory: 745.41 GB
CPU percent: 11.7
tensor([[-0.0015, 0.0017, -0.0009, ..., -0.0036, 0.0041, 0.0052]],
grad_fn=\)
One time forward 6.95 seconds
memory: 745.42 GB
CPU percent: 11.4

加速后

start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True)
print(f'Loading the state_dict took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict:')
print_usage()

start_time = time.time()
with torch.device('meta'):
  model = BigModel(size)
print(f'Init the model took {time.time() - start_time:.2f} seconds')
print('After Init the model:')
print_usage()

start_time = time.time()
model.load_state_dict(state_dict, assign=True)
print(f'Loading the state_dict to model took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict to model:')
print_usage()

for i in range(2):
    start_time = time.time()
    input = torch.randn(1, size)
    output = model(input)
    print(output)
    print(f'One time forward {time.time() - start_time:.2f} seconds')
    print_usage()
    
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 9.1
Loading the state_dict took 0.11 seconds
After Load the state_dict:
memory: 0.34 GB
CPU percent: 6.1
Init the model took 0.00 seconds
After Init the model:
memory: 0.34 GB
CPU percent: 4.3
Loading the state_dict to model took 0.00 seconds
After Load the state_dict to model:
memory: 0.34 GB
CPU percent: 10.0
tensor([[ 0.0080, -0.0017, -0.0027, ..., -0.0011, 0.0097, -0.0048]],
grad_fn=\)
One time forward 48.37 seconds
memory: 372.85 GB
CPU percent: 5.2
tensor([[ 0.0038, 0.0014, -0.0076, ..., -0.0016, 0.0004, -0.0018]],
grad_fn=\)
One time forward 3.28 seconds
memory: 372.86 GB
CPU percent: 13.4

通过上面的对比, 加速前100B模型加载时间为

852.06 + 518.15 + 125.63 = 1495(s) = 25 (min)

而使用 mmap + meta device 加载几乎没有时间开销, 只有模型真正运行时才会从硬盘拷贝权重到CPU RAM。

作者:Fazzie
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18849
内容数
1389
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息