gpt4 book ai didi

python - 类方法返回迭代器

转载 作者:行者123 更新时间:2023-11-30 22:19:16 26 4
gpt4 key购买 nike

我实现了一个迭代器类,如下所示:

import numpy as np
import time


class Data:

def __init__(self, filepath):
# Computationaly expensive
print("Computationally expensive")
time.sleep(10)
print("Done!")

def __iter__(self):
return self

def __next__(self):
return np.zeros((2,2)), np.zeros((2,2))


count = 0
for batch_x, batch_y in Data("hello.csv"):
print(batch_x, batch_y)
count = count + 1

if count > 5:
break


count = 0
for batch_x, batch_y in Data("hello.csv"):
print(batch_x, batch_y)
count = count + 1

if count > 5:
break

但是构造函数的计算量很大,并且 for 循环可能会被调用多次。例如,在上面的代码中,构造函数被调用两次(每次 for 循环创建一个新的 Data 对象)。

如何分离构造函数和迭代器?我希望有以下代码,其中构造函数仅被调用一次:

data = Data(filepath)

for batch_x, batch_y in data.get_iterator():
print(batch_x, batch_y)

for batch_x, batch_y in data.get_iterator():
print(batch_x, batch_y)

最佳答案

你可以直接迭代一个可迭代对象,for..in不需要任何其他东西:

data = Data(filepath)

for batch_x, batch_y in data:
print(batch_x, batch_y)

for batch_x, batch_y in data:
print(batch_x, batch_y)

也就是说,取决于您如何实现 __iter__() ,这可能有问题。

例如:

不好

class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
self._i = 0
def __iter__(self): return self
def __next__(self):
if self._i >= len(self._items): # Or however you check if data is available
raise StopIteration
result = self._items[self._i]
self._i += 1
return result

因为这样你就不能迭代同一个对象两次,如self._i仍然会指向循环的末尾。

不错

class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
self._i = 0
return self
def __next__(self):
if self._i >= len(self._items):
raise StopIteration
result = self._items[self._i]
self._i += 1
return result

每次您要迭代时,这都会重置索引,从而修复上述问题。如果您在同一对象上嵌套迭代,则这将不起作用。

更好

要解决这个问题,请将迭代状态保留在单独的迭代器对象中:

class Data:
class Iter:
def __init__(self, data):
self._data = data
self._i = 0
def __next__(self):
if self._i >= len(self._data._items): # check for available data
raise StopIteration
result = self._data._items[self._i]
self._i = self._i + 1
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
return self.Iter(self)

这是最灵活的方法,但如果您可以使用以下任一方法,则不必要的冗长。

简单,使用yield

如果您使用 Python 的生成器,该语言将负责为您跟踪迭代状态,并且即使在嵌套循环时它也应该正确地执行此操作:

class Data:
def __init__(self, filepath):
self._items= load_items(filepath)
def __iter__(self):
for it in self._items: # Or whatever is appropriate
yield return it

简单,传递到底层可迭代

如果“计算成本较高”的部分是将所有数据加载到内存中,则可以直接使用缓存的数据。

class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
return iter(self._items)

关于python - 类方法返回迭代器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49178906/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com