1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import torch
- from torchvision import datasets, transforms
- from megatron.data.autoaugment import ImageNetPolicy
- def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
- # training dataset
- train_data_path = os.path.join(data_path[0], "train")
- normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
- process = [
- transforms.RandomResizedCrop(crop_size),
- transforms.RandomHorizontalFlip(),
- ]
- if color_jitter:
- process += [
- transforms.ColorJitter(
- brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
- )
- ]
- fp16_t = transforms.ConvertImageDtype(torch.half)
- process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
- transform_train = transforms.Compose(process)
- train_data = datasets.ImageFolder(
- root=train_data_path, transform=transform_train
- )
- # validation dataset
- val_data_path = os.path.join(data_path[0], "val")
- transform_val = transforms.Compose(
- [
- transforms.Resize(crop_size),
- transforms.CenterCrop(crop_size),
- transforms.ToTensor(),
- normalize,
- fp16_t
- ]
- )
- val_data = datasets.ImageFolder(
- root=val_data_path, transform=transform_val
- )
- return train_data, val_data
|