Source code for torchjpeg.data.unlabeled_image_folder

from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Union

import torch
from PIL import Image


[docs]class UnlabeledImageFolder(torch.utils.data.Dataset): """ Dataset loading a folder of unlabeled images recursively. The images are loaded using PIL and otherwise unchanged, add a transform to turn them into Tensors Args: path (Path): The path to load recursively from extensions (List[str]): The image extensions to look for transform: Any transform to apply to the images after loading them """ def __init__(self, path: Union[str, Path], extensions: Sequence[str] = [".bmp", ".png", ".jpg", ".ppm", ".pgm"], transform: Optional[Callable] = None) -> None: # pylint: disable=dangerous-default-value if isinstance(path, str): path = Path(path) self.path = path self.images = list(filter(lambda p: p.suffix in extensions, path.glob("**/*"))) self.transform = transform def __len__(self) -> int: return len(self.images) def __getitem__(self, idx) -> Any: im_path = self.images[idx] with open(im_path, "rb") as f: im = Image.open(f) im.load() if self.transform is not None: im = self.transform(im) return im