Shortcuts

data

class trojanzoo.utils.data.TensorListDataset(data=None, targets=None, **kwargs)[source]

The dataset class that has a torch.Tensor as inputs and list[int] as labels. It inherits torch.utils.data.Dataset.

Parameters:
Example:
>>> import torch
>>> from trojanzoo.utils.data import TensorListDataset
>>>
>>> data = torch.ones(10, 3, 32, 32)
>>> targets = list(range(10))
>>> dataset = TensorListDataset(data, targets)
>>> x, y = dataset[3]
>>> x.shape
torch.Size([3, 32, 32])
>>> y
3
trojanzoo.utils.data.dataset_to_tensor(dataset)[source]

Transform a torch.utils.data.Dataset to (data, targets) tensor tuple by traversing all elements.

Parameters:

dataset (torch.utils.data.Dataset) – The dataset.

Returns:

(torch.Tensor, torch.Tensor) – The tuple of (data, targets).

Example:
>>> from torchvision.datasets import MNIST
>>> import torchvision.transforms as transforms
>>> from trojanzoo.utils.data import dataset_to_tensor
>>>
>>> transform = transforms.Compose([
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float)])
>>> dataset = MNIST('./', train=False, download=True,
                    transform=transform)
>>> data, targets = dataset_to_tensor(dataset)
>>> data.shape
torch.Size([10000, 1, 28, 28])
>>> targets.shape
torch.Size([10000])
>>> targets.dtype
torch.int64
trojanzoo.utils.data.sample_batch(dataset, batch_size=None, idx=[])[source]

Sample a batch from dataset by calling

dataset_to_tensor()(torch.utils.data.Subset(dataset, idx))
Parameters:
  • dataset (torch.utils.data.Dataset) – The dataset to sample.

  • batch_size (int) – The batch size to sample when idx is None. Defaults to None.

  • idx (Sequence[int]) – The index list of each sample in dataset. If empty, randomly sample a batch with given batch_size. Defaults to [].

Returns:

(torch.Tensor, torch.Tensor) – The tuple of sampled batch (data, targets).

Example:
>>> import torch
>>> from trojanzoo.utils.data import TensorListDataset, sample_batch
>>>
>>> data = torch.ones(10, 3, 32, 32)
>>> targets = list(range(10))
>>> dataset = TensorListDataset(data, targets)
>>> x, y = sample_batch(dataset, idx=[1, 2])
>>> x.shape
torch.Size([2, 3, 32, 32])
>>> y
tensor([1, 2])
>>> x, y = sample_batch(dataset, batch_size=4)
>>> y
tensor([6, 3, 2, 5])
trojanzoo.utils.data.split_dataset(dataset, length=None, percent=None, shuffle=True, seed=None)[source]

Split a dataset into two subsets.

Parameters:
  • dataset (torch.utils.data.Dataset) – The dataset to split.

  • length (int) – The length of the first subset. This argument cannot be used together with percent. If None, use percent to calculate length instead. Defaults to None.

  • percent (float) – The split ratio for the first subset. This argument cannot be used together with length. length = percent * len(dataset). Defaults to None.

  • shuffle (bool) – Whether to shuffle the dataset. Defaults to True.

  • seed (bool) – The random seed to split dataset using numpy.random.shuffle. Defaults to None.

Returns:

(torch.utils.data.Subset, torch.utils.data.Subset) – The two splitted subsets.

Example:
>>> import torch
>>> from trojanzoo.utils.data import TensorListDataset, split_dataset
>>>
>>> data = torch.ones(11, 3, 32, 32)
>>> targets = list(range(11))
>>> dataset = TensorListDataset(data, targets)
>>> set1, set2 = split_dataset(dataset, length=3)
>>> len(set1), len(set2)
(3, 8)
>>> set3, set4 = split_dataset(dataset, percent=0.5)
>>> len(set3), len(set4)
(5, 6)

Note

This is the implementation of trojanzoo.datasets.Dataset.split_dataset(). The difference is that this method will NOT set seed as env['data_seed'] when it is None.

trojanzoo.utils.data.get_class_subset(dataset, class_list)[source]

Get a subset from dataset with certain classes.

Parameters:
Returns:

torch.utils.data.Subset – The subset with labels in class_list.

Example:
>>> import torch
>>> from trojanzoo.utils.data import get_class_subset, TensorListDataset
>>>
>>> data = torch.ones(11, 3, 32, 32)
>>> targets = list(range(11))
>>> dataset = TensorListDataset(data, targets)
>>> subset = get_class_subset(dataset, class_list=[2, 3])
>>> len(subset)
2

Docs

Access comprehensive developer documentation for TrojanZoo

View Docs