基于CMExam数据集的ChatGLM4模型微调方法
帮人代工的一个项目,正好深入地学习一下真正的大模型方案,并记录下学习LLM做微调的方法。项目源码已上传github。项目源代码:SFT_GLM4_on_CMExam
一、研究目标
构建并微调一个基于LLM的深度学习模型,使其能够适应医学考试问题的场景。具体来说,本项目的主要目标包括:
1. 利用提示工程和基于LoRA的模型微调技术优化模型,使其能够生成准确且解释性强的答案;
2. 通过实验评估模型在医学考试问题上的表现,并探索不同超参数对模型性能的影响;
二、数据集介绍
CMExam数据集是从中国国家医学执业资格考试中提取的一个综合性中文医学考试数据集,旨在为大规模语言模型(LLM)的医学领域评估提供标准化和全面的基准。该数据集包含超过60,000道多项选择题,每道题目附有详细的解答说明,便于模型进行开放式的推理评估。
CMExam还邀请医学专业人员对每道题目进行了五个额外的注释,包括疾病组别、临床科室、医学学科、能力领域和问题难度级别。这些注释基于权威的医学资源和客观指标,为模型性能的多维度评估提供了丰富的信息。 疾病组别的分类参考了国际疾病分类第11版(ICD-11),涵盖27个类别;临床科室的分类依据《医疗机构诊疗科目名录》,包含36个科室;医学学科的分类采用《学位授予和人才培养学科目录(2022年)》,分为7个学科;能力领域的划分由医学专业人员定义,共有4个领域;问题难度级别则根据人类考生的表现分为5个等级。
CMExam的数据文件主要以CSV格式存储,分为训练集(train.csv)、验证集(val.csv)和带有注释的测试集(test_with_annotations.csv)。每个文件包含题目、选项、答案、解答说明以及上述五个注释信息。数据集的总大小约为49.29MB,适合用于模型训练和评估。
在CMExam数据集论文中的基准测试评估了多种大型语言模型(LLMs)的性能,旨在了解它们在中文医学考试领域的表现。测试的模型包括GPT-4、GPT-3.5、ChatGLM等。下图展示了各基准模型在CMExam数据集的性能表现。
三、训练方法
3.1 prompt工程
提示工程(Prompt Engineering)在模型效果优化中扮演了重要角色。提示词设计的合理性对模型的生成结果有直接影响,尤其在医学领域的问答任务中,精确的提示词能够有效引导模型输出更符合预期的答案。我们探索了多个不同的提示词模式,通过对比实验选择出最佳的提示词配置,以提高模型的准确性和解释性。
在初期测试中,我们尝试了多种不同的提示词设计,以观察其对模型生成效果的影响。设计的提示词不仅包含问题本身,还包括预设的答案和解释的格式。这种方法旨在帮助大语言模型(LLM)理解并生成符合医学考试要求的答案。在提示词的设计过程中,我们通过不断调整提示词的内容和结构,观察模型在输出格式和准确性方面的变化。
提示格式1:
训练集格式化数据:
问题:{填入训练集中的问题}
答案:{填入训练集中的答案}
理由:{填入训练集中的解释}
验证集PROMPT提示:
问题:{填入验证集中的问题}
答案:后续由LLM生成
提示格式2:
训练集格式化数据:
问题:{填入训练集中的问题}
分析:{填入训练集中的解释}
正确选项:{填入训练集中的答案}
验证集PROMPT提示:
问题:{填入验证集中的问题}
分析:后续由LLM生成
提示格式3:
训练集格式化数据:
问题:{填入训练集中的问题}
分析:{填入训练集中的解释}
正确选项:{填入训练集中的答案}
验证集PROMPT提示:
{加入固定提示模版}
问题:{填入验证集中的问题}
分析:后续由LLM生成
最初的提示设计(如提示格式1所示)较为简单,直接以“问题”“答案”和“理由”三项为主。该提示词虽能让模型理解问题和答案的基本结构,但在复杂的医学问答情境下往往缺乏足够的指导性,导致模型生成的解释质量较低,准确率不高。为解决这一问题,我们引入了结构更为复杂的提示词格式(如提示格式2和提示格式3所示),在问题和答案的基础上增加了“分析”字段,用于提供对答案的简要解释,帮助模型更好地推理和回答。
在进一步的测试中,我们发现提示词的详细程度和引导性直接影响模型的准确率和生成内容的合理性。例如,在格式2和格式3中,模型能够更好地理解题目并生成带有合理解释的答案。特别是在格式3的提示下,模型生成的答案解释更为完整,且符合预期的格式要求。我们还测试了不同提示词的引导效果,通过设置固定模板的提示内容,引导模型生成更符合医学考试答题风格的答案。
不同提示词模式的实验结果显示,随着提示词结构的复杂化,模型的准确率和生成答案的解释质量逐步提高。这一结果表明,通过为模型提供明确的格式和回答引导,能够有效减少模型的生成偏差,提高回答的准确性。我们观察到,在固定模板的提示模式下(如最终最佳提示词所示),模型的准确率达到了最佳水平。这种提示方式不仅为模型提供了回答的结构性指引,还鼓励其采用类似思维链的思维方式,逐步推理出答案。
最终,我们选定了结构性最强的提示模式作为最佳提示词配置:
对于后续问题,你需按照以下格式,向分析和正确选项字段填入答案。
格式:
问题:问题内容
分析:{填入分析,50字以内}
正确选项:{填入正确的选项和选项文本}
问题:问题内容
分析:解释问题的要点和推理,字数限制在50字以内
正确选项:提供正确答案及其对应选项文本
3.2 AdaLoRA微调
在本项目中,我们选择了AdaLoRA(Adaptive LoRA)作为LoRA(Low-Rank Adaptation)的改进版本,来进行模型的高效微调。LoRA技术的核心思想是在模型的权重矩阵中添加低秩的矩阵分解,减少了模型微调时需要更新的参数数量,使得微调过程既高效又节省内存资源。而AdaLoRA在此基础上进行了进一步优化,能够自适应地调整低秩矩阵的秩,从而在性能和计算成本之间实现了更好的平衡。选择AdaLoRA作为微调方法的原因在于其高效性和灵活性,适合处理医学领域的数据特点和任务需求。
传统的LoRA方法通过在模型的权重矩阵中引入低秩矩阵,使得微调过程不再需要大规模地更新模型的所有权重,而是只更新引入的低秩矩阵,从而大幅度减少了计算开销。然而,LoRA在处理复杂任务时面临一个局限,即固定的低秩矩阵秩可能无法适应不同任务的需求。对于一些难度较大的问题,模型可能需要更高的表达能力,而低秩限制则可能导致模型生成的答案质量下降。
为了解决这一问题,AdaLoRA在LoRA的基础上引入了动态调整秩的机制。具体来说,AdaLoRA会在训练过程中根据不同层次的梯度变化自动调整秩的大小,使得模型能够在处理复杂问题时增加秩来提升表达能力,而在简单问题上减少秩以节省计算资源。这种自适应调整机制使得AdaLoRA能够更灵活地应对医学考试中的多样化问题,保证了模型在高效性和性能上的平衡。
四、代码分析
4.1 数据预处理
预处理部分直接使用别人处理过的json格式数据,省去了一个对csv的解析步骤。数据集来源:https://huggingface.co/datasets/fzkuji/CMExam
预处理部分均写在项目makeDataset.py中。主要分为两个类,分别是GLM4QADataset和GLM4QATestDataset,用于处理微调测试集和训练集。
通过以下部分代码,将json文件格式化为之前prompt工程最优的提示词:
with open(json_path, "r", encoding='utf-8') as f:
for line in f:
if not line or line == "":
continue
json_line = json.loads(line)
question = json_line["Question"]
# 生成选项内容
options_text = "\n".join([f"{option['key']}: {option['value']}" for option in json_line["Options"]])
# 获取正确答案的选项值
correct_option = json_line["Answer"]
correct_answer = next(
(option["value"] for option in json_line["Options"] if option["key"] == json_line["Answer"]), "")
explanation = json_line["Explanation"]
if explanation and len(explanation) > char_limit:
continue
else :
# 按照要求格式化文本
full_text = f"对于后续问题,你需按照以下格式,向分析和正确选项字段填入答案。\n格式:\n问题:问题内容\n分析:{{填入分析,50字以内}}\n正确选项:{{填入正确的选项和选项文本}}\n\n问题:{question}\n{options_text}\n分析:{explanation}\n正确选项:{correct_option}.{correct_answer}{tokenizer.eos_token}"
self.data.append(full_text)
同时,在GLM4QATestDataset中,加入了额外的return值:保存了”correct_answer”,也就是正确答案的文本,用于后续对比分析。但是这个字段不会放入”input_ids”,也就是输入文本之中。因此额外编写了custom_collate_fn函数,用于自定义dataloader的行为。
另外,在代码中额外加入了HuatuoQADataset功能,用于使用其他数据集的导入,但实际微调过程中并没有实际使用,仅作为一个历史函数保留。
4.2 LoRA微调
接下来就是实际调用AdaLoRA微调了。所有的微调训练代码均在main.py中。主要分为了三个函数:train(),validate(),test(),也就是训练,验证和测试。另外还预留了pretrain()函数,也就是利用huatuo数据集进行预训练,但是同样地没有实际使用,仅作为一个备用功能保留。
首先通过以下代码定义好我们所使用的模型,数据集以及所设定超参数:
#定义所使用的模型
@dataclass
class ModelArguments:
#使用本地文件
model_name_or_path: str = field(default="GLM4-9B", metadata={"help": "Path to the model."})
#使用远程文件
#model_name_or_path: str = field(default="THUDM/glm-4-9b", metadata={"help": "Path to the model."})
#定义数据集的路径(JSON格式)
@dataclass
class DataArguments:
train_data_path: str = field(default="./fzkuji/train.json", metadata={"help": "Path to the training data."})
val_data_path: str = field(default="./fzkuji/valid.json", metadata={"help": "Path to the validation data."})
test_data_path: str = field(default="./fzkuji/test.json", metadata={"help": "Path to the test data."})
# 实例化参数对象
model_args = ModelArguments()
data_args = DataArguments()
training_args = CustomTrainingArguments(
output_dir="./model",
per_device_train_batch_size=8,
per_device_test_batch_size=64,
num_train_epochs=4,
logging_dir="./logs",
model_max_length=256,
char_limit=200,
use_lora=True,
learning_rate=1e-4
)
然后我们再来实际看看各个函数。首先是train()部分。
首先,我们需要引用from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, DataCollatorForLanguageModeling
确定我们使用的tokenizer分词器。当然,在ChatGLM中这个已经自带了,因此直接使用AutoTokenizer.from_pretrained(
指定模型文件夹下的分词器即可。
model_args.model_name_or_path, use_fast=False, trust_remote_code=True)
模型则是通过AutoModelForCausalLM
方法调用,这是生成类模型的一个普遍调用方式。
在微调训练的过程中,peft库是一个非常有效的工具,其内置了AdaLora微调方法以及其他,直接通过from peft import AdaLoraConfig调用即可。
然后通过以下配置文件设置lora配置:
if training_args.use_lora:
peft_config = AdaLoraConfig(
init_r=128,
lora_alpha=256,
target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], # 现存问题只微调部分演示即可
lora_dropout=0.1,
bias="none",
task_type=TaskType.CAUSAL_LM,
inference_mode=False
)
值得注意的是这里的target_modules部分。其他的lora设置,例如init_r和lora_alpha都是较为通用的部分,是Adalora的设置内容。而target_modules则是指定了要对模型的哪些层进行微调。这里使用的["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
是跟着ChatGLM4官方的代码设置的,但是假设后续换了一个模型,由于模型的设计不同,因此模型层的这些设置也需要对应的更改。例如我测试使用Baichuan2模型时,使用的设置则是target_modules=["q_proj", "v_proj"]
接下来就是比较常规的训练动作,设置数据加载器,以及准备模型、优化器等工作,这里没啥特别的,与常规的深度学习方法一致。
主要的训练循环代码如下:
# 开始训练循环
for epoch in range(training_args.num_train_epochs):
torch.cuda.empty_cache() # 清理显存缓存
model.train()
epoch_loss = 0
with tqdm(train_dataloader, unit="batch") as tepoch:
for step, batch in enumerate(tepoch):
tepoch.set_description(f"Epoch {epoch}")
outputs = model(**batch)
loss = outputs.loss
epoch_loss += loss.item()
accelerator.backward(loss)
if step % 10 == 0:
tepoch.set_postfix(loss=loss.item(),
step_time=f"{tepoch.format_dict['elapsed'] / (step + 1):.2f}s/it")
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch} completed. Average Training Loss: {epoch_loss / len(train_dataloader)}")
# 验证循环
torch.cuda.empty_cache() # 清理显存缓存
validate(model, tokenizer, val_dataloader, training_args, epoch, accelerator)
在这段代码里,我加上了tqdm进度条,精简一下输出。需要注意的是,由于LLM模型占用的显存较大,而且每次训练后显存资源不会自动释放,因此多轮epoch训练之后会导致爆显存,因此在每次训练之后都使用了torch.cuda.empty_cache()清理显存缓存。
每轮epoch训练完成后,都会调用validate(),也就是用此时的模型测试验证集,看看每轮epoch的结果如何。当然,因为每次验证都会比较耗时,因此在手动调参的时候验证结果的时候也可以禁用validate(),提高训练速度。
validate()和test()的总体逻辑一致,都是在验证集或者测试集上让模型自主输出回答。因此核心代码在于
outputs = model.generate(
input_ids=batch['input_ids'].to(accelerator.device),
attention_mask=batch['attention_mask'].to(accelerator.device),
max_new_tokens=256,
do_sample=False,
num_beams=1,
)
# 获取生成的预测文本
test_predictions.extend([tokenizer.decode(pred, skip_special_tokens=True) for pred in outputs])
也就是使用model.generate函数生成文本,并通过tokenizer解码输出至test_predictions的list中。这里有一些超参数设置,其中num_beams指定了beam搜索数量,越大文本质量越高,但是消耗的显存也越高。为了兼顾训练效率,因此这里设置为1。
另外,代码中还使用了accelerator作为训练加速器,直接使用 accelerator = Accelerator()初始化即可,虽然效果感觉不是很明显。
在训练完成后,会将所有生成的文本保存至test_predictions.csv文件中。还需要进一步分析,以确定模型的得分情况。
4.3 模型得分分析
分析的部分我全部写在了check.py中,可以很方便地指定需要分析的文件路径。
由于我们输出保存的CSV只有模型的输出结果,以及通过之前所说的json文件格式化后的答案文本,这些数据不够Cmexam数据集的列表中这么精确,因此还是需要将模型的输出匹配回原本的Cmexam数据集,找到模型的回答对应的是原始模型的哪一个提问。
首先对原始测试集csv文件做处理,使其格式匹配生成文本的形式(因为测试集中问题,选项,解析等部分都是按不同列分开的),也就是问题+选项,后面分析部分则不予处理:
# 定义函数,将 Options 列转换为与 Generated Text 一致的格式,并去除每个选项文本后的多余空格
def format_options(options_text):
options = options_text.split('\n')
formatted_options = [re.sub(r"([A-E]) (.+)", r"\1: \2", opt).strip() for opt in options if opt]
return '\n'.join(formatted_options)
# 对测试集的 Options 列进行格式化处理
test_data['Formatted Options'] = test_data['Options'].apply(format_options)
由于之前应用了prompt工程,因此第一步就是把之前的prompt提示都给去除,留下问题和模型生成的文本,方便后面的处理。去除的函数代码如下:
# 去除 Generated Text 列开头的提示文本
def remove_prompt(text):
prompt_pattern = r"^对于后续问题,你需按照以下格式,向分析和正确选项字段填入答案。\s+格式:\s+问题:问题内容\s+分析:\{填入分析,50字以内\}\s+正确选项:\{填入正确的选项和选项文本\}"
return re.sub(prompt_pattern, "", text).strip()
# 应用函数去除生成数据集中 Generated Text 列的开头提示文本
generated_data['Generated Text'] = generated_data['Generated Text'].apply(remove_prompt)
接下来,就是根据格式化成与生成文本相同形式的Formatted Options,通过是否包含对应的问题+选项,来确定生成的回答对应的问题。这一步主要是在实际测试中发现,有些问题题干很短,诸如:“下列选项正确的是:”。这样的题干很容易匹配错误,因此使用问题+选项的形式进行匹配。
# 使用包含关系来匹配,并同时检查 Question 和 Formatted Options,禁用正则表达式匹配
for index, row in test_data.iterrows():
question = row['Question']
formatted_options = row['Formatted Options']
matched_text = generated_data[
generated_data['Generated Text'].str.contains(question, na=False, regex=False) &
generated_data['Generated Text'].str.contains(formatted_options, na=False, regex=False)
]
if not matched_text.empty:
test_data.at[index, 'Generated Text'] = matched_text.iloc[0]['Generated Text']
匹配完成后,就可以计算正确率和各项评分了。正确率计算较为简单,使用正则表达式提取出模型回答的正确选项,并与正确答案比对即可。bleu和rouge评分则是需要使用正则表达式提取出模型生成的分析部分,并与原始测试集中的解析部分进行比较。具体的计算方法倒是较为简单,直接调库即可。
五、实验结果与性能评估
本次实验使用的是Autodl租用的单卡A800 80G服务器,配套14核CPU与100G内存。在这样的硬件组合下,可以设置训练batch_size=8左右。我尝试了使用自己的RTX4070 12G显卡训练,即使是batch_size=1也会爆显存无法训练。当然在这里没有使用半精度训练或者量化等方法(因为性能够也没有必要)。训练一轮epoch的时间约为6小时左右,我习惯设置只训练2轮epoch,加上验证和测试的时间,一次训练总共耗时需要15小时左右,还是比较花时间的。
5.1 评估指标
- 准确率(Accuracy)
准确率是衡量分类任务最常用的指标之一,表示模型预测正确的样本数量与总样本数量的比值。在本项目中,准确率用于评估模型在多选题或判断题上的表现。如果模型预测的答案与实际正确答案一致,则计为一个正确的样本。 - BLEU(Bilingual Evaluation Understudy)
BLEU是一种常用于机器翻译和文本生成任务的指标,衡量生成文本与参考文本的相似度。BLEU通过计算生成文本与参考文本之间的词汇匹配程度来得分,得分范围为0到1,其中得分越高,表示生成文本越接近参考文本。BLEU的计算方式为按不同n-gram(词组)等级进行匹配,通常包括1-gram、2-gram、3-gram和4-gram,分别表示生成文本中单词、二元组、三元组和四元组的匹配程度。 - ROUGE(Recall-Oriented Understudy for Gisting Evaluation)
ROUGE是一组用于文本生成任务的指标,特别适用于摘要生成和回答生成。ROUGE指标通过计算生成文本与参考文本之间的重合率来衡量其质量,主要包括ROUGE-N、ROUGE-L和ROUGE-W等。常用的ROUGE指标有:
·ROUGE-N:基于n-gram的重合率,计算生成文本和参考文本之间的n-gram匹配比例。例如,ROUGE-1表示单词匹配,ROUGE-2表示二元组匹配。
·ROUGE-L:基于最长公共子序列的匹配度,适用于衡量生成文本和参考文本的句子结构相似度。
·ROUGE-W:对LCS加权,重点衡量连续的匹配词序列。
5.2 微调结果
感叹一下ChatGLM4在中文的理解能力上确实牛逼。在微调之前,只使用prompt工程,就可以将准确率提高至72%左右,大幅领先于论文中ChatGPT4的数据。而在通过lora微调之后,准确率达到了75.6%,大约实现了5%的性能提升,证明了Lora微调的有效性。
Models | ChatGPT4 (-) | ChatGLM4-Base+SFT (9B) |
---|---|---|
Prediction_Acc | 61.60 | 75.60 |
Reasoning_BLUE-1 | 0.17 | 40.24 |
Reasoning_BLUE-4 | 0.06 | 12.14 |
Reasoning_ROUGE-1 | 29.74 | 61.19 |
Reasoning_ROUGE-2 | 14.84 | 33.42 |
参考链接:
Benchmarking Large Language Models on CMExam – A Comprehensive Chinese Medical Exam Dataset
fzkuji/CMExam
CMExam 中文医学考试数据集介绍
LLM大模型推理输出生成方式总结
LoRA、QLoRA与AdaLoRA的低秩适配:如何让AI语言模型瘦身不减智?
[大模型微调技术] LoRA、QLoRA、QA-LoRA 原理笔记