gpt4 book ai didi

Tensorflow Recommenders : what is features in computed_loss method in tfrs. 模型类(来自检索教程)

转载 作者:行者123 更新时间:2023-12-05 05:58:17 24 4
gpt4 key购买 nike

我正在关注 Retrieval tutorial来自 TFRS(TensorFlow Recommenders)库,我对这部分感到困惑:

class MovielensModel(tfrs.Model):

def __init__(self, user_model, movie_model):
super().__init__()
self.movie_model: tf.keras.Model = movie_model
self.user_model: tf.keras.Model = user_model
self.task: tf.keras.layers.Layer = task

def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
# We pick out the user features and pass them into the user model.
user_embeddings = self.user_model(features["user_id"])
# And pick out the movie features and pass them into the movie model,
# getting embeddings back.
positive_movie_embeddings = self.movie_model(features["movie_title"])

# The task computes the loss and the metrics.
return self.task(user_embeddings, positive_movie_embeddings)

接下来是:

model = MovielensModel(user_model, movie_model)
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))

我对这些代码块有疑问:

  • 当它说 user_embeddings = self.user_model(features["user_id"])(还有 (self.movi​​e_model(features["movie_title"]))时好像features是一个字典,但是在教程之前的任何部分都没有定义。我还检查了compute_loss源代码here,看它是否是一个属性那个方法之类的,但我也没有找到任何东西......所以我的问题是,什么是 features?代码如何运行良好,运行之前未定义的代码?我在类外尝试了它,只运行了这个:user_model(features["user_id"]) 当然,说 features 没有定义是行不通的。但是,为什么它会在类被实例化和稍后被编译时起作用?(上面的第二段代码)。

非常感谢!

最佳答案

Features 是用字典构造的 tf.data.Dataset 的一行。您可以通过在字典上使用 tf.data.Dataset.from_tensor_slices 创建相同的数据集。

例如:

dict_ = {'user_id' : [1, 2], 'movie_title' : ['foo', 'bar']}
dataset = tf.data.Dataset.from_tensor_slices(dict_)

for features in dataset :
print(features)

将返回:

{'user_id': <tf.Tensor: shape=(), dtype=int32, numpy=1>, 'movie_title': <tf.Tensor: shape=(), dtype=string, numpy=b'foo'>}
{'user_id': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'movie_title': <tf.Tensor: shape=(), dtype=string, numpy=b'bar'>}

请注意,dataset['user_id'] 将返回错误,因为 tf.data.Dataset 无法以这种方式调用。 MovielensModel(tfrs.Model)compute_loss 函数正在接收 tf.data.Dataset 的项目,但它需要张量才能正常工作.因此,该函数仅保留与“user_id”部分(或“movie_title”)关联的张量。

这个:

for features in dataset:
print(features['user_id'])

将返回一个张量:

tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

关于Tensorflow Recommenders : what is features in computed_loss method in tfrs. 模型类(来自检索教程),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68564382/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com