data¶
- class trojanzoo.utils.data.TensorListDataset(data=None, targets=None, **kwargs)[source]¶
The dataset class that has a
torch.Tensor
as inputs andlist
[int
] as labels. It inheritstorch.utils.data.Dataset
.- Parameters:
data (torch.Tensor) – The inputs.
**kwargs – Keyword arguments passed to
torch.utils.data.Dataset
.
- Example:
>>> import torch >>> from trojanzoo.utils.data import TensorListDataset >>> >>> data = torch.ones(10, 3, 32, 32) >>> targets = list(range(10)) >>> dataset = TensorListDataset(data, targets) >>> x, y = dataset[3] >>> x.shape torch.Size([3, 32, 32]) >>> y 3
- trojanzoo.utils.data.dataset_to_tensor(dataset)[source]¶
Transform a
torch.utils.data.Dataset
to(data, targets)
tensor tuple by traversing all elements.- Parameters:
dataset (torch.utils.data.Dataset) – The dataset.
- Returns:
(torch.Tensor, torch.Tensor) – The tuple of
(data, targets)
.- Example:
>>> from torchvision.datasets import MNIST >>> import torchvision.transforms as transforms >>> from trojanzoo.utils.data import dataset_to_tensor >>> >>> transform = transforms.Compose([ transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float)]) >>> dataset = MNIST('./', train=False, download=True, transform=transform) >>> data, targets = dataset_to_tensor(dataset) >>> data.shape torch.Size([10000, 1, 28, 28]) >>> targets.shape torch.Size([10000]) >>> targets.dtype torch.int64
- trojanzoo.utils.data.sample_batch(dataset, batch_size=None, idx=[])[source]¶
Sample a batch from dataset by calling
dataset_to_tensor()
(torch.utils.data.Subset
(dataset, idx))- Parameters:
dataset (torch.utils.data.Dataset) – The dataset to sample.
batch_size (int) – The batch size to sample when
idx
isNone
. Defaults toNone
.idx (Sequence[int]) – The index list of each sample in dataset. If empty, randomly sample a batch with given
batch_size
. Defaults to[]
.
- Returns:
(torch.Tensor, torch.Tensor) – The tuple of sampled batch
(data, targets)
.- Example:
>>> import torch >>> from trojanzoo.utils.data import TensorListDataset, sample_batch >>> >>> data = torch.ones(10, 3, 32, 32) >>> targets = list(range(10)) >>> dataset = TensorListDataset(data, targets) >>> x, y = sample_batch(dataset, idx=[1, 2]) >>> x.shape torch.Size([2, 3, 32, 32]) >>> y tensor([1, 2]) >>> x, y = sample_batch(dataset, batch_size=4) >>> y tensor([6, 3, 2, 5])
- trojanzoo.utils.data.split_dataset(dataset, length=None, percent=None, shuffle=True, seed=None)[source]¶
Split a dataset into two subsets.
- Parameters:
dataset (torch.utils.data.Dataset) – The dataset to split.
length (int) – The length of the first subset. This argument cannot be used together with
percent
. IfNone
, usepercent
to calculate length instead. Defaults toNone
.percent (float) – The split ratio for the first subset. This argument cannot be used together with
length
.length = percent * len(dataset)
. Defaults toNone
.shuffle (bool) – Whether to shuffle the dataset. Defaults to
True
.seed (bool) – The random seed to split dataset using
numpy.random.shuffle
. Defaults toNone
.
- Returns:
(torch.utils.data.Subset, torch.utils.data.Subset) – The two splitted subsets.
- Example:
>>> import torch >>> from trojanzoo.utils.data import TensorListDataset, split_dataset >>> >>> data = torch.ones(11, 3, 32, 32) >>> targets = list(range(11)) >>> dataset = TensorListDataset(data, targets) >>> set1, set2 = split_dataset(dataset, length=3) >>> len(set1), len(set2) (3, 8) >>> set3, set4 = split_dataset(dataset, percent=0.5) >>> len(set3), len(set4) (5, 6)
Note
This is the implementation of
trojanzoo.datasets.Dataset.split_dataset()
. The difference is that this method will NOT setseed
asenv['data_seed']
when it isNone
.
- trojanzoo.utils.data.get_class_subset(dataset, class_list)[source]¶
Get a subset from dataset with certain classes.
- Parameters:
dataset (torch.utils.data.Dataset) – The entire dataset.
- Returns:
torch.utils.data.Subset – The subset with labels in
class_list
.- Example:
>>> import torch >>> from trojanzoo.utils.data import get_class_subset, TensorListDataset >>> >>> data = torch.ones(11, 3, 32, 32) >>> targets = list(range(11)) >>> dataset = TensorListDataset(data, targets) >>> subset = get_class_subset(dataset, class_list=[2, 3]) >>> len(subset) 2