240707_昇思学习打卡-Day19-基于MindSpore通过GPT实现情感分类

240707_昇思学习打卡-Day19-基于MindSpore通过GPT实现情感分类

今天基于GPT实现一个情感分类的功能,假设已经安装好了MindSpore环境。

# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com

导包导包

import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn

from mindnlp.dataset import load_dataset

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
# 加载IMDb数据集
imdb_ds = load_dataset('imdb', split=['train', 'test'])
# 获取训练集
imdb_train = imdb_ds['train']
# 获取测试集
imdb_test = imdb_ds['test']

# 调用get_dataset_size方法来获取训练集的大小
imdb_train.get_dataset_size()
import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    """
    处理数据集,使用tokenizer对文本进行编码,并根据指定的batch大小和序列长度组织数据。
    
    参数:
    - dataset: 需要处理的数据集,包含文本和标签。
    - tokenizer: 用于将文本转换为token序列的tokenizer。
    - max_seq_len: 最大序列长度,超过该长度的序列将被截断。
    - batch_size: 打包数据的批次大小。
    - shuffle: 是否在处理数据集前对其进行洗牌。
    
    返回:
    - 经过tokenization和batch处理后的数据集。
    """
    # 判断是否在Ascend设备上运行
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    
    def tokenize(text):
        """
        对文本进行tokenization,并返回input_ids和attention_mask。
        
        参数:
        - text: 需要被tokenize的文本。
        
        返回:
        - tokenize后的input_ids和attention_mask。
        """
        # 根据设备类型选择合适的tokenization方法
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    # 如果需要洗牌,对数据集进行洗牌操作
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # 对数据集进行tokenization操作
    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    # 将标签转换为int32类型
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # 根据设备类型选择合适的批次处理方法
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset
import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    """
    处理数据集,使用tokenizer对文本进行编码,并根据指定的batch大小和序列长度组织数据。
    
    参数:
    - dataset: 需要处理的数据集,包含文本和标签。
    - tokenizer: 用于将文本转换为token序列的tokenizer。
    - max_seq_len: 最大序列长度,超过该长度的序列将被截断。
    - batch_size: 打包数据的批次大小。
    - shuffle: 是否在处理数据集前对其进行洗牌。
    
    返回:
    - 经过tokenization和batch处理后的数据集。
    """
    # 判断是否在Ascend设备上运行
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    
    def tokenize(text):
        """
        对文本进行tokenization,并返回input_ids和attention_mask。
        
        参数:
        - text: 需要被tokenize的文本。
        
        返回:
        - tokenize后的input_ids和attention_mask。
        """
        # 根据设备类型选择合适的tokenization方法
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    # 如果需要洗牌,对数据集进行洗牌操作
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # 对数据集进行tokenization操作
    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    # 将标签转换为int32类型
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # 根据设备类型选择合适的批次处理方法
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset

# 导入来自mindnlp库transformers模块中的GPTTokenizer类
from mindnlp.transformers import GPTTokenizer

# 初始化GPT分词器,使用预训练的'openai-gpt'模型
# 分词器
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')

# 定义一个特殊token字典,包括开始、结束和填充token
special_tokens_dict = {
    "bos_token": "<bos>",  # 开始符号
    "eos_token": "<eos>",  # 结束符号
    "pad_token": "<pad>",  # 填充符号
}

# 向分词器中添加特殊token,并返回添加的token数量
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

# 将训练数据集imdb_train分割成训练集和验证集
# 按照70%训练集和30%验证集的比例进行划分
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)
# 调用create_tuple_iterator方法创建一个迭代器,并通过next函数获取迭代器的第一个元素
# 这里的目的是为了展示或测试迭代器是否能正常生成数据
# 对于参数和返回值的详细说明,需要查看create_tuple_iterator方法的文档或实现
next(dataset_train.create_tuple_iterator())

# 导入GPT序列分类模型与Adam优化器
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam

# 初始化GPT模型用于序列分类任务,设置标签数量为2(二分类任务)
# 设置模型配置并定义训练参数
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
# 配置模型的填充标记ID以匹配分词器设置
model.config.pad_token_id = gpt_tokenizer.pad_token_id
# 调整令牌嵌入层大小以适应新增词汇量
model.resize_token_embeddings(model.config.vocab_size + 3)

# 使用2e-5的学习率初始化Adam优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

# 初始化准确度指标来评估模型性能
metric = Accuracy()

# 定义回调函数以在训练过程中保存检查点
# 定义保存检查点的回调函数
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
# 初始化最佳模型回调函数以保存表现最优的模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)

# 初始化训练器,包括模型、训练数据集、评估数据集、性能指标、优化器以及回调函数
trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_train, metrics=metric,
                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=False)
# 导入GPT序列分类模型与Adam优化器
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam

# 初始化GPT模型用于序列分类任务,设置标签数量为2(二分类任务)
# 设置模型配置并定义训练参数
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
# 配置模型的填充标记ID以匹配分词器设置
model.config.pad_token_id = gpt_tokenizer.pad_token_id
# 调整令牌嵌入层大小以适应新增词汇量
model.resize_token_embeddings(model.config.vocab_size + 3)

# 使用2e-5的学习率初始化Adam优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

# 初始化准确度指标来评估模型性能
metric = Accuracy()

# 定义回调函数以在训练过程中保存检查点
# 定义保存检查点的回调函数
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
# 初始化最佳模型回调函数以保存表现最优的模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)

# 初始化训练器,包括模型、训练数据集、评估数据集、性能指标、优化器以及回调函数
trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_train, metrics=metric,
                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=False)

# 执行模型训练
trainer.run(tgt_columns="labels")

# 初始化Evaluator对象,用于评估模型性能
# 参数说明:
# network: 待评估的模型
# eval_dataset: 用于评估的测试数据集
# metrics: 评估指标
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)

# 执行模型评估,指定目标列作为评估标签
# 该步骤将计算模型在测试数据集上的指定评估指标
evaluator.run(tgt_columns="labels")

打卡图片:

image-20240707192022631

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/782785.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Git 查看、新建、删除、切换分支

Git 是一个版本控制系统&#xff0c;软件开发者用它来跟踪应用程序的变化并进行项目协作。 分支的诞生便于开发人员在彼此独立的环境中进行开发工作。主分支&#xff08;通常是 main 或 master&#xff09;可以保持稳定&#xff0c;而新的功能或修复可以在单独的分支中进行开发…

STM32智能无人机控制系统教程

目录 引言环境准备智能无人机控制系统基础代码实现&#xff1a;实现智能无人机控制系统 4.1 数据采集模块 4.2 数据处理与飞行控制 4.3 通信与导航系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;无人机应用与优化问题解决方案与优化收尾与总结 1. 引言 智能无人机控…

AI工具杂谈

AI是在帮助开发者还是取代他们&#xff1f; 在软件开发领域&#xff0c;生成式人工智能&#xff08;AIGC&#xff09;正在改变开发者的工作方式。无论是代码生成、错误检测还是自动化测试&#xff0c;AI工具正在成为开发者的得力助手。然而&#xff0c;这也引发了对开发者职业…

哪个牌子的护眼大路灯质量好呢?性价比高的五款护眼大路灯分享

护眼大路灯可以说是是每个有娃家庭必不可少的照明神器。但面对市场上琳琅满目的护眼落地灯品牌和型号&#xff0c;很多消费者在选购时都会犯难&#xff1a;究竟哪个牌子的护眼大路灯质量好呢&#xff1f;哪个性价比高呢&#xff1f;本文将根据市场反馈以及性价比等各方面&#…

Table 表格--分页序号自增

代码&#xff1a; import { Space, Table, Tag } from antd; import type { ColumnsType } from antd/es/table; import React, { useState } from react;interface DataType {key: string;name: string;age: number;address: string;tags: string[]; }const data: DataType[]…

6K star! 部署本地运行LLM的AI助手,零基础入门到精通超详细

AI套壳千千万万&#xff0c;你最喜欢哪一款&#xff1f;现在各种ChatGPT替代品层出不穷&#xff0c;但是大部分都是使用OpenAI的API&#xff0c;也就说离不开网络。 今天我们推荐的开源项目它就是要帮你100%在本地运行大模型&#xff0c;进而构建一个属于自己的ChatGPT&#x…

使用 Docker 部署一个文档管理系统,让宝贵文档不在丢失!

大家好,我是CodeQi! 一位热衷于技术分享的码仔。 BookStack 是一个开源的文档管理系统,非常适合用来创建和组织文档。 通过 Docker,我们可以轻松地将 BookStack 部署到本地或服务器上。 本文将详细介绍如何使用 Docker 搭建 BookStack。 项目预览 登录页面

element-plus 的form表单组件之el-radio(单选按钮组件)

单选按钮组件适用于同一组类型的选项只能互斥选择的场景&#xff0c;就是支持单选。单选组件包含以下3个组件 组件名作用el-radio-group单选组组件&#xff0c;子元素可以是el-radio或el-radio-button&#xff0c;v-mode绑定单选组的响应式属性el-radio单选组件&#xff0c;la…

如何确保工业展厅设计既专业又吸引?三原则详解!

工业是民族发展的基石&#xff0c;它为我们带来了无数的便利和进步&#xff0c;而为了让更多人了解这个至关重要的产业&#xff0c;以及其背后的技术和产品&#xff0c;许多工业性质的企业都致力于通过互动投影、虚拟现实、全息投影等多媒体技术&#xff0c;来打造独具特色的工…

起底:Three.js和Cesium.js,二者异同点,好比全科和专科.

Three.js和Cesium.js是两个常用的webGL引擎&#xff0c;很多小伙伴容易把它们搞混淆了&#xff0c;今天威斯数据来详细介绍一下&#xff0c;他们的起源、不同点和共同点&#xff0c;阅读后你就发现二者就像全科医院和专科医院的关系&#xff0c;很好识别。 一、二者的起源 Th…

LiveNVR监控流媒体Onvif/RTSP用户手册-录像回看:录像通道、查看录像、设备录像、云端录像查、时间轴视图、录像分享

LiveNVR监控流媒体Onvif/RTSP用户手册-录像回看:录像通道、查看录像、设备录像、云端录像查、时间轴视图、录像分享 1、录像回看1.1、查看录像1.1.1、时间轴视图1.1.2、列表视图 2、如何分享时间轴录像回看&#xff1f;3、iframe集成示例4、RTSP/HLS/FLV/RTMP拉流Onvif流媒体服…

RabbitMQ(集群相关部署)

RabbitMQ 集群部署 环境准备&#xff1a;阿里云centos8 服务器&#xff0c;3台服务器&#xff0c;分别进行安装&#xff1b; 下载Erlang Erlang和RabbitMQ版本对照&#xff1a;https://www.rabbitmq.com/which-erlang.html 创建yum库配置文件 vim /etc/yum.repos.d/rabbi…

Soong 构建系统

背景 Soong 构建系统在Android 7.0开始引入&#xff0c;目的是取代Make。它利用Kati GNU Make 和Ninja构建系统组件来构建Android Soong是用Go语言写的&#xff0c;go环境在prebuilts/go环境下&#xff0c;Soong在编译时&#xff0c;解析bp文件&#xff0c;转化成Ninja文件&am…

互联网留给网站建设的,也就一个门缝了,抓紧往高端进发吧。

高端定制网站具有以下价值&#xff1a; 独特性&#xff1a;高端定制网站能够根据企业的品牌形象和定位进行设计&#xff0c;呈现独特的风格和用户体验。这有助于提升企业的品牌形象和差异化竞争力&#xff0c;使企业在竞争激烈的市场中脱颖而出。用户体验&#xff1a;高端定制…

vue-使用Worker实现多标签页共享一个WebSocket

文章目录 前言一、SharedWorker 是什么SharedWorker 是什么SharedWorker 的使用方式SharedWorker 标识与独占 二、Demo使用三、使用SharedWorker实现WebSocket共享 前言 最近有一个需求&#xff0c;需要实现用户系统消息时时提醒功能。第一时间就是想用WebSocket进行长连接。但…

14-47 剑和诗人21 - 2024年如何打造AI创业公司

​​​​​ 2024 年&#xff0c;随着人工智能继续快速发展并融入几乎所有行业&#xff0c;创建一家人工智能初创公司将带来巨大的机遇。然而&#xff0c;在吸引资金、招聘人才、开发专有技术以及将产品推向市场方面&#xff0c;人工智能初创公司也面临着相当大的挑战。 让我来…

下一代 CLI 工具,使用Go语言用于构建令人惊叹的网络应用程序

大家好&#xff0c;今天给大家分享一个创新的命令行工具Gowebly CLI&#xff0c;它专注于使用Go语言来快速构建现代Web应用程序。 Gowebly CLI 是一款免费开源软件&#xff0c;有助于在后端使用 Go、在前端使用 htmx 和 hyperscript 以及最流行的 CSS 框架轻松构建令人惊叹的 W…

Maven Nexus3 私服搭建、配置、项目发布指南

maven nexus私服搭建 访问nexus3官方镜像库,选择需要的版本下载:Docker Nexus docker pull sonatype/nexus3:3.49.0 创建数据目录并赋权 sudo mkdir /nexus-data && sudo chown -R 200 /nexus-data 运行(数据目录选择硬盘大的卷进行挂载) docker run -d -p 808…

AI集成工具平台一站式体验,零门槛使用国内外主流大模型

目录 0 写在前面1 AI艺术大师1.1 绘画制图1.2 智能作曲 2 AI科研助理2.1 学术搜索2.2 自动代码 3 AI智能对话3.1 聊天机器人3.2 模型竞技场 4 特别福利 0 写在前面 人工智能大模型浪潮滚滚&#xff0c;正推动着千行百业的数智化进程。随着技术演进&#xff0c;2024年被视为是大…

数据库开发:mysql基础一

文章目录 数据库开发Day15&#xff1a;MySQL基础&#xff08;一&#xff09;一、MySQL介绍与安装【1】MySQL介绍&#xff08;5&#xff09;启动MySQL服务&#xff08;6&#xff09;修改root登陆密码 二、SQL简介三、数据库操作四、数据表操作4.1、数据库数据类型4.2、创建数据表…