文本分类模型微调实践
原本计划是利用LLama-factory框架基于chinese-bert-wwm这个模型去做微调,但是在实践过程中发现用LLama-factory框架去训练chinese-bert-wwm模型死活跑不通,于是选择换成Qwen-1.8B来进行微调。后面查了很多资料后才了解到BERT属于Transformer的Encoder结构,LLama-factory主要面向Decoder结构的模型,并不直接支持BERT模型,如果想微调BERT,那就需要使用Hugging Face Transformers去进行训练了。
一、安装cuda
1、确认GPU支持的CUDA版本
安装显卡驱动后控制台输入nvidia-smi命令查看
显卡最高支持cuda 12.8,可以安装不高于cuda 12.8的版本
2、安装支持的CUDA版本
Nvidia官网下载可以使用的cuda版本安装即可
3、验证安装
控制台输入nvcc --version
二、安装pytroch和部署LLaMA-Factory
1、创建python虚拟环境
python -m venv pyenv
2、安装gpu版本的pytorch
pytorch官方文档,获取到对应cuda12.8版本的pytroch安装命令
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
3、验证pytorch是否可以调用cuda
python终端,返回True表示pytorch可以调用cuda
import torch
torch.cuda.is_available()
4、安装LLaMA-Factory
LLaMA-Factory官方文档,安装命令如下:
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
# pip install -e ".[torch,metrics]"
# 上面是官方文档里的命令,重新安装了torch,但是这个上一步已经安装过了,去掉就可以了
pip install -e ".[metrics]"
5、验证LLaMA-Factory安装
llamafactory-cli version
成功看到类似下面的界面,就说明安装成功了
6、启动LLaMA-Factory webui
llamafactory-cli webui
三、处理数据
1、数据集选择
这个数据集本身是阿里达摩院创建用于测试文本领域分类模型能力的测试集,这里我们拿来做微调实践也是ok的,下载下来的数据集是一个csv格式的文件,一共有两列,第一例是文本,第二列是分类
2、将数据集处理成LLaMA-Factory支持的格式
这一次的任务是指令监督微调,根据LLaMA-Factory官方文档里面的内容,可以选择Alpaca格式的指令监督微调数据集
指令监督微调数据集 格式要求 如下:
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
["第一轮指令(选填)", "第一轮回答(选填)"],
["第二轮指令(选填)", "第二轮回答(选填)"]
]
}
]
根据我们下载的csv文件的内容,处理为如下格式:
[
{
"instruction": "请分类以下文本",
"input": "", //这里填写csv的text列数据
"output": "" //这里填写csv的label列数据
}
]
处理数据的脚本如下
import pandas as pd
import json
# 读取 CSV 文件
csv_file = r"zh.test.csv"
data = pd.read_csv(csv_file, comment='/', encoding='utf-8') # 忽略注释行
# 清理数据
data = data.dropna() # 去掉空值行
data.columns = ['text', 'label'] # 重命名列名
# 转换为 JSON 格式
json_data = []
for _, row in data.iterrows():
instruction = '请分类以下文本'
input = row['text']
output = row['label']
json_data.append({"instruction": instruction, "input":input, "output": output})
# 保存为 JSON 文件
output_file = r"zh.train.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)
print(f"JSON 数据已保存到 {output_file}")
3.为LLaMA-Factory添加数据集描述
将上一步生成的数据集放到LLaMA-Factory的data目录下
根据LLaMA-Factory官方文档,在data目录下的dataset_info.json中添加数据集描述如下:
"zh_train": {
"file_name": "zh.train.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output"
}
}
四、开始训练
打开webui,模型选择部分选择Qwen/Qwen-1_8B模型,微调方法选择lora
数据集部分选择zh_train
预览数据集内容,确认内容正确
其余训练配置保持不变,点击开始按钮开始训练,等待训练完成。
五、测试微调后的模型
训练完成后,在检查点路径选择刚刚训练的模型
点击下方的chat标签,加载模型,加载完成后如下图所示
测试一下能否获得想要的输出,网上随便找了一段关于教育的新闻,在输入的部分输入如下内容,提交
请分类以下文本
2025年3月22日,第九届高等学校外语教育改革与发展论坛期间,外研社、外研在线举办了“重塑教育的力量:UNIPUS教育创新发布会”,近30,000名数字教育领域知名学者、探索者、实践者通过线上线下形式参与,见证了UNIPUS将教育与AI技术深度融合所带来的革命性教育创新,共话新知、瞭望前沿。
模型回复如下:
获得了想要的输出
五、验证微调后的模型能力
先准备一点测试数据集,这个可以从训练集里面随机取一些数据出来做测试集,或者自己去网上找一些新闻分好类来做测试集,同样的需要将数据集放在LLaMA-Factory的data目录下,再在dataset_info.json中添加数据集描述
"zh_test": {
"file_name": "zh.test.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output"
}
}
webui检查点路径选择刚刚训练的模型,点击下方的evaluate&predict 标签,数据集选择刚刚准备好的测试用数据集zh_test
点击开始进行预测,预测结果如下:
{
"predict_bleu-4": 44.38081697341513,
"predict_model_preparation_time": 0.0035,
"predict_rouge-1": 68.50715746421268,
"predict_rouge-2": 37.73006134969325,
"predict_rouge-l": 68.50715746421268,
"predict_runtime": 47.5357,
"predict_samples_per_second": 20.574,
"predict_steps_per_second": 10.287
}
使用chatgpt分析预测结果如下: