聊聊GLM-4-9B开源模型的微调loss计算

概述

Github官方地址:GLM-4

网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

微调

微调格式:

代码语言:javascript
复制
[
  {
    "messages": [
      {
        "role": "system",
        "content": "<system prompt text>",
        "tools": [
          {
            "name": "<tool name>",
            "args": {
              "<arg name>": "<arg value>"
            }
          }
        ]
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      },
      {
        "role": "observation",
        "content": "<observation prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response observation>"
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      }
    ]
  }
]

微调源码地址:finetune.py Loss计算代码:

代码语言:javascript
复制
def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_labels = []
    # batched_conv 是一个数组
    # conv 是数组内的单个 message
    for conv in batched_conv:
        input_ids = [151331, 151333]
        loss_masks = [False, False]
        # conv 是数组内的单个 message
        # message 是 单个role json对象
        for message in conv:
            message = process_message(message)
            # 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
            loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
            # 获取 input 文本的数字表示(ids)
            new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
            # 计算整句的 mask
            new_loss_masks = [loss_mask_val] * len(new_input_ids)
            # 拼接message中的每段json
            input_ids += new_input_ids
            # 拼接message中每段json对应的mask
            loss_masks += new_loss_masks
        # 追加结尾的 token id
        input_ids.append(tokenizer.eos_token_id)
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                # 添加到label,计算loss
                labels.append(input_id)
            else:
                # -100 不处理,即ignore_index
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        # 截断
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])
    return {'input_ids': batched_input_ids, 'labels': batched_labels}

注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

代码语言:javascript
复制
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
# 数据集拆分遍历
train_dataset = data_manager.get_dataset(
    Split.TRAIN,
    functools.partial(
        process_batch,
        tokenizer=tokenizer,
        max_input_length=ft_config.max_input_length,
        max_output_length=ft_config.max_output_length,
    ),
    batched=True,
)
print('train_dataset:', train_dataset)

Loss计算如下图所示: