Pytorch数据读入之Dataset

1793-席同学

发表文章数:14

首页 » Python » 正文

Pytorch数据读入之Dataset

Dataset:提供一种方式,获取数据和label

from torch.utils.data import Dataset
from PIL import Image  ## 读取图片的库,可以对图片进行可视化
import os  ## 关于系统操作的库,主要用来对文件路径操作

# dir_path='dataset/hymenoptera_data/train/ants'   # 获取数据相对路径/也可以是绝对路径
# import os          ## 对数据所在文件路径进行字符串操作
# img_path_list=os.listdir(dir_path)  ## 将地址传给os.listdir会吧所有数据地址汇成一个列表,我的数据地址是dir_path='dataset/hymenoptera_data/train/ants' 

# img_path_list[0]   #  这个列表存的就是每个数据的文件名,我们取出第一个看一下
# 输出:'0013035.jpg'  # 和列表操作一样可以逐个读取



创建读取数据的class类

class MyData(Dataset):
    '''
    读数据的类需要继承Dataset这个类
    '''

    def __init__(self, root_dir, label_dir):  # 初始化主要是为这个所建的类提供全局变量
        
        '''
        :param root_dir: 数据集根目录
        :param label_dir: 数据集标签的目录,这个只填标签目录即可,我们有函数可以合并这两个目录

        ####################################
        root_dir='dataset/hymenoptera_data/train'
        label_dir='ants'
        path=os.path.join(root_dir,label_dir)  os.path.join()这个函数可以拼接两个文件路径形成一个文件路径
        输出:'dataset/hymenoptera_data/train//ants'       我的数据就是放在dataset/hymenoptera_data/train//ants里面//环境里需要用//
        '''
        
        self.root_dir = root_dir  ## self 的作用就是把定义的变量设置成全局变量
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)  ## 获取数据地址,只不过这个数据文件名称是标签
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):  # item原本是item,在这里我们使用idx
        
        # 取图片中的某一个
        '''
        img_name = img_path[idx]
        print(img_name)
        0013035.jpg
        '''

        img_name = self.img_path[idx]  ## 根据idx索引来提取索引所☞的图片的编号
        img_item_path = os.path.join(self.root_dir, self.label_dir,
                                     img_name)  # 每一个图片的位置'dataset/hymenoptera_data/train//ants//0013035.jpg'
        ## 读取图片
        img = Image.open(img_item_path)
        ## 标签
        label = self.label_dir

        return img, label

    def __len__(self):
        '''
        返回数据集长度,数据集的长度就是列表文件的长度
        :return:
        '''

        return len(self.img_path)

实例化演示

这里获取,蚂蚁的数据集

root_dir = 'dataset/hymenoptera_data/train'
ant_label_dir = 'ants'
ant_dataset = MyData(root_dir, ant_label_dir)

这里获取,蜜蜂的数据集

bees_dir = 'dataset/hymenoptera_data/train'
bees_label_dir = 'bees'
bees_dataset = MyData(bees_dir, bees_label_dir)

整个数据集应该是,ants和bees的集合而不是单独的所以我们要进行整合

train_dataset = ant_dataset + bees_dataset

未经允许不得转载:作者:1793-席同学, 转载或复制请以 超链接形式 并注明出处 拜师资源博客
原文地址:《Pytorch数据读入之Dataset》 发布于2021-10-17

分享到:
赞(0) 打赏

评论 抢沙发

评论前必须登录!

  注册



长按图片转发给朋友

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

微信扫一扫打赏

Vieu3.3主题
专业打造轻量级个人企业风格博客主题!专注于前端开发,全站响应式布局自适应模板。

登录

忘记密码 ?

您也可以使用第三方帐号快捷登录

Q Q 登 录
微 博 登 录