Coggle 30 Days of ML(23年7月)任务七:训练TextCNN模型

news/2024/7/5 21:00:21

Coggle 30 Days of ML(23年7月)任务七:训练TextCNN模型

任务七:使用Word2Vec词向量,搭建TextCNN模型进行训练和预测

  • 说明:在这个任务中,你将使用Word2Vec词向量,搭建TextCNN模型进行文本分类的训练和预测,通过卷积神经网络来进行文本分类。
  • 实践步骤:
    1. 准备Word2Vec词向量模型和相应的训练数据集。
    2. 构建TextCNN模型,包括卷积层、池化层、全连接层等。
    3. 将Word2Vec词向量应用到模型中,作为词特征的输入。
    4. 使用训练数据集对TextCNN模型进行训练。
    5. 使用训练好的TextCNN模型对测试数据集进行预测

导入训练好的Word2Vec模型

由于上一部分我们已经训练好了我们的模型,所以这一部分我们直接导入即可

# 准备Word2Vec词向量模型和训练数据集
word2vec_model = Word2Vec.load("word2vec.model")

在数据分析的时候,我们已经发现,词语的数量是不一的,所以首先我们先对数据进行处理,将文本序列转化为词向量表示,并且填充为长度为200

# 获取Word2Vec词向量的维度
embedding_dim = word2vec_model.vector_size

# 转换训练数据集的文本序列为词向量表示,并进行填充
train_sequences = []
for text in train_data:
    sequence = [word2vec_model.wv[word] for word in text if word in word2vec_model.wv]
    padded_sequence = pad_sequences([sequence], maxlen=max_length, padding='post', truncating='post')[0]
    train_sequences.append(padded_sequence)

# 转换测试数据集的文本序列为词向量表示,并进行填充
test_sequences = []
for text in test_data:
    sequence = [word2vec_model.wv[word] for word in text if word in word2vec_model.wv]
    padded_sequence = pad_sequences([sequence], maxlen=max_length, padding='post', truncating='post')[0]
    test_sequences.append(padded_sequence)

构建TextCNN模型

在这里插入图片描述

接下来我们就开始构建一下TextCNN模型,包括卷积层、池化层、全连接层等,这样我们就初步得到一个非常简单的模型了

# 构建TextCNN模型
model = tf.keras.Sequential()
model.add(layers.Conv1D(128, 5, activation='relu', input_shape=(max_length, embedding_dim)))
model.add(layers.GlobalMaxPooling1D())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(num_classes, activation='softmax'))

训练模型

接下来我们就可以开始训练我们的模型了,这里使用了SGD优化器进行操作,在训练之前,我们还需要把训练数据集的标签转换为one-hot编码

# 设置优化器和学习率
optimizer = optimizers.SGD(learning_rate=0.1)  # 使用SGD优化器,并设置学习率为0.1

# 编译模型
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])


# 转换训练数据集的标签为one-hot编码
train_labels = tf.keras.utils.to_categorical(train_labels)

# 训练模型
model.fit(np.array(train_sequences), train_labels, epochs=5, batch_size=32)
Epoch 1/5
438/438 [==============================] - 5s 6ms/step - loss: 0.4385 - accuracy: 0.8454
Epoch 2/5
438/438 [==============================] - 2s 5ms/step - loss: 0.4309 - accuracy: 0.8454
Epoch 3/5
438/438 [==============================] - 3s 6ms/step - loss: 0.4308 - accuracy: 0.8454
Epoch 4/5
438/438 [==============================] - 2s 6ms/step - loss: 0.4307 - accuracy: 0.8454
Epoch 5/5
438/438 [==============================] - 2s 6ms/step - loss: 0.4309 - accuracy: 0.8454

可能是模型太简单了,所以可以看到,通过训练以后,准确率也没有较大的提升,还可以继续改进

预测与提交

最后使用训练好的TextCNN模型对测试数据集进行预测,得到csv数据以后进行提交

# 预测测试数据集的分类结果
predictions = model.predict(np.array(test_sequences))
predicted_labels = predictions.argmax(axis=1)

# 读取提交样例文件
submit = pd.read_csv('ChatGPT/sample_submit.csv')
submit = submit.sort_values(by='name')

# 将预测结果赋值给提交文件的label列
submit['label'] = predicted_labels

# 保存提交文件
submit.to_csv('ChatGPT/textcnn.csv', index=None)

总结

这个TextCNN模型太过于简单,所以以至于可能没有学习到很多的数据,接下来可以进行调参和设置合理的模型结构,以期得到更好的结果,再接再厉,加油!!!


http://www.niftyadmin.cn/n/4394722.html

相关文章

鸿蒙生态电饭煲,新日XC3成为华为鸿蒙生态的首辆电动车

导语:5月18日,华为在上海召开“HarmonyOS Connect伙伴峰会”,与各场景合作伙伴共同探讨鸿蒙系统全新生态带来的商业价值和未来趋势,前期已经加入鸿蒙生态的华为各场景合作伙伴,携创新产品、解决方案在峰会亮相。早在3月的天津发布会上,新日就已经宣布了与华为的深度…

html中的时钟如何自动,HTML+CSS入门 如何实现模拟时钟

本篇教程介绍了HTMLCSS入门 如何实现模拟时钟&#xff0c;希望阅读本篇文章以后大家有所收获&#xff0c;帮助大家HTMLCSS入门。<1 html>2 3 4 5 模拟时钟6 7 body {8 margin: 0;9 padding: 0;10 }1112 #bl…

html图片爆炸效果,HTML5 粒子爆炸效果

JavaScript语言&#xff1a;JaveScriptBabelCoffeeScript确定var size 55; // 25 is the sweet spot for me, how high can you get?var rot 0;var interval 1000;var explosions 0;var maxExplosions 100; // the total number of fireworks youll get. Kindof like at …

html选择文件框选择wood文件,CSS结构性伪类选择器—nth-of-type实现自定义导航菜单案例解析(代码实例)...

本文目标&#xff1a;1、掌握CSS中结构性伪类选择器—nth-of-type的用法问题&#xff1a;1、实现以下自定义导航菜单&#xff0c;且使用纯DIVCSS&#xff0c;必须使用结构性伪类选择器—nth-of-type附加说明&#xff1a;1、导航宽800px&#xff0c;高90px&#xff0c;居中显示2…

Redis学习手册(事务)(转)

2019独角兽企业重金招聘Python工程师标准>>> 一、概述&#xff1a; 和众多其它数据库一样&#xff0c;Redis作为NoSQL数据库也同样提供了事务机制。在Redis中&#xff0c;MULTI/EXEC/DISCARD/WATCH这四个命令是我们实现事务的基石。相信对有关系型数据库开发经…

ant压缩html,ant+yuicompressor压缩js/css

antyuicompressor在前端部署和服务器发布中用处极大。我在项目中用到的是antyuicompressor&#xff0c;ant可以压缩CSS&#xff0c;压缩javascript,并可以传输信息和远程服务器发布&#xff01;目前yuicompressor最新版本为yuicompressor-2.4.8.jar&#xff0c;请使用该版本的j…

rmd文件怎么转换html文件,RMD文件扩展名 - 什么是.rmd以及如何打开? - ReviverSoft...

你在这里因为你有&#xff0c;有一个文件扩展名结尾的​​文件 .rmd. 文件与文件扩展名 .rmd 只能通过特定的应用程序推出。这有可能是 .rmd 文件是数据文件&#xff0c;而不是文件或媒体&#xff0c;这意味着他们并不是在所有观看。什么是一&nbsp.rmd&nbsp文件&#x…

SnapKit自动布局(二)

也许你在写OC的时候已经用过了Masonry这个第三方库来写自动布局&#xff0c;今天我们来说说Swift版本的Masonry第三方库SnapKit SnapKit 今天我们来做一个稍稍复杂的东西。 snp_updateConstraints 效果图如下。 Show Your Code var button: UIButton! var scacle:CGFloat 1.0o…