Source code for sconce.datasets.csv_image_folder

from PIL import Image
from torchvision.datasets import folder
from collections import defaultdict

import csv
import glob
import numpy as np
import os
import os.path
import torch
import torch.utils.data as data


[docs]class CsvImageFolder(data.Dataset): """ A Dataset that reads images from a folder and labels from a csv file. Arguments: root (string): directory where the images can be found. csv_path (string): the path to the csv file containing image filenames and labels. filename_key (string, optional): the column header of the csv for the column that contains image filenames (without extensions). labels_key (string, optional): the column header of the csv for the column that contains labels for each image. csv_delimiter (string, optional): the character(s) used to separate fields in the csv file. loader (callable, optional): a function to load a sample given its path. extensions (list[string], optinoal): a list of allowed extensions. E.g, ``['.jpg', '.tif']`` transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ def __init__(self, root, csv_path, filename_key='image_name', labels_key='tags', csv_delimiter=',', labels_delimiter=' ', loader=folder.default_loader, extensions=folder.IMG_EXTENSIONS, transform=None, target_transform=None): self.root = root self.csv_path = csv_path self.filename_key = filename_key self.labels_key = labels_key self.csv_delimiter = csv_delimiter self.labels_delimiter = labels_delimiter self.loader = loader self.extensions = extensions self.transform = transform self.target_transform = target_transform self._found_extensions = None self._load_found_extensions() self._labels = [] self._load() def _load(self): self._rows = [] all_labels = set() with open(self.csv_path) as csv_file: reader = csv.DictReader(csv_file, delimiter=self.csv_delimiter) for row in reader: self._rows.append(row) labels = self._get_labels(row) all_labels.update(labels) self._labels = sorted(all_labels) self._label_idxs = {label:i for i, label in enumerate(self._labels)} def _get_path(self, base_filename): found_extensions = self.found_extensions if base_filename in found_extensions: found_extensions = found_extensions[base_filename] for extension in self.extensions: if extension in found_extensions: return os.path.join(self.root, '%s%s' % (base_filename, extension)) raise RuntimeError(f"No image file with base filename ({base_filename}) " f"found in folder ({self.root}), valid extensions are: {self.extensions}") @property def found_extensions(self): if self._found_extensions is None: self._found_extensions = self._load_found_extensions() return self._found_extensions def _load_found_extensions(self): found_extensions = defaultdict(list) for filename in os.listdir(self.root): base, ext = os.path.splitext(filename) found_extensions[base].append(ext) return dict(found_extensions) def _get_labels(self, row): return set(row[self.labels_key].split(self.labels_delimiter)) @property def num_classes(self): return len(self._labels) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where sample is the image, and target is an array of indices of the target class. """ row = self._rows[index] base_filename = row[self.filename_key] path = self._get_path(base_filename) sample = self.loader(path) labels = self._get_labels(row) target_indices = np.array([self._label_idxs[label] for label in labels]) target = np.zeros(self.num_classes) target[target_indices] = 1 target = torch.Tensor(target) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self._rows) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of images: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) fmt_str += ' Number of tags: {}\n'.format(len(self._labels)) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str