JetStream 是用于在 XLA 设备 (TPU) 上的大型语言模型 (LLM) 推断的吞吐量和内存优化引擎。
准备工作
按照设置 Cloud TPU 环境中的步骤创建一个 Google Cloud 项目、激活 TPU API、安装 TPU CLI 并申请 TPU 配额。
按照使用 CreateNode API 创建 Cloud TPU 中的步骤创建一个 TPU 虚拟机,并将 --accelerator-type
设置为 v5litepod-8
。
克隆 JetStream 代码库并安装依赖项
使用 SSH 连接到您的 TPU 虚拟机
- 将 ${TPU_NAME} 设置为您的 TPU 名称。
- 将 ${PROJECT} 设为您的 Google Cloud 项目
- 将 ${ZONE} 设置为要在其中创建 TPU 的 Google Cloud 可用区
gcloud compute config-ssh gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
克隆 JetStream 代码库
git clone https://github.com/google/jetstream-pytorch.git
(可选)使用
venv
或conda
创建一个虚拟 Python 环境并将其激活。运行安装脚本
cd jetstream-pytorch source install_everything.sh
下载并转换权重
- 从 GitHub 下载官方 Llama 权重。
转换权重。
- 将 ${IN_CKPOINT} 设置为包含 Llama 权重的位置
- 将 ${OUT_CKPOINT} 设置为位置写入检查点
export input_ckpt_dir=${IN_CKPOINT} export output_ckpt_dir=${OUT_CKPOINT} export quantize=True python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
在本地运行 JetStream PyTorch 引擎
如需在本地运行 JetStream PyTorch 引擎,请设置标记生成器路径:
export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama
使用 Llama 7B 运行 JetStream PyTorch 引擎
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
使用 Llama 13b 运行 JetStream PyTorch 引擎
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
运行 JetStream 服务器
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
注意:--platform=tpu=
参数需要指定 TPU 设备的数量(v4-8
为 4,v5lite-8
为 8)。例如 --platform=tpu=8
。
运行 run_server.py
后,JetStream PyTorch 引擎即可接收 gRPC 调用。
运行基准测试
切换到您运行 install_everything.sh
时下载的 deps/JetStream
文件夹。
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
如需了解详情,请参阅 deps/JetStream/benchmarks/README.md
。
典型错误
如果您收到 Unexpected keyword argument 'device'
错误,请尝试以下操作:
- 卸载
jax
和jaxlib
依赖项 - 使用
source install_everything.sh
重新安装
如果您收到 Out of memory
错误,请尝试以下操作:
- 使用较小的批次大小
- 使用量化
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
清理 GitHub 代码库
# Clean up the JetStream repository rm -rf JetStream # Clean up the xla repository rm -rf xla
清理 Python 虚拟环境
rm -rf .env
删除 TPU 资源
如需了解详情,请参阅删除 TPU 资源。
文章来源:google cloud
推荐阅读
- 在 v5e Cloud TPU 虚拟机上进行 JetStream MaxText 推理
- 使用 Pax 在单主机 TPU 上训练
- 使用 Cloud TPU 进行 BERT 微调:句子和句对分类任务 (TF 2.x)
更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。