- VisualStudio2022插件的安装及使用-编程手把手系列文章
- pprof-在现网场景怎么用
- C#实现的下拉多选框,下拉多选树,多级节点
- 【学习笔记】基础数据结构:猫树
为了AIGC的学习,我做了一个基于Transformer Models模型完成GPT2的学生AIGC学习训练模型,指在训练模型中学习编程AI.
在编程之前需要准备一些文件:
首先,先win+R打开运行框,输入:PowerShell后 。
输入:
pip install -U huggingface_hub 。
下载完成后,指定我们的环境变量:
$env:HF_ENDPOINT = "https://hf-mirror.com" 。
然后下载模型:
huggingface-cli download --resume-download gpt2 --local-dir "D:\Pythonxiangmu\PythonandAI\Transformer Models\gpt-2" 。
这边我的目录是我要下载的工程目录地址 。
然后下载数据量:
huggingface-cli download --repo-type dataset --resume-download wikitext --local-dir "D:\Pythonxiangmu\PythonandAI\Transformer Models\gpt-2" 。
这边我的目录是我要下载的工程目录地址 。
所以两个地址记得更改成自己的工程目录下(建议放在创建一个名为gpt-2的文件夹) 。
在PowerShell中下载完这些后,可以开始我们的代码啦 。
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AdamW,
get_linear_schedule_with_warmup,
set_seed,
)
from torch.optim import AdamW
# 设置随机种子以确保结果可复现
set_seed(42)
class TextDataset(Dataset):
def __init__(self, tokenizer, texts, block_size=128):
self.tokenizer = tokenizer
self.examples = [
self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=block_size) for
text
in texts]
# 在tokenizer初始化后,确保unk_token已设置
print(f"Tokenizer's unk_token: {self.tokenizer.unk_token}, unk_token_id: {self.tokenizer.unk_token_id}")
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
item = self.examples[i]
# 替换所有不在vocab中的token为unk_token_id
for key in item.keys():
item[key] = torch.where(item[key] >= self.tokenizer.vocab_size, self.tokenizer.unk_token_id, item[key])
return item
def train(model, dataloader, optimizer, scheduler, de, tokenizer):
model.train()
for batch in dataloader:
input_ids = batch['input_ids'].to(de)
# 添加日志输出检查input_ids
if torch.any(input_ids >= model.config.vocab_size):
print("Warning: Some input IDs are outside the model's vocabulary.")
print(f"Max input ID: {input_ids.max()}, Vocabulary Size: {model.config.vocab_size}")
attention_mask = batch['attention_mask'].to(de)
labels = input_ids.clone()
labels[labels[:, :] == tokenizer.pad_token_id] = -100
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
def main():
local_model_path = "D:/Pythonxiangmu/PythonandAI/Transformer Models/gpt-2"
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
# 确保pad_token已经存在于tokenizer中,对于GPT-2,它通常自带pad_token
if tokenizer.pad_token is None:
special_tokens_dict = {'pad_token': '[PAD]'}
tokenizer.add_special_tokens(special_tokens_dict)
model = AutoModelForCausalLM.from_pretrained(local_model_path, pad_token_id=tokenizer.pad_token_id)
else:
model = AutoModelForCausalLM.from_pretrained(local_model_path)
model.to(device)
train_texts = [
"The quick brown fox jumps over the lazy dog.",
"In the midst of chaos, there is also opportunity.",
"To be or not to be, that is the question.",
"Artificial intelligence will reshape our future.",
"Every day is a new opportunity to learn something.",
"Python programming enhances problem-solving skills.",
"The night sky sparkles with countless stars.",
"Music is the universal language of mankind.",
"Exploring the depths of the ocean reveals hidden wonders.",
"A healthy mind resides in a healthy body.",
"Sustainability is key for our planet's survival.",
"Laughter is the shortest distance between two people.",
"Virtual reality opens doors to immersive experiences.",
"The early morning sun brings hope and vitality.",
"Books are portals to different worlds and minds.",
"Innovation distinguishes between a leader and a follower.",
"Nature's beauty can be found in the simplest things.",
"Continuous learning fuels personal growth.",
"The internet connects the world like never before."
# 更多训练文本...
]
dataset = TextDataset(tokenizer, train_texts, block_size=128)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
optimizer = AdamW(model.parameters(), lr=5e-5)
total_steps = len(dataloader) * 5 # 假设训练5个epoch
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
for epoch in range(5): # 训练5个epoch
train(model, dataloader, optimizer, scheduler, device, tokenizer) # 使用正确的变量名dataloader并传递tokenizer
# 保存微调后的模型
model.save_pretrained("path/to/save/fine-tuned_model")
tokenizer.save_pretrained("path/to/save/fine-tuned_tokenizer")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
main()
这个代码只训练了5个epoch,有一些实例文本,记得调成直接的路径后,运行即可啦.
如果有什么问题可以随时在评论区或者是发个人邮箱:linyuanda@linyuanda.com 。
最后此篇关于[Python急救站]基于TransformerModels模型完成GPT2的学生AIGC学习训练模型的文章就讲到这里了,如果你想了解更多关于[Python急救站]基于TransformerModels模型完成GPT2的学生AIGC学习训练模型的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
关闭。这个问题是opinion-based .它目前不接受答案。 想要改进这个问题? 更新问题,以便 editing this post 可以用事实和引用来回答它. 关闭 9 年前。 Improve
介绍篇 什么是MiniApis? MiniApis的特点和优势 MiniApis的应用场景 环境搭建 系统要求 安装MiniApis 配置开发环境 基础概念 MiniApis架构概述
我正在从“JavaScript 圣经”一书中学习 javascript,但我遇到了一些困难。我试图理解这段代码: function checkIt(evt) { evt = (evt) ? e
package com.fastone.www.javademo.stringintern; /** * * String.intern()是一个Native方法, * 它的作用是:如果字
您会推荐哪些资源来学习 AppleScript。我使用具有 Objective-C 背景的传统 C/C++。 我也在寻找有关如何更好地开发和从脚本编辑器获取更快文档的技巧。示例提示是“查找要编写脚本的
关闭。这个问题不满足Stack Overflow guidelines .它目前不接受答案。 想改善这个问题吗?更新问题,使其成为 on-topic对于堆栈溢出。 4年前关闭。 Improve thi
关闭。这个问题不满足Stack Overflow guidelines .它目前不接受答案。 想改善这个问题吗?更新问题,使其成为 on-topic对于堆栈溢出。 7年前关闭。 Improve thi
关闭。这个问题不符合 Stack Overflow guidelines 。它目前不接受答案。 想改善这个问题吗?更新问题,以便堆栈溢出为 on-topic。 6年前关闭。 Improve this
我是塞内加尔的阿里。我今年60岁(也许这是我真正的问题-笑脸!!!)。 我正在学习Flutter和Dart。今天,我想使用给定数据模型的列表(它的名称是Mortalite,请参见下面的代码)。 我尝试
关闭。这个问题是off-topic .它目前不接受答案。 想改进这个问题? Update the question所以它是on-topic对于堆栈溢出。 9年前关闭。 Improve this que
学习 Cappuccino 的最佳来源是什么?我从事“传统”网络开发,但我对这个新框架非常感兴趣。请注意,我对 Objective-C 毫无了解。 最佳答案 如上所述,该网站是一个好地方,但还有一些其
我正在学习如何使用 hashMap,有人可以检查我编写的这段代码并告诉我它是否正确吗?这个想法是有一个在公司工作的员工列表,我想从 hashMap 添加和删除员工。 public class Staf
我正在尝试将 jQuery 与 CoffeScript 一起使用。我按照博客中的说明操作,指示使用 $ -> 或 jQuery -> 而不是 .ready() 。我玩了一下代码,但我似乎无法理解我出错
还在学习,还有很多问题,所以这里有一些。我正在进行 javascript -> PHP 转换,并希望确保这些做法是正确的。是$dailyparams->$calories = $calories;一条
我目前正在学习 SQL,以便从我们的 Magento 数据库制作一个简单的 RFM 报告,我目前可以通过导出两个查询并将它们粘贴到 Excel 模板中来完成此操作,我想摆脱 Excel 模板。 我认为
我知道我很可能会因为这个问题而受到抨击,但没有人问,我求助于你。这是否是一个正确的 javascript > php 转换 - 在我开始不良做法之前,我想知道这是否是解决此问题的正确方法。 JavaS
除了 Ruby-Doc 之外,哪些来源最适合获取一些示例和教程,尤其是关于 Ruby 中的 Tk/Tile?我发现自己更正常了 http://www.tutorialspoint.com/ruby/r
我只在第一次收到警告。这正常吗? >>> cv=LassoCV(cv=10).fit(x,y) C:\Python27\lib\site-packages\scikit_learn-0.14.1-py
按照目前的情况,这个问题不适合我们的问答形式。我们希望答案得到事实、引用或专业知识的支持,但这个问题可能会引发辩论、争论、投票或扩展讨论。如果您觉得这个问题可以改进并可能重新打开,visit the
As it currently stands, this question is not a good fit for our Q&A format. We expect answers to be
我是一名优秀的程序员,十分优秀!