主頁 > 知識(shí)庫 > Pytorch自定義Dataset和DataLoader去除不存在和空數(shù)據(jù)的操作

Pytorch自定義Dataset和DataLoader去除不存在和空數(shù)據(jù)的操作

熱門標(biāo)簽:正安縣地圖標(biāo)注app qt百度地圖標(biāo)注 400電話申請(qǐng)資格 螳螂科技外呼系統(tǒng)怎么用 阿里電話機(jī)器人對(duì)話 遼寧智能外呼系統(tǒng)需要多少錢 電銷機(jī)器人系統(tǒng)廠家鄭州 舉辦過冬奧會(huì)的城市地圖標(biāo)注 地圖地圖標(biāo)注有嘆號(hào)

【源碼GitHub地址】:點(diǎn)擊進(jìn)入

1. 問題描述

之前寫了一篇關(guān)于《pytorch Dataset, DataLoader產(chǎn)生自定義的訓(xùn)練數(shù)據(jù)》的博客,但存在一個(gè)問題,我們不能在Dataset做一些數(shù)據(jù)清理,如果我們傳遞給Dataset數(shù)據(jù),本身存在問題,那么迭代過程肯定出錯(cuò)的。

比如我把很多圖片路徑都傳遞給Dataset,如果圖片路徑都是正確的,且圖片都存在也沒有損壞,那顯然運(yùn)行是沒有問題的;

但倘若傳遞給Dataset的圖片路徑有些圖片是不存在,這時(shí)你通過Dataset讀取圖片數(shù)據(jù),然后再迭代返回,就會(huì)出現(xiàn)類似如下的錯(cuò)誤:

File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 68, in listcomp> return [default_collate(samples) for samples in transposed]

File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 70, in default_collate

raise TypeError((error_msg_fmt.format(type(batch[0])))) TypeError: batch must contain tensors, numbers, dicts or lists; found class 'NoneType'>

2. 一般的解決方法

一般的解決方法也很簡單粗暴,就是在傳遞數(shù)據(jù)給Dataset前,就做數(shù)據(jù)清理,把不存在的圖片,損壞的數(shù)據(jù)都提前清理掉。

是的,這個(gè)是最簡單粗暴的。

3. 另一種解決方法:自定義返回?cái)?shù)據(jù)的規(guī)則:collate_fn()校對(duì)函數(shù)

我們希望不管傳遞什么處理給Dataset,Dataset都進(jìn)行處理,如果不存在或者異常,就返回None,而在DataLoader時(shí),對(duì)于不存為None的數(shù)據(jù),都去除掉。

這樣就保證在迭代過程中,DataLoader獲得batch數(shù)據(jù)都是正確的。

比如讀取batch_size=5的圖片數(shù)據(jù),如果其中有1個(gè)(或者多個(gè))圖片是不存在,那么返回的batch應(yīng)該把不存在的數(shù)據(jù)過濾掉,即返回5-1=4大小的batch的數(shù)據(jù)。

是的,我要實(shí)現(xiàn)的就是這個(gè)功能:返回的batch數(shù)據(jù)會(huì)自定清理掉不合法的數(shù)據(jù)。

3.1 Pytorch數(shù)據(jù)處理函數(shù):Dataset和 DataLoader

Pytorch有兩個(gè)數(shù)據(jù)處理函數(shù):Dataset和 DataLoader

from torch.utils.data import Dataset, DataLoader

其中Dataset用于定義數(shù)據(jù)的讀取和預(yù)處理操作,而DataLoader用于加載并產(chǎn)生批訓(xùn)練數(shù)據(jù)。

torch.utils.data.DataLoader參數(shù)說明:

DataLoader(object)可用參數(shù):

1、dataset(Dataset) 傳入的數(shù)據(jù)集

2、batch_size(int, optional) 每個(gè)batch有多少個(gè)樣本

3、shuffle(bool, optional) 在每個(gè)epoch開始的時(shí)候,對(duì)數(shù)據(jù)進(jìn)行重新排序

4、sampler(Sampler, optional) 自定義從數(shù)據(jù)集中取樣本的策略,如果指定這個(gè)參數(shù),那么shuffle必須為False

5、batch_sampler(Sampler, optional) 與sampler類似,但是一次只返回一個(gè)batch的indices(索引),需要注意的是,一旦指定了這個(gè)參數(shù),那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

6、num_workers (int, optional) 這個(gè)參數(shù)決定了有幾個(gè)進(jìn)程來處理data loading。0意味著所有的數(shù)據(jù)都會(huì)被load進(jìn)主進(jìn)程。(默認(rèn)為0)

7、collate_fn (callable, optional) 將一個(gè)list的sample組成一個(gè)mini-batch的函數(shù)

8、pin_memory (bool, optional) 如果設(shè)置為True,那么data loader將會(huì)在返回它們之前,將tensors拷貝到CUDA中的固定內(nèi)存(CUDA pinned memory)中.

9、drop_last (bool, optional) 如果設(shè)置為True:這個(gè)是對(duì)最后的未完成的batch來說的,比如你的batch_size設(shè)置為64,而一個(gè)epoch只有100個(gè)樣本,那么訓(xùn)練的時(shí)候后面的36個(gè)就被扔掉了。 如果為False(默認(rèn)),那么會(huì)繼續(xù)正常執(zhí)行,只是最后的batch_size會(huì)小一點(diǎn)。

10、timeout(numeric, optional) 如果是正數(shù),表明等待從worker進(jìn)程中收集一個(gè)batch等待的時(shí)間,若超出設(shè)定的時(shí)間還沒有收集到,那就不收集這個(gè)內(nèi)容了。這個(gè)numeric應(yīng)總是大于等于0。默認(rèn)為0

11、worker_init_fn (callable, optional) 每個(gè)worker初始化函數(shù) If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

我們要用到的是collate_fn()回調(diào)函數(shù)

3.2 自定義collate_fn()函數(shù):

torch.utils.data.DataLoader的collate_fn()用于設(shè)置batch數(shù)據(jù)拼接方式,默認(rèn)是default_collate函數(shù),但當(dāng)batch中含有None等數(shù)據(jù)時(shí),默認(rèn)的default_collate校隊(duì)方法會(huì)出現(xiàn)錯(cuò)誤。因此,我們需要自定義collate_fn()函數(shù):

方法也很簡單:只需在原來的default_collate函數(shù)中添加下面幾句代碼:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯(cuò)了。

 # 這里添加:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯(cuò)了
 if isinstance(batch, list):
 batch = [(image, image_id) for (image, image_id) in batch if image is not None]
 if batch==[]:
 return (None,None)

dataset_collate.py:

# -*-coding: utf-8 -*-
"""
 @Project: pytorch-learning-tutorials
 @File : dataset_collate.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2019-06-07 17:09:13
"""
 
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import torch
import re
from torch._six import container_abcs, string_classes, int_classes 
_use_shared_memory = False
r"""Whether to use shared memory in default_collate"""
 
np_str_obj_array_pattern = re.compile(r'[SaUO]')
 
error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
 
numpy_type_map = {
 'float64': torch.DoubleTensor,
 'float32': torch.FloatTensor,
 'float16': torch.HalfTensor,
 'int64': torch.LongTensor,
 'int32': torch.IntTensor,
 'int16': torch.ShortTensor,
 'int8': torch.CharTensor,
 'uint8': torch.ByteTensor,
}
 
def collate_fn(batch):
 '''
 collate_fn (callable, optional): merges a list of samples to form a mini-batch.
 該函數(shù)參考touch的default_collate函數(shù),也是DataLoader的默認(rèn)的校對(duì)方法,當(dāng)batch中含有None等數(shù)據(jù)時(shí),
 默認(rèn)的default_collate校隊(duì)方法會(huì)出現(xiàn)錯(cuò)誤
 一種的解決方法是:
 判斷batch中image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯(cuò)了
 :param batch:
 :return:
 '''
 r"""Puts each data field into a tensor with outer dimension batch size"""
 # 這里添加:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯(cuò)了
 if isinstance(batch, list):
 batch = [(image, image_id) for (image, image_id) in batch if image is not None]
 if batch==[]:
 return (None,None)
 
 elem_type = type(batch[0])
 if isinstance(batch[0], torch.Tensor):
 out = None
 if _use_shared_memory:
  # If we're in a background process, concatenate directly into a
  # shared memory tensor to avoid an extra copy
  numel = sum([x.numel() for x in batch])
  storage = batch[0].storage()._new_shared(numel)
  out = batch[0].new(storage)
 return torch.stack(batch, 0, out=out)
 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \

  and elem_type.__name__ != 'string_':
 elem = batch[0]
 if elem_type.__name__ == 'ndarray':
  # array of string classes and object
  if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  raise TypeError(error_msg_fmt.format(elem.dtype))
 
  return collate_fn([torch.from_numpy(b) for b in batch])
 if elem.shape == (): # scalars
  py_type = float if elem.dtype.name.startswith('float') else int
  return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
 elif isinstance(batch[0], float):
 return torch.tensor(batch, dtype=torch.float64)
 elif isinstance(batch[0], int_classes):
 return torch.tensor(batch)
 elif isinstance(batch[0], string_classes):
 return batch
 elif isinstance(batch[0], container_abcs.Mapping):
 return {key: collate_fn([d[key] for d in batch]) for key in batch[0]}
 elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
 return type(batch[0])(*(collate_fn(samples) for samples in zip(*batch)))
 elif isinstance(batch[0], container_abcs.Sequence):
 transposed = zip(*batch)#ok
 return [collate_fn(samples) for samples in transposed]
 
 raise TypeError((error_msg_fmt.format(type(batch[0]))))

測試方法:

# -*-coding: utf-8 -*-
"""
 @Project: pytorch-learning-tutorials
 @File : dataset.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2019-03-07 18:45:06
"""
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from utils import dataset_collate
import os
import cv2
from PIL import Image
def read_image(path,mode='RGB'):
 '''
 :param path:
 :param mode: RGB or L
 :return:
 '''
 return Image.open(path).convert(mode)
 
class TorchDataset(Dataset):
 def __init__(self, image_id_list, image_dir, resize_height=256, resize_width=256, repeat=1, transform=None):
 '''
 :param filename: 數(shù)據(jù)文件TXT:格式:imge_name.jpg label1_id labe2_id
 :param image_dir: 圖片路徑:image_dir+imge_name.jpg構(gòu)成圖片的完整路徑
 :param resize_height 為None時(shí),不進(jìn)行縮放
 :param resize_width 為None時(shí),不進(jìn)行縮放,
    PS:當(dāng)參數(shù)resize_height或resize_width其中一個(gè)為None時(shí),可實(shí)現(xiàn)等比例縮放
 :param repeat: 所有樣本數(shù)據(jù)重復(fù)次數(shù),默認(rèn)循環(huán)一次,當(dāng)repeat為None時(shí),表示無限循環(huán)sys.maxsize
 :param transform:預(yù)處理
 '''
 self.image_dir = image_dir
 self.image_id_list=image_id_list
 self.len = len(image_id_list)
 self.repeat = repeat
 self.resize_height = resize_height
 self.resize_width = resize_width
 self.transform= transform
 
 def __getitem__(self, i):
 index = i % self.len
 # print("i={},index={}".format(i, index))
 image_id = self.image_id_list[index]
 image_path = os.path.join(self.image_dir, image_id)
 img = self.load_data(image_path)
 
 if img is None:
  return None,image_id
 img = self.data_preproccess(img)
 return img,image_id
 
 def __len__(self):
 if self.repeat == None:
  data_len = 10000000
 else:
  data_len = len(self.image_id_list) * self.repeat
 return data_len
 
 def load_data(self, path):
 '''
 加載數(shù)據(jù)
 :param path:
 :param resize_height:
 :param resize_width:
 :param normalization: 是否歸一化
 :return:
 '''
 try:
  image = read_image(path)
 except Exception as e:
  image=None
  print(e)
 # image = image_processing.read_image(path)#用opencv讀取圖像
 return image
 
 def data_preproccess(self, data):
 '''
 數(shù)據(jù)預(yù)處理
 :param data:
 :return:
 '''
 if self.transform is not None:
  data = self.transform(data)
 return data
 
if __name__=='__main__':
 
 resize_height = 224
 resize_width = 224
 image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"]
 image_dir="../dataset/test_images/images"
 # 相關(guān)預(yù)處理的初始化
 '''class torchvision.transforms.ToTensor把shape=(H,W,C)的像素值范圍為[0, 255]的PIL.Image或者numpy.ndarray數(shù)據(jù)
 # 轉(zhuǎn)換成shape=(C,H,W)的像素?cái)?shù)據(jù),并且被歸一化到[0.0, 1.0]的torch.FloatTensor類型。
 '''
 train_transform = transforms.Compose([
 transforms.Resize(size=(resize_height, resize_width)),
 # transforms.RandomHorizontalFlip(),#隨機(jī)翻轉(zhuǎn)圖像
 transforms.RandomCrop(size=(resize_height, resize_width), padding=4), # 隨機(jī)裁剪
 transforms.ToTensor(), # 吧shape=(H,W,C)->換成shape=(C,H,W),并且歸一化到[0.0, 1.0]的torch.FloatTensor類型
 # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#給定均值(R,G,B) 方差(R,G,B),將會(huì)把Tensor正則化
 ])
 
 epoch_num=2 #總樣本循環(huán)次數(shù)
 batch_size=5 #訓(xùn)練時(shí)的一組數(shù)據(jù)的大小
 train_data_nums=10
 max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #總迭代次數(shù)
 
 train_data = TorchDataset(image_id_list=image_id_list,
    image_dir=image_dir,
    resize_height=resize_height,
    resize_width=resize_width,
    repeat=1,
    transform=train_transform)
 # 使用默認(rèn)的default_collate會(huì)報(bào)錯(cuò)
 # train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
 # 使用自定義的collate_fn
 train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False, collate_fn=dataset_collate.collate_fn)
 
 
 # [1]使用epoch方法迭代,TorchDataset的參數(shù)repeat=1
 for epoch in range(epoch_num):
 for step,(batch_image, batch_label) in enumerate(train_loader):
  if batch_image is None and batch_label is None:
  print("batch_image:{},batch_label:{}".format(batch_image, batch_label))
  continue
  image=batch_image[0,:]
  image=image.numpy()#image=np.array(image)
  image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  cv2.imshow("image",image)
  cv2.waitKey(2000)
  print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
  # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

輸出結(jié)果說明:

batch_size=5,輸入圖片列表image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"] ,其中"ddd.jpg","111.jpg"是不存在的,resize_width=224,正常情況下返回的數(shù)據(jù)應(yīng)該是torch.Size([5, 3, 224, 224]),但由于"ddd.jpg","111.jpg"不存在,被過濾掉了,所以第一個(gè)batch的維度變?yōu)閠orch.Size([3, 3, 224, 224])

[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。

您可能感興趣的文章:
  • pytorch DataLoader的num_workers參數(shù)與設(shè)置大小詳解
  • Pytorch數(shù)據(jù)讀取之Dataset和DataLoader知識(shí)總結(jié)
  • 我對(duì)PyTorch dataloader里的shuffle=True的理解
  • pytorch Dataset,DataLoader產(chǎn)生自定義的訓(xùn)練數(shù)據(jù)案例
  • PyTorch實(shí)現(xiàn)重寫/改寫Dataset并載入Dataloader
  • Pytorch 如何加速Dataloader提升數(shù)據(jù)讀取速度

標(biāo)簽:合肥 淘寶好評(píng)回訪 信陽 隨州 濟(jì)源 昭通 興安盟 阜新

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《Pytorch自定義Dataset和DataLoader去除不存在和空數(shù)據(jù)的操作》,本文關(guān)鍵詞  Pytorch,自定義,Dataset,和,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《Pytorch自定義Dataset和DataLoader去除不存在和空數(shù)據(jù)的操作》相關(guān)的同類信息!
  • 本頁收集關(guān)于Pytorch自定義Dataset和DataLoader去除不存在和空數(shù)據(jù)的操作的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章