- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章pyTorch深度学习softmax实现解析由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
num_inputs = 2 #feature numbernum_examples = 1000 #训练样本个数true_w = torch.tensor([[2],[-3.4]]) #真实的权重值true_b = torch.tensor(4.2) #真实的biassamples = torch.normal(0,1,(num_examples,num_inputs))noise = torch.normal(0,0.01,(num_examples,1))labels = samples.matmul(true_w) + true_b + noise
class LinearNet(nn.Module): def __init__(self,in_features): super().__init__() self.fc = nn.Linear(in_features=2,out_features=1) def forward(self,t): t = self.fc(t) return t
import torch.utils.data as Datadataset = Data.TensorDataset(samples,labels)#类似于zip,把两个张量打包data_loader = Data.DataLoader(dataset,batch_size=100,shuffle=True)
network = LinearNet(2)optimizer = optim.SGD(network.paramters(),lr=0.05)
for epoch in range(10): total_loss = 0 for data,label in data_loader: predict = network(data) loss = F.mse_loss(predict,label) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() print( 'epoch',epoch, 'loss',total_loss, 'weight',network.weight, 'bias',network.bias )
。
sotfmax主要用于分类任务。regression最终得到的是一个scalar,根据input中的feature线性相加得到一个output。分类任务的结果是一个类别,是离散的。 假设现在有一批图片是2 * 2大小的灰度图片,这样图片中的每隔二像素用一个标量表示就行了。这批图片一种是三类小动物,第一类是小狗,第二类是小猫,第三类是小兔子。 每张图片总共4个像素点,我们可以看作是4个feature,假设这三类小动物的图片线性可分,每一类对应一组weight和一个bias.
可以根据输出值较大的来决定哪一类,可这样有个问题,首先输出值没有明确的意义,且可能是实数范围。其次,不好衡量输出值与真实值之间的差距。所以采用softmax操作,将三个输出值转化成概率值,这样输出结果满足概率分布。label采用one-hot编码,相当于对应类别的概率是1,这样就可以用cross_entropy来计算loss.
本次学习softmax模型采用torchvision.datasets中的Fashion-MNIST.
import torchvisionimport torchvision.transforms as transformstrain_set = torchvision.datasets.FashionMNIST( root='./data', train=True, download=True, transform=transforms.ToTensor())
transforms.ToTensor()将尺寸为(H x W x C)且数据位于(0,255)的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为C x H x W且数据类型为torch.float32且位于(0.0,1.0)的Tensor 。
len(train_set),len(test_set)> (60000,10000)
展示一下数据集中的图片 。
import matplotlib.pyplot as pltplt.figure(figsize=(10,10))for i,(image,lable) in enumerate(train_set,start=1): plt.subplot(1,10,i) plt.imshow(image.squeeze()) plt.title(train_set.classes[lable]) plt.axis('off') if i == 10: breakplt.show()
train_loader = torch.utils.data.DataLoader(train_set,batch_size=100,shuffle=True,num_workers=4)test_loader = torch.utils.data.DataLoader(test_set,batch_size=100,shuffle=False,num_workers=1)
def net(samples,w,b): samples = samples.flatten(start_dim=1) #将c,h,w三个轴展成一个feature轴,长度为28 * 28 samples = torch.exp(samples)#全体元素取以e为底的指数 partial_sum = samples.sum(dim=1,keepdim=True) samples = samples / partial_sum #归一化,得概率,这里还应用了广播机制 return samples.matmul(w) + b
i表示label对应的种类,pi为真实种类的预测概率,log是以e为底的对数 这里gather函数的作用,就是在predict上取到对应label的概率值,注意负号不能丢,pytorch中的cross_entropy对输入先进行一次softmax操作,以保证输入都是正的.
def net(samples,w,b): samples = samples.flatten(start_dim=1) #将c,h,w三个轴展成一个feature轴,长度为28 * 28 samples = torch.exp(samples)#全体元素取以e为底的指数 partial_sum = samples.sum(dim=1,keepdim=True) samples = samples / partial_sum #归一化,得概率,这里还应用了广播机制 return samples.matmul(w) + b
import torchimport torchvisionimport torch.nn as nnimport torch.nn.functional as Fimport torch.utils.data as Dataimport torchvision.transforms as transformsimport torch.optim as optimimport torch.nn.init as initclass SoftmaxNet(nn.Module): def __init__(self,in_features,out_features): super().__init__() self.fc = nn.Linear(in_features=in_features,out_features=out_features) def forward(self,t): t = t.flatten(start_dim=1) t = self.fc(t) return ttrain_set = torchvision.datasets.FashionMNIST( root='E:\project\python\jupyterbook\data', train=True, download=True, transform=transforms.ToTensor())test_set = torchvision.datasets.FashionMNIST( root='E:\project\python\jupyterbook\data', train=False, download=True, transform=transforms.ToTensor())train_loader = Data.DataLoader( train_set, batch_size=100, shuffle=True, #num_workers=2)test_loader = Data.DataLoader( test_set, batch_size=100, shuffle=False, #num_workers=2)@torch.no_grad()def get_correct_nums(predict,labels): return predict.argmax(dim=1).eq(labels).sum().item()@torch.no_grad()def evaluate(test_loader,net,total_num): correct = 0 for image,label in test_loader: predict = net(image) correct += get_correct_nums(predict,label) pass return correct / total_numnetwork = SoftmaxNet()optimizer = optim.SGD(network.parameters(),lr=0.05)for epoch in range(10): total_loss = 0 total_correct = 0 for image,label in train_loader: predict = network(image) loss = F.cross_entropy(predict,label) total_loss += loss.item() total_correct += get_correct_nums(predict,label) optimizer.zero_grad() loss.backward() optimizer.step() pass print( 'epoch',epoch, 'loss',total_loss, 'train_acc',total_correct / len(train_set), 'test_acc',evaluate(test_loader,network,len(test_set)) )
以上就是pytorch深度学习softmax实现解析的详细内容,更多关于pytorch深度学习的资料请关注我其它相关文章! 。
原文链接:https://blog.csdn.net/qq_43152622/article/details/116850268 。
最后此篇关于pyTorch深度学习softmax实现解析的文章就讲到这里了,如果你想了解更多关于pyTorch深度学习softmax实现解析的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我一直在使用 AJAX 从我正在创建的网络服务中解析 JSON 数组时遇到问题。我的前端是一个简单的 ajax 和 jquery 组合,用于显示从我正在创建的网络服务返回的结果。 尽管知道我的数据库查
很难说出这里要问什么。这个问题模棱两可、含糊不清、不完整、过于宽泛或夸夸其谈,无法以目前的形式得到合理的回答。如需帮助澄清此问题以便重新打开,visit the help center . 关闭 1
我在尝试运行 Android 应用程序时遇到问题并收到以下错误 java.lang.NoClassDefFoundError: com.parse.Parse 当我尝试运行该应用时。 最佳答案 在这
有什么办法可以防止etree在解析HTML内容时解析HTML实体吗? html = etree.HTML('&') html.find('.//body').text 这给了我 '&' 但我想
我有一个有点疯狂的例子,但对于那些 JavaScript 函数作用域专家来说,它看起来是一个很好的练习: (function (global) { // our module number one
关闭。此题需要details or clarity 。目前不接受答案。 想要改进这个问题吗?通过 editing this post 添加详细信息并澄清问题. 已关闭 8 年前。 Improve th
我需要编写一个脚本来获取链接并解析链接页面的 HTML 以提取标题和其他一些数据,例如可能是简短的描述,就像您链接到 Facebook 上的内容一样。 当用户向站点添加链接时将调用它,因此在客户端启动
在 VS Code 中本地开发时,包解析为 C:/Users//AppData/Local/Microsoft/TypeScript/3.5/node_modules/@types//index而不是
我在将 json 从 php 解析为 javascript 时遇到问题 这是我的示例代码: //function MethodAjax = function (wsFile, param) {
我在将 json 从 php 解析为 javascript 时遇到问题 这是我的示例代码: //function MethodAjax = function (wsFile, param) {
我被赋予了将一种语言“翻译”成另一种语言的工作。对于使用正则表达式的简单逐行方法来说,源代码过于灵活(复杂)。我在哪里可以了解更多关于词法分析和解析器的信息? 最佳答案 如果你想对这个主题产生“情绪化
您好,我在解析此文本时遇到问题 { { { {[system1];1;1;0.612509325}; {[system2];1;
我正在为 adobe after effects 在 extendscript 中编写一些代码,最终变成了 javascript。 我有一个数组,我想只搜索单词“assemble”并返回整个 jc3_
我有这段代码: $(document).ready(function() { // }); 问题:FB_RequireFeatures block 外部的代码先于其内部的代码执行。因此 who
背景: netcore项目中有些服务是在通过中间件来通信的,比如orleans组件。它里面服务和客户端会指定网关和端口,我们只需要开放客户端给外界,服务端关闭端口。相当于去掉host,这样省掉了些
1.首先贴上我试验成功的代码 复制代码 代码如下: protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec)
什么是 XML? XML 指可扩展标记语言(eXtensible Markup Language),标准通用标记语言的子集,是一种用于标记电子文件使其具有结构性的标记语言。 你可以通过本站学习 X
【PHP代码】 复制代码 代码如下: $stmt = mssql_init('P__Global_Test', $conn) or die("initialize sto
在SQL查询分析器执行以下代码就可以了。 复制代码代码如下: declare @t varchar(255),@c varchar(255) declare table_cursor curs
前言 最近练习了一些前端算法题,现在做个总结,以下题目都是个人写法,并不是标准答案,如有错误欢迎指出,有对某道题有新的想法的友友也可以在评论区发表想法,互相学习🤭 题目 题目一: 二维数组中的
我是一名优秀的程序员,十分优秀!