Let PagedList work with iterators.
This commit is contained in:
parent
4742d152c8
commit
c9376e2bd3
2 changed files with 14 additions and 5 deletions
|
|
@ -3,11 +3,18 @@
|
||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
def batch(iter_, page_size):
|
||||||
|
for _, batch in itertools.groupby(
|
||||||
|
enumerate(iter_), lambda tuple_: tuple_[0] // page_size):
|
||||||
|
yield [value for index, value in batch]
|
||||||
|
|
||||||
|
|
||||||
class PagedList:
|
class PagedList:
|
||||||
|
|
||||||
def __init__(self, list_, pages_dir, page_size, cache_size, exist_ok=False,
|
def __init__(self, list_, pages_dir, page_size, cache_size, exist_ok=False,
|
||||||
|
|
@ -16,19 +23,17 @@ class PagedList:
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self.cache_size = cache_size
|
self.cache_size = cache_size
|
||||||
self.open_func = open_func
|
self.open_func = open_func
|
||||||
self._len = len(list_)
|
self._len = 0
|
||||||
tmp_dir = pages_dir + ".tmp"
|
tmp_dir = pages_dir + ".tmp"
|
||||||
if exist_ok:
|
if exist_ok:
|
||||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||||
shutil.rmtree(pages_dir, ignore_errors=True)
|
shutil.rmtree(pages_dir, ignore_errors=True)
|
||||||
os.makedirs(tmp_dir)
|
os.makedirs(tmp_dir)
|
||||||
pages = ([[]] if len(list_) == 0 else
|
for index, page in enumerate(batch(list_, page_size)):
|
||||||
(list_[start:start+self.page_size]
|
|
||||||
for start in range(0, len(list_), self.page_size)))
|
|
||||||
for index, page in enumerate(pages):
|
|
||||||
pickle_path = os.path.join(tmp_dir, str(index))
|
pickle_path = os.path.join(tmp_dir, str(index))
|
||||||
with self.open_func(pickle_path, "wb") as file_:
|
with self.open_func(pickle_path, "wb") as file_:
|
||||||
pickle.dump(page, file_, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(page, file_, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
self._len += len(page)
|
||||||
self.page_count = index + 1
|
self.page_count = index + 1
|
||||||
os.rename(tmp_dir, self.pages_dir)
|
os.rename(tmp_dir, self.pages_dir)
|
||||||
self._setup_page_cache()
|
self._setup_page_cache()
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,10 @@ import eris.paged_list as paged_list
|
||||||
|
|
||||||
class PagedListTestCase(unittest.TestCase):
|
class PagedListTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_batch(self):
|
||||||
|
self.assertEqual(list(paged_list.batch(iter([3,4,5,6,7]), 2)),
|
||||||
|
[[3, 4], [5, 6], [7]])
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
list_ = paged_list.PagedList([3, 4, 5, 6], temp_dir, 4, 2)
|
list_ = paged_list.PagedList([3, 4, 5, 6], temp_dir, 4, 2)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue