AI学习者 · 2024年08月08日

在 v5e Cloud TPU 虚拟机上进行 JetStream MaxText 推理

JetStream 是适用于大语言模型的吞吐量和内存优化引擎 XLA 设备 (TPU) 上的 (LLM) 推断。

准备工作

按照管理 TPU 资源中的步骤进行操作, 创建一个将 --accelerator-type 设置为 v5litepod-8 的 TPU 虚拟机,并连接到 TPU 虚拟机。

设置 JetStream 和 MaxText

  1. 下载 JetStream 和 MaxText GitHub 代码库

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
  2. 设置 MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh

转换模型检查点

您可以使用 Gemma 或 Llama2 模型运行 JetStream MaxText 服务器。这个 部分介绍了如何使用各种大小的 这些模型。

使用 Gemma 模型检查点

  1. 从 Kaggle 下载 Gemma 检查点
  2. 将检查点复制到 Cloud Storage 存储桶

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive

    如需查看包含 ${YOUR_CKPT_PATH}${CHKPT_BUCKET} 值的示例,请参阅转化脚本

  3. 将 Gemma 检查点转换为与 MaxText 兼容的未扫描检查点。

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}

使用 Llama2 模型检查点

  1. 开源社区下载 Llama2 检查点, 或使用您已生成的代码。
  2. 将检查点复制到 Cloud Storage 存储桶。

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive

    如需查看包含 ${YOUR_CKPT_PATH}${CHKPT_BUCKET} 值的示例, 请参阅转化脚本

  3. 将 Llama2 检查点转换为与 MaxText 兼容的未扫描检查点。

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}

运行 JetStream MaxText 服务器

本部分介绍了如何使用与 MaxText 兼容的 检查点。

为 MaxText 服务器配置环境变量

根据您使用的模型导出以下环境变量。 使用 model_ckpt_conversion.shUNSCANNED_CKPT_PATH 的值 输出。

为服务器标志创建 Gemma-7b 环境变量

配置 JetStream MaxText 服务器标志

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

为服务器标志创建 Llama2-7b 环境变量

配置 JetStream MaxText 服务器标志

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

为服务器标志创建 Llama2-13b 环境变量

配置 JetStream MaxText 服务器标志

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

启动 JetStream MaxText 服务器

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

JetStream MaxText Server 标志说明

tokenizer_path

标记生成器的路径(应与您的模型匹配)。

load_parameters_path

从特定目录加载参数(无优化器状态)

per_device_batch_size

每个设备的批次大小(1 个 TPU 芯片 = 1 个设备)

max_prefill_predict_length

执行自动回归时的预填充长度上限

max_target_length

序列长度上限

model_name

模型名称

ici_fsdp_parallelism

实现 FSDP 并行处理的分片数

ici_autoregressive_parallelism

自动回归并行处理的分片数

ici_tensor_parallelism

并行张量处理的分片数

weight_dtype

权重数据类型(例如 bfloat16)

scan_layers

扫描图层布尔值标志(设置为“false”以进行推理)

向 JetStream MaxText 服务器发送测试请求

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

输出将如下所示:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

使用 JetStream MaxText 服务器运行基准测试

若要获得最佳基准测试结果,请启用量化(使用经过 AQT 训练或微调的 调整检查点以确保准确性)。启用 量化,请设置量化标记:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

基准化分析 Gemma-7b

如需对 Gemma-7b 进行基准测试,请执行以下操作:

  1. 下载 ShareGPT 数据集。
  2. 在运行 Gemma 7b 时,请务必使用 Gemma 标记生成器 (tokenizer.gemma)。
  3. 为您的第 1 次运行添加 --warmup-first 标志,以预热服务器。
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

对大型 Llama2 进行基准化分析

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env
文章来源:google cloud

推荐阅读

更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
10834
内容数
80
搭载基于安谋科技自研“周易”NPU的芯擎科技工业级“龍鹰一号”SE1000-I处理器
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息