- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
对不起,我不知道,我不知道在哪里可以找到解决方案。我正在使用两个网络来构造两个嵌入,我有二进制目标来指示 embeddingA 和 embeddingB 是否“匹配”(1 或 -1)。像这样的数据集:
embA0 embB0 1.0
embA1 embB1 -1.0
embA2 embB2 1.0
...
希望利用余弦相似度来得到分类结果。但是在选择损失函数的时候感觉很迷茫,生成embeddings的两个网络是分开训练的,现在可以想到两个方案如下:
方案一:
构建第3个网络,将embeddingA和embeddingB作为nn.cosinesimilarity()的输入计算最终结果(应该是[-1,1]中的概率),然后选择一个二分类损失函数。
(抱歉,我不知道该选择哪个损失函数。)
class cos_Similarity(nn.Module):
def __init__(self):
super(cos_Similarity,self).__init__()
cos=nn.CosineSimilarity(dim=2)
embA=generator_A()
embB=generator_B()
def forward(self,a,b):
output_a=embA(a)
output_b=embB(b)
return cos(output_a,output_b)
loss_func=nn.CrossEntropyLoss()
y=cos_Similarity(a,b)
loss=loss_func(y,target)
acc=np.int64(y>0)
方案二:这两个Embeddings作为输出,然后使用nn.CosineEmbeddingLoss()作为损失函数,当我计算准确率时,我使用nn.Cosinesimilarity()输出结果([-1,1]中的概率)。
output_a=embA(a)
output_b=embB(b)
cos=nn.CosineSimilarity(dim=2)
loss_function = torch.nn.CosineEmbeddingLoss()
loss=loss_function(output_a,output_b,target)
acc=cos(output_a,output_b)
我真的需要帮助。我该如何选择?为什么?或者只能通过实验结果为我做出选择。非常感谢!
###############################添加
def train_func(train_loss_list):
train_data=load_data('train')
trainloader = DataLoader(train_data, batch_size=BATCH_SIZE)
cos_smi=nn.CosineSimilarity(dim=2)
train_loss = 0
for step,(a,b,target) in enumerate(trainloader):
try:
optimizer.zero_grad()
output_a = model_A(a) #generate embA
output_b = model_B(b) #generate embB
acc=cos_smi(output_a,output_b)
loss = loss_fn(output_a,output_b, target.unsqueeze(dim=1))
train_loss += loss.item()
loss.backward()
optimizer.step()
train_loss_list.append(loss)
if step%10==0:
print('train:',step,'step','loss:',loss,'acc',acc)
except Exception as e:
print('train:',step,'step')
print(repr(e))
return train_loss_list,train_loss/len(trainloader)
最佳答案
回应评论线程。
目标或管道似乎是:
我能想到的是以下几点。如果我误解了什么,请纠正我。免责声明是,我几乎是在不知道任何细节的情况下根据我的直觉编写代码,因此如果您尝试运行它可能会充满错误。让我们仍然尝试获得更高层次的理解。
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, num_emb, emb_dim): # I'm assuming the embedding matrices are same sizes.
self.embedding1 = nn.Embedding(num_embeddings=num_emb, embedding_dim=emb_dim)
self.embedding2 = nn.Embedding(num_embeddings=num_emb, embedding_dim=emb_dim)
self.cosine = nn.CosineSimilarity()
self.sigmoid = nn.Sigmoid()
def forward(self, a, b):
output1 = self.embedding1(a)
output2 = self.embedding2(b)
similarity = self.cosine(output1, output2)
output = self.sigmoid(similarity)
return output
model = Model(num_emb, emb_dim)
if torch.cuda.is_available():
model = model.to('cuda')
model.train()
criterion = loss_function()
optimizer = some_optimizer()
for epoch in range(num_epochs):
epoch_loss = 0
for batch in train_loader:
optimizer.zero_grad()
a, b, label = batch
if torch.cuda.is_available():
a = a.to('cuda')
b = b.to('cuda')
label = label.to('cuda')
output = model(a, b)
loss = criterion(output, label)
loss.backward()
optimizer.step()
epoch_loss += loss.cpu().item()
print("Epoch %d \t Loss %.6f" % epoch, epoch_loss)
我省略了一些细节(例如,超参数值、损失函数和优化器等)。这个整体过程是否与您正在寻找的 OP 相似?
关于python - 关于余弦相似度,损失函数和网络如何选择(我有两个方案),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63750215/
C语言sscanf()函数:从字符串中读取指定格式的数据 头文件: ?
最近,我有一个关于工作预评估的问题,即使查询了每个功能的工作原理,我也不知道如何解决。这是一个伪代码。 下面是一个名为foo()的函数,该函数将被传递一个值并返回一个值。如果将以下值传递给foo函数,
CStr 函数 返回表达式,该表达式已被转换为 String 子类型的 Variant。 CStr(expression) expression 参数是任意有效的表达式。 说明 通常,可以
CSng 函数 返回表达式,该表达式已被转换为 Single 子类型的 Variant。 CSng(expression) expression 参数是任意有效的表达式。 说明 通常,可
CreateObject 函数 创建并返回对 Automation 对象的引用。 CreateObject(servername.typename [, location]) 参数 serv
Cos 函数 返回某个角的余弦值。 Cos(number) number 参数可以是任何将某个角表示为弧度的有效数值表达式。 说明 Cos 函数取某个角并返回直角三角形两边的比值。此比值是
CLng 函数 返回表达式,此表达式已被转换为 Long 子类型的 Variant。 CLng(expression) expression 参数是任意有效的表达式。 说明 通常,您可以使
CInt 函数 返回表达式,此表达式已被转换为 Integer 子类型的 Variant。 CInt(expression) expression 参数是任意有效的表达式。 说明 通常,可
Chr 函数 返回与指定的 ANSI 字符代码相对应的字符。 Chr(charcode) charcode 参数是可以标识字符的数字。 说明 从 0 到 31 的数字表示标准的不可打印的
CDbl 函数 返回表达式,此表达式已被转换为 Double 子类型的 Variant。 CDbl(expression) expression 参数是任意有效的表达式。 说明 通常,您可
CDate 函数 返回表达式,此表达式已被转换为 Date 子类型的 Variant。 CDate(date) date 参数是任意有效的日期表达式。 说明 IsDate 函数用于判断 d
CCur 函数 返回表达式,此表达式已被转换为 Currency 子类型的 Variant。 CCur(expression) expression 参数是任意有效的表达式。 说明 通常,
CByte 函数 返回表达式,此表达式已被转换为 Byte 子类型的 Variant。 CByte(expression) expression 参数是任意有效的表达式。 说明 通常,可以
CBool 函数 返回表达式,此表达式已转换为 Boolean 子类型的 Variant。 CBool(expression) expression 是任意有效的表达式。 说明 如果 ex
Atn 函数 返回数值的反正切值。 Atn(number) number 参数可以是任意有效的数值表达式。 说明 Atn 函数计算直角三角形两个边的比值 (number) 并返回对应角的弧
Asc 函数 返回与字符串的第一个字母对应的 ANSI 字符代码。 Asc(string) string 参数是任意有效的字符串表达式。如果 string 参数未包含字符,则将发生运行时错误。
Array 函数 返回包含数组的 Variant。 Array(arglist) arglist 参数是赋给包含在 Variant 中的数组元素的值的列表(用逗号分隔)。如果没有指定此参数,则
Abs 函数 返回数字的绝对值。 Abs(number) number 参数可以是任意有效的数值表达式。如果 number 包含 Null,则返回 Null;如果是未初始化变量,则返回 0。
FormatPercent 函数 返回表达式,此表达式已被格式化为尾随有 % 符号的百分比(乘以 100 )。 FormatPercent(expression[,NumDigitsAfterD
FormatNumber 函数 返回表达式,此表达式已被格式化为数值。 FormatNumber( expression [,NumDigitsAfterDecimal [,Inc
我是一名优秀的程序员,十分优秀!