一、Bert类模型分类方法
在文本分类任务中,BERT模型会接收到输入文本,并通过其双向编码器捕捉文本中的关键特征。这些特征反映了文本的语义内容和对分类任务的重要性。具体来说:
[CLS]标记:在输入序列的开始添加特殊的[CLS]标记,模型用这个标记来捕捉整个序列的整体语义信息。
多头自注意力机制:BERT模型的注意力机制允许模型在处理一个单词时同时考虑其在序列中的前后上下文。
前馈神经网络和残差连接:这些结构增强了模型的表示能力,使得模型能够学习到更加复杂的模式。
层归一化:通过对输出进行归一化,使得模型层的输出能够在不同的任务中稳定地更新。
在分类时,通常使用[CLS]标记的输出作为整个输入文本的代表向量,然后将其输入到分类层进行最终的分类预测。使用[CLS]输入到线性分类层中,将结果映射到所有分类类别上,对应位置的数值经过softmax函数后可以得到对应类别的概率(置信度)
二、生成式大模型分类方法
利用大模型进行分类通常有微调和prompt工程的方法。
1.prompt工程
针对分类任务编写prompt,利用大模型的自身能力对文本进行分类。在分类过程中可以采用选择、填空等方式要求模型生成对应的类别。如果需要生成对应的分类概率,可以在prompt中要求模型生成对应的类别概率。
例如:
你是一个分类模型。以下是各个类别的定义:
类别1:
类别2:
...
请根据类别的定义给出输入文本对应的类别的得分,得分为1-100,所有类别得分相加等于100。
以下是一些示例:【few-shot】
请对以下文本进行判断:
【待判断文本】
2.指令微调
采用few-shot的方式可以应对一些简单任务,但对于一些复杂任务则需要指令微调的方式。
为了使得可以使用token生成概率作为分类概率,在指令设置的时候要求模型直接输出对应的分类编号。
例如 【是】和【否】或者在prompt中对分类进行编号,使用英文字母代替分类类别:A:类别1;B:类别2...,在输出时要求模型输出分类的标识符和对应的类别名称。经过微调后,模型输出的第一个字符就被限制在【是】和【否】之间或者【A, B, C...】之中,此时获取模型输出第一个token对应的概率既可以代替分类的概率。
实践代码:
# transformers
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to('npu')
outputs = model.generate(
model_inputs.input_ids,
do_sample=False, # 设置为False,严格按照概率输出
return_dict_in_generate=True, # 按照词典的格式返回结果
output_scores=True, # 返回token对应的分数
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, outputs.sequences)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# 获得每个token对应的概率
probs = torch.stack(outputs.scores, dim=1).softmax(-1)
token_prob = probs[0][0][generated_ids[0][0]].item()
# vllm
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
logprobs=1, # 返回token概率
stop=STOP_SEQUENCES,
)
results_generator = self.model.generate(
inputs=inputs,
sampling_params=sampling_params,
request_id=self.task_id + str(self.index),
)
token_ids = output.outputs[0].token_ids[-1]
text_output = output.outputs[0].logprobs[-1][token_ids].decoded_token
logprobs = math.exp(output.outputs[0].logprobs[-1][token_ids].logprob)