1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
| import os import torch.utils.data as data import torch import torch.optim as optim import torch.nn as nn from torch.optim import lr_scheduler from torchvision import datasets, models, transforms from PIL import Image import time import copy import pandas as pd import matplotlib.pyplot as plt import numpy as np %matplotlib inline
NUM_EPOCH=5 batch_size = 100 device = torch.device('cuda:0') NUMCLASS = 4
def default_loader(path): with open(path, 'rb') as f: with Image.open(f) as img: return img.convert('RGB')
class CustomImageLoader(data.Dataset): def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader): im_list = [] im_labels = [] with open(txt_path, 'r') as files: for line in files: items = line.split() if items[0][0] == '/': imname = line.split()[0][1:] else: imname = line.split()[0] im_list.append(os.path.join(img_path, imname)) im_labels.append(int(items[1])) self.imgs = im_list self.labels = im_labels self.data_tranforms = data_transforms self.loader = loader self.dataset = dataset def __len__(self): return len(self.imgs) def __getitem__(self, item): img_name = self.imgs[item] label = self.labels[item] img = self.loader(img_name) if self.data_tranforms is not None: try: img = self.data_tranforms[self.dataset](/img) except: print("Cannot transform image: {}".format(img_name)) return img, label data_tranforms={ 'Train':transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) ]), 'Test':transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) ]) }
image_datasets = {x : CustomImageLoader('/', txt_path=('/tmp/'+x+'Images.label'), data_transforms=data_tranforms, dataset=x) for x in ['Train', 'Test'] } dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['Train', 'Test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['Train', 'Test']}
|