sdk/python: introduce a ffrecord dataloader example with pythonsdk (#5998)

This commit is contained in:
Edric Mo
2025-04-18 11:33:26 +08:00
committed by GitHub
parent 3b9f79f696
commit d84a291121
9 changed files with 687 additions and 3 deletions

View File

@@ -0,0 +1,94 @@
# encoding: utf-8
# JuiceFS, Copyright 2025 Juicedata, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from typing import List, Iterator, Callable
from multiprocessing import Pool
from dataset import FFRecordDataset
import os
import torch
import time
class FFRecordDataLoader(torch.utils.data.DataLoader):
def __init__(
self,
dataset: FFRecordDataset,
batch_size=1,
shuffle: bool = False,
sampler=None,
batch_sampler=None,
num_workers: int = 0,
collate_fn=None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False,
skippable: bool = True):
# use fork to create subprocesses
if num_workers == 0:
multiprocessing_context = None
dataset.initialize()
else:
multiprocessing_context = 'fork'
self.skippable = skippable
super(FFRecordDataLoader,
self).__init__(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
if __name__ == "__main__":
fnames = ["/demo.ffr"]
dataset = FFRecordDataset(fnames, check_data=True)
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
print(f"Worker initialized pid: {os.getpid()}, work_info: {worker_info}")
dataset = worker_info.dataset
dataset.initialize(worker_id=worker_id)
def collate_fn(batch):
return batch
begin_time = time.time()
dataloader = FFRecordDataLoader(dataset, batch_size=1, shuffle=True, num_workers=10, worker_init_fn=worker_init_fn, prefetch_factor=None, collate_fn=collate_fn)
i=0
for batch in dataloader:
# print(i, ": ", batch[0]["index"], "----", time.time()-begin_time)
i+=1
if i>1000:
break
end_time = time.time()
print(f"takes: {end_time-begin_time}")

View File

@@ -0,0 +1,65 @@
# encoding: utf-8
# JuiceFS, Copyright 2025 Juicedata, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from typing import List, Union
from filereader import FileReader
# from filereader_dio import FileReader
import torch
import os
class FFRecordDataset(torch.utils.data.Dataset):
def __init__(self, fnames: Union[str, List[str]], check_data: bool = True):
if isinstance(fnames, str):
fnames = [fnames]
self.reader = FileReader(fnames, check_data=check_data)
self.n = self.reader.n
self.reader.close_fd()
def initialize(self, worker_id=0, num_workers=1):
self.reader.open_fd()
self.n = self.reader.n
def __len__(self) -> int:
return self.n
def __getitem__(self, index: Union[int, List[int]]) -> Union[np.array, List[np.array]]:
if isinstance(index, int):
return self.reader.read_one(index)
elif isinstance(index, list):
return self.reader.read_batch(index)
else:
raise TypeError(f"Index must be int or list, got {type(index)}")
def close(self):
self.reader.close_fd()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
if __name__ == "__main__":
fnames = ["/demo.ffr"]
with FFRecordDataset(fnames, check_data=True) as dataset:
sample = dataset[0]
print("Sample 0:", sample)
batch = dataset[[1, 2, 3]]
print(batch)
print("Dataset length:", len(dataset))

View File

@@ -0,0 +1,167 @@
# encoding: utf-8
# JuiceFS, Copyright 2025 Juicedata, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('.')
from sdk.python.juicefs.juicefs import juicefs
# import juicefs
import zlib
from typing import Union
import struct
import os
import struct
import zlib
from typing import List, Tuple, Optional
import io
import pickle
import numpy as np
MAX_SIZE = 512 * (1 << 20) # 512 MB
def ffcrc32(code: int, data: Union[bytes, bytearray], length: int) -> int:
start = 0
while start < length:
chunk_size = min(MAX_SIZE, length - start)
code = zlib.crc32(data[start:start + chunk_size], code)
start += chunk_size
return code
class FileHeader:
def __init__(self, jfscli: juicefs.Client, fname: str, check_data: bool = True):
self.fname = fname
self.fd = jfscli.open(fname, mode='rb')
self.fd.seek(0)
self.checksum_meta = self._read_uint32()
self.n = self._read_uint64()
self.checksums = [self._read_uint32() for _ in range(self.n)]
self.fd.seek(4+8+4*self.n)
self.offsets = [self._read_uint64() for _ in range(self.n + 1)]
self.offsets[self.n] = jfscli.stat(fname).st_size
if check_data:
self.validate()
self.fd.close()
self.fd = jfscli.open(fname, mode='rb', buffering=0)
self.aiofd = self.fd
def _read_uint32(self) -> int:
return struct.unpack('<I', self.fd.read(4))[0]
def _read_uint64(self) -> int:
return struct.unpack('<Q', self.fd.read(8))[0]
def close_fd(self):
if self.fd:
self.fd.close()
self.fd = None
def validate(self):
if self.checksum_meta == 0:
print("Warning: you are using an old version ffrecord file, please update the file")
return
checksum = 0
checksum = ffcrc32(checksum, struct.pack('<Q', self.n), 8)
checksum = ffcrc32(checksum, struct.pack(f'<{len(self.checksums)}I', *self.checksums), 4 * len(self.checksums))
checksum = ffcrc32(checksum, struct.pack(f'<{len(self.offsets)}Q', *self.offsets), 8 * len(self.offsets) - 8)
assert checksum == self.checksum_meta, f"{self.fname}: checksum of metadata mismatched!"
def access(self, index: int, use_aio: bool = False) -> Tuple[int, int, int, int]:
fd = self.aiofd if use_aio else self.fd
offset = self.offsets[index]
length = self.offsets[index + 1] - self.offsets[index]
checksum = self.checksums[index]
return fd, offset, length, checksum
class FileReader:
def __init__(self, fnames: List[str], check_data: bool = True):
self.fnames = fnames
self.check_data = check_data
self.nfiles = len(fnames)
self.n = 1000
self.nsamples = [0]
self.headers = []
def close_fd(self):
for header in self.headers:
header.close_fd()
self.headers = []
self.n = 0
self.nsamples = [0]
return
def open_fd(self):
self.v = juicefs.Client("myjfs", "redis://localhost", cache_dir="/tmp/data", cache_size="0", debug=False)
for fname in self.fnames:
header = FileHeader(self.v, fname, self.check_data)
self.headers.append(header)
self.n += header.n
self.nsamples.append(self.n)
def validate(self):
for header in self.headers:
header.validate()
def validate_sample(self, index: int, buf: bytes, checksum: int):
if self.check_data:
checksum2 = ffcrc32(0, buf, len(buf))
assert checksum2 == checksum, f"Sample {index}: checksum mismatched!"
def read(self, indices: List[int]):
return self.read_batch(indices)
def read_batch(self, indices: List[int]):
assert not any(index >= self.n for index in indices), "Index out of range"
results = []
for index in indices:
results.append(self.read_one(index))
return results
def read_one(self, index: int):
assert index < self.n, "Index out of range"
fid = 0
while index >= self.nsamples[fid + 1]:
fid += 1
header = self.headers[fid]
fd, offset, length, checksum = header.access(index - self.nsamples[fid], use_aio=False)
fd.seek(offset)
buf = fd.read(length)
self.validate_sample(index, buf, checksum)
res = pickle.loads(buf)
return res
def close(self):
self.close_fd()
if __name__ == "__main__":
fnames = ["/demo.ffr"]
reader = FileReader(fnames, check_data=True)
reader.open_fd()
data = reader.read_one(0)
print(data)
data = pickle.loads(data)
print(data["index"])
print(data["txt"])

View File

@@ -0,0 +1,166 @@
# encoding: utf-8
# JuiceFS, Copyright 2025 Juicedata, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zlib
import os
import struct
from typing import List, Tuple, Union
import numpy as np
MAX_SIZE = 512 * (1 << 20) # 512 MB
DIRECTIO_BLOCK_SIZE = 1 * (1 << 20) # 1 MB
def ffcrc32(code: int, data: Union[bytes, bytearray], length: int) -> int:
start = 0
while start < length:
chunk_size = min(MAX_SIZE, length - start)
code = zlib.crc32(data[start:start + chunk_size], code)
start += chunk_size
return code
class FileHeader:
def __init__(self, fname: str, check_data: bool = True):
print(f"__init__ self: {hex(id(self))}")
print(f"pid: {os.getpid()}")
self.fname = fname
self.fd = os.open(fname, os.O_RDONLY | os.O_DIRECT)
self.aiofd = self.fd
self.file_obj = os.fdopen(self.fd, 'rb', buffering=0)
self.checksum_meta = self._read_uint32()
self.n = self._read_uint64()
checksums_size = 4 * self.n
offsets_size = 8 * (self.n + 1)
combined_data = self.file_obj.read(checksums_size + offsets_size)
self.checksums = list(struct.unpack(f'<{self.n}I', combined_data[:checksums_size]))
self.offsets = list(struct.unpack(f'<{self.n + 1}Q', combined_data[checksums_size:checksums_size + offsets_size]))
self.offsets[self.n] = os.path.getsize(fname)
if check_data:
self.validate()
print("FileHeader initialized for:", fname, "fd:", self.fd)
def _read_uint32(self) -> int:
return struct.unpack('<I', self.file_obj.read(4))[0]
def _read_uint64(self) -> int:
return struct.unpack('<Q', self.file_obj.read(8))[0]
def close_fd(self):
print("close fd: ", self.fd)
if self.fd != -1:
os.close(self.fd)
self.fd = -1
self.file_obj = None
def open_fd(self):
if self.fd == -1:
self.fd = os.open(self.fname, os.O_RDONLY | os.O_DIRECT)
self.aiofd = self.fd
print(f"header.open_fd: {self.fd} address: {hex(id(self))} pid: {os.getpid()}")
def validate(self):
if self.checksum_meta == 0:
print("Warning: you are using an old version ffrecord file, please update the file")
return
checksum = 0
checksum = ffcrc32(checksum, struct.pack('<Q', self.n), 8)
checksum = ffcrc32(checksum, struct.pack(f'<{len(self.checksums)}I', *self.checksums), 4 * len(self.checksums))
checksum = ffcrc32(checksum, struct.pack(f'<{len(self.offsets)}Q', *self.offsets), 8 * len(self.offsets) - 8)
assert checksum == self.checksum_meta, f"{self.fname}: checksum of metadata mismatched!"
def access(self, index: int, use_aio: bool = False) -> Tuple[int, int, int, int]:
fd = self.aiofd if use_aio else self.fd
offset = self.offsets[index]
length = self.offsets[index + 1] - self.offsets[index]
checksum = self.checksums[index]
return fd, offset, length, checksum
class FileReader:
def __init__(self, fnames: List[str], check_data: bool = True):
self.fnames = fnames
self.check_data = check_data
self.nfiles = len(fnames)
self.n = 0
self.nsamples = [0]
self.headers = []
for fname in fnames:
header = FileHeader(fname, check_data)
self.headers.append(header)
self.n += header.n
self.nsamples.append(self.n)
def close_fd(self):
for header in self.headers:
header.close_fd()
def open_fd(self):
print(f"open_fd address: {hex(id(self))} pid: {os.getpid()}")
for header in self.headers:
header.open_fd()
def validate(self):
for header in self.headers:
header.validate()
def validate_sample(self, index: int, buf: bytes, checksum: int):
if self.check_data:
checksum2 = ffcrc32(0, buf, len(buf))
assert checksum2 == checksum, f"Sample {index}: checksum mismatched!"
def read_batch(self, indices: List[int]) -> List[np.array]:
assert not any(index >= self.n for index in indices), "Index out of range"
results = []
for index in indices:
results.append(self.read_one(index))
return results
def read_one(self, index: int) -> np.array:
assert index < self.n, "Index out of range"
fid = 0
while index >= self.nsamples[fid + 1]:
fid += 1
header = self.headers[fid]
fd, offset, length, checksum = header.access(index - self.nsamples[fid], use_aio=False)
buf = bytearray(length)
start = 0
while start < length:
chunk_size = min(DIRECTIO_BLOCK_SIZE, length - start)
read_bytes = os.pread(fd, chunk_size, offset + start)
buf[start:start + chunk_size] = read_bytes
start += chunk_size
self.validate_sample(index, buf, checksum)
array = np.frombuffer(buf, dtype=np.uint8)
return array
if __name__ == "__main__":
fnames = ["/demo.ffr"]
reader = FileReader(fnames, check_data=True)
data = reader.read_one(0)
print(data)

View File

@@ -0,0 +1,173 @@
# encoding: utf-8
# JuiceFS, Copyright 2025 Juicedata, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from pathlib import Path
import loguru
import random
import pickle
from multiprocessing import Pool
import numpy as np
from PIL import Image
from faker import Faker
import io
import time
from tqdm import tqdm
from ffrecord import FileWriter
from ffrecord.torch import Dataset, DataLoader
from ffrecord import FileReader
logger = loguru.logger
fake = Faker()
def serialize(sample):
return pickle.dumps(sample)
def deserialize(sample):
return pickle.loads(sample)
def generate_random_image_np(
width=256,
height=256,
format="JPEG", # JPEG, PNG, WEBP
quality=90, # only for JPEG/WEBP
):
image_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
img = Image.fromarray(image_np)
img_bytes = io.BytesIO()
img.save(img_bytes, format=format, quality=quality)
return img_bytes.getvalue()
def generate_data_entry(
idx,
text=None,
avg_width=1024,
avg_height=1024,
variance=50,
possible_formats=["PNG"],
# possible_formats=["JPEG", "PNG", "WEBP"],
):
"""
- avg_width/avg_height ± variance
"""
image_format = random.choice(possible_formats).lower()
width = random.randint(avg_width - variance, avg_width + variance)
height = random.randint(avg_height - variance, avg_height + variance)
width, height = max(width, 32), max(height, 32)
img_bytes = generate_random_image_np(
width=width,
height=height,
format=image_format.upper(),
)
if text is None:
text = fake.sentence()
return {
"index": idx,
"txt": text,
image_format: img_bytes,
}
def write_ffrecord():
ffr_output = Path(ffrecord_file)
if ffr_output.exists():
logger.warning(f"Output {ffr_output} exists, removing")
logger.info(f"Generating {num_samples} samples")
with Pool(num_proc) as pool:
data_to_write = list(
tqdm(
pool.imap_unordered(generate_data_entry, range(num_samples), chunksize=10),
total=num_samples,
desc="Generating data"
)
)
begin_time = time.time()
writer = FileWriter(ffr_output, len(data_to_write))
for i, data in enumerate(data_to_write):
writer.write_one(serialize(data))
# writer.write_one(data)
writer.close()
end_time = time.time()
lmdb_size = ffr_output.stat().st_size
logger.info(f"FFRecord size: {lmdb_size / 1024 ** 3:.2f} GB")
logger.info(f"Time taken to write: {end_time - begin_time:.2f} seconds")
def read_ffrecord(batch_size: int):
reader = FileReader([ffrecord_file], check_data=True)
sample_indices = list(range(num_samples))
random.Random(0).shuffle(sample_indices)
sample_batches = [sample_indices[i: i + batch_size] for i in range(0, len(sample_indices), batch_size)]
logger.info(f'Number of samples to read: {reader.n}, batch_size = {batch_size}, num_batches = {len(sample_batches)}')
read_indices = set()
begin_time = time.time()
index_iter = sample_batches
index_iter = tqdm(index_iter, desc="Reading data in batches", total=len(sample_batches))
for indices in index_iter:
all_data = reader.read(indices)
for data in all_data:
data = deserialize(data)
read_indices.add(data["index"])
end_time = time.time()
reader.close()
assert read_indices == set(range(num_samples))
logger.info(f"Read {len(read_indices)} samples in {end_time - begin_time:.2f} s: {len(read_indices) / (end_time - begin_time):.2f} samples/s")
class MyDataset(Dataset):
def __init__(self, fnames, check_data=True):
self.reader = FileReader(fnames, check_data=check_data)
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
data = self.reader.read(indices)
samples = []
for bytes_ in data:
item = pickle.loads(bytes_)
samples.append(item)
return samples
ffrecord_file="/tmp/jfs/demo.ffr"
num_samples=1000
num_proc=4
if __name__ == "__main__":
if len(sys.argv) > 1:
if sys.argv[1] == "write":
write_ffrecord()
elif sys.argv[1] == "read":
read_ffrecord(batch_size=1)
else:
begin_time = time.time()
dataset = MyDataset([ffrecord_file], check_data=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=10,prefetch_factor=None)
i=0
for batch in dataloader:
i+=1
if i>1000:
break
end_time = time.time()
print(f"takes: {end_time-begin_time}")

View File

@@ -0,0 +1,19 @@
```bash
# This is a ffrecord dataloader example.
# Prepare
# Install ffrecord here: https://github.com/HFAiLab/ffrecord
# Mount JuiceFS
juicefs mount redis://localhost /tmp/jfs -d
# Generate dataset
python3 sdk/python/examples/ffrecord/main.py write
# Simple read dataset
python3 sdk/python/examples/ffrecord/main.py read
# Read dataset with dataloader: (takes 39.55s)
python3 sdk/python/examples/ffrecord/main.py
# Prepare python-sdk
make -C sdk/python libjfs.so
# Read dataset with Juicefs-pythonsdk-dataloader: (takes 10.02s)
python3 sdk/python/examples/ffrecord/dataloader.py
```

View File

@@ -1,5 +1,5 @@
```bash
# This example demonstrates how to use the fsspec library to read a CSV file from a URL.
# This example demonstrates how to use the fsspec library to read a CSV file.
juicefs mount redis://localhost /tmp/jfs -d
# Download the data file
wget https://gender-pay-gap.service.gov.uk/viewing/download-data/2021 -O /tmp/jfs/ray_demo_data.csv

View File

@@ -1,5 +1,5 @@
# encoding: utf-8
# JuiceFS, Copyright 2020 Juicedata, Inc.
# JuiceFS, Copyright 2024 Juicedata, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
# encoding: utf-8
# JuiceFS, Copyright 2020 Juicedata, Inc.
# JuiceFS, Copyright 2024 Juicedata, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.