gpt4 book ai didi

matlab - 多标签多类数据的平均精度

转载 作者:太空宇宙 更新时间:2023-11-03 20:32:29 25 4
gpt4 key购买 nike

我正在尝试编写代码来计算多标签数据的平均精度 ( MAP )。为了更直观的理解请看下面

enter image description here

我已经在 MATLAB 中编写了 MAP 计算的代码,但速度很慢。本质上它是,因为要为每个r 值计算变量Lrx

我想让我的代码更快。

function [map] = map_at_R(sim_x,L_tr,L_te)

%sim_x(i,j) denote the sim bewteen query j and database i
tn = size(sim_x,2);
APx = zeros(tn,1);
R = 100;

for i = 1 : tn
Px = zeros(R,1);
deltax = zeros(R,1);
label = L_te(i,:);
[~,inxx] = sort(sim_x(:,i),'descend');

% compute Lx - the denominator in the map calculation
% Lx = 1 if the retrieved item has the same label with the query or
% shares atleast one label else Lx = 0
search_set = L_tr(inxx(1:R),:);

for r = 1 : R
%% FAST COMPUTATION
Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);

%% SLOW COMPUTATION
% Lrx = 0;
% for j=1:r
% if sum(label*(search_set(j,:)).')>0
% Lrx = Lrx+1;
% end
% end

if sum(label*(search_set(r,:)).')>0
deltax(r) = 1;
end

Px(r) = Lrx/r;
end
Lx = sum(deltax);
if Lx ~=0
APx(i) = sum(Px.*deltax)/Lx;
end
end
map = mean(APx);

代码的输入是这样的:

% sim_x = similarity score matrix or distance matrix
sim_x = gallery_data_size X probe_data_size

% L_tr = labels of the gallery set
L_tr = gallery_data_size X c

% L_te = labels of the probe set
L_te = probe_data_size X c

% where c is the number of classes
% please note that the data is multi-label

是否可以使代码更快?我自己无法弄清楚。

最佳答案

使用 delta 函数 APx(i) = sum(Px.*deltax)/Lx 您将丢弃一部分 r = 1:R 迭代。既然 delta 可以在循环之前定义,为什么不只遍历 r where deltax(r) == 1

% r_range is equivalent to find(deltax(r) == 1);
%Edit 1/4 %Previously :: r_range = find(sum(label*(search_set(1:R,:)).')>0);
% Multiply each row by label
mult = bsxfun(@times,(search_set(1:R,:)),label);
% Sum each row
r_range = find(sum(mult,2)>0);
% r_range @ i should equal find(deltax) @ i

Px = zeros(numel(r_range,1);

for r = r_range
Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);
Px(r == r_range) = Lrx/r;
end

Lx = numel(r_range);
if Lx ~=0
APx(i) = sum(Px)/Lx;
end

关于matlab - 多标签多类数据的平均精度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48003041/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com