vit_dataset.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import torch
  17. from torchvision import datasets, transforms
  18. from megatron.data.autoaugment import ImageNetPolicy
  19. def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
  20. # training dataset
  21. train_data_path = os.path.join(data_path[0], "train")
  22. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  23. process = [
  24. transforms.RandomResizedCrop(crop_size),
  25. transforms.RandomHorizontalFlip(),
  26. ]
  27. if color_jitter:
  28. process += [
  29. transforms.ColorJitter(
  30. brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
  31. )
  32. ]
  33. fp16_t = transforms.ConvertImageDtype(torch.half)
  34. process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
  35. transform_train = transforms.Compose(process)
  36. train_data = datasets.ImageFolder(
  37. root=train_data_path, transform=transform_train
  38. )
  39. # validation dataset
  40. val_data_path = os.path.join(data_path[0], "val")
  41. transform_val = transforms.Compose(
  42. [
  43. transforms.Resize(crop_size),
  44. transforms.CenterCrop(crop_size),
  45. transforms.ToTensor(),
  46. normalize,
  47. fp16_t
  48. ]
  49. )
  50. val_data = datasets.ImageFolder(
  51. root=val_data_path, transform=transform_val
  52. )
  53. return train_data, val_data