大家好,又见面了,我是你们的朋友全栈君。
PyTorch使用LMDB数据库加速文件读取
文章目录
- PyTorch使用LMDB数据库加速文件读取
- 背景介绍
- 具体操作
- LMDB主要类
- `lmdb.Environment`
- `lmdb.Transaction`
- `Imdb.Cursor`
- 操作流程
- 创建图像数据集
- 配合DataLoader
- LMDB主要类
- 参考链接
原始文档:https://www.yuque.com/lart/ugkv9f/hbnym1
对于数据库的了解较少,文章中大部分的介绍主要来自于各种博客和LMDB的文档,但是文档中的介绍,默认是已经了解了数据库的许多知识,这导致目前只能囫囵吞枣,待之后仔细了解后再重新补充内容。
背景介绍
文章https://blog.csdn.net/jyl1999xxxx/article/details/53942824中介绍了使用LMDB的原因:
Caffe使用LMDB来存放训练/测试用的数据集,以及使用网络提取出的feature(为了方便,以下还是统称数据集)。数据集的结构很简单,就是大量的矩阵/向量数据平铺开来。数据之间没有什么关联,数据内没有复杂的对象结构,就是向量和矩阵。既然数据并不复杂,Caffe就选择了LMDB这个简单的数据库来存放数据。
LMDB的全称是Lightning Memory-Mapped Database,闪电般的内存映射数据库。它文件结构简单,一个文件夹,里面一个数据文件,一个锁文件。数据随意复制,随意传输。它的访问简单,不需要运行单独的数据库管理进程,只要在访问数据的代码里引用LMDB库,访问时给文件路径即可。图像数据集归根究底从图像文件而来。引入数据库存放数据集,是为了减少IO开销。读取大量小文件的开销是非常大的,尤其是在机械硬盘上。LMDB的整个数据库放在一个文件里,避免了文件系统寻址的开销。LMDB使用内存映射的方式访问文件,使得文件内寻址的开销非常小,使用指针运算就能实现。数据库单文件还能减少数据集复制/传输过程的开销。一个几万,几十万文件的数据集,不管是直接复制,还是打包再解包,过程都无比漫长而痛苦。LMDB数据库只有一个文件,你的介质有多块,就能复制多快,不会因为文件多而慢如蜗牛。
在文章http://shuokay.com/2018/05/14/python-lmdb/中类似提到:
为什么要把图像数据转换成大的二进制文件?
简单来说,是因为读写小文件的速度太慢。那么,不禁要问,图像数据也是二进制文件,单个大的二进制文件例如 LMDB 文件也是二进制文件,为什么单个图像读写速度就慢了呢?这里分两种情况解释。
- 机械硬盘的情况:机械硬盘的每次读写启动时间比较长,例如磁头的寻道时间占比很高,因此,如果单个小文件读写,尤其是随机读写单个小文件的时候,这个寻道时间占比就会很高,最后导致大量读写小文件的时候时间会很浪费;
- NFS 的情况:在 NFS 的场景下,系统的一次读写首先要进行上百次的网络通讯,并且这个通讯次数和文件的大小无关。因此,如果是读写小文件,这个网络通讯时间占据了整个读写时间的大部分。
固态硬盘的情况下应该也会有一些类似的开销,目前没有研究过。
总而言之,使用LMDB可以为我们的数据读取进行加速。
具体操作
LMDB主要类
<code style="margin-left:0">pip install lmdb</code>
lmdb.Environment
lmdb.open()
这个方法实际上是 class lmdb.Environment(path, map_size=10485760, subdir=True, readonly=False, metasync=True, sync=True, map_async=False, mode=493, create=True, readahead=True, writemap=False, meminit=True, max_readers=126, max_dbs=0, max_spare_txns=1, lock=True)
的一个别名(shortcut),二者是等价的。关于这个类:https://lmdb.readthedocs.io/en/release/#environment-class
这是数据库环境的结构。 一个环境可能包含多个数据库,所有数据库都驻留在同一共享内存映射和基础磁盘文件中。要写入环境,必须创建事务(Transaction)。 允许同时进行一次写入事务,但是即使存在写入事务,读取事务的数量也没有限制。
几个重要的实例方法:
- begin(db=None, parent=None, write=False, buffers=False): 可以调用事务类
lmdb.Transaction
- open_db(key=None, txn=None, reverse_key=False, dupsort=False, create=True, integerkey=False, integerdup=False, dupfixed=False): 打开一个数据库,返回一个不透明的句柄。重复
Environment.open_db()
调用相同的名称将返回相同的句柄。作为一个特殊情况,主数据库总是开放的。命名数据库是通过在主数据库中存储一个特殊的描述符来实现的。环境中的所有数据库共享相同的文件。因为描述符存在于主数据库中,所以如果已经存在与数据库名称匹配的key
,创建命名数据库的尝试将失败。此外,查找和枚举可以看到key
。如果主数据库keyspace
与命名数据库使用的名称冲突,则将主数据库的内容移动到另一个命名数据库。
<code style="margin-left:0">>>> env = lmdb.open('/tmp/test', max_dbs=2) >>> with env.begin(write=True) as txn ... txn.put('somename', 'somedata') >>> # Error: database cannot share name of existing key! >>> subdb = env.open_db('somename')</code>
lmdb.Transaction
这和事务对象有关。
class lmdb.Transaction(env, db=None, parent=None, write=False, buffers=False)
。
关于这个类的参数:https://lmdb.readthedocs.io/en/release/#transaction-class
所有操作都需要事务句柄,事务可以是只读或读写的。写事务可能不会跨越线程。事务对象实现了上下文管理器协议,因此即使面对未处理的异常,也可以可靠地释放事务:
<code style="margin-left:0"># Transaction aborts correctly: with env.begin(write=True) as txn: crash() # Transaction commits automatically: with env.begin(write=True) as txn: txn.put('a', 'b')</code>
这个类的实例包含着很多有用的操作方法。
- abort(): 中止挂起的事务。重复调用
abort()
在之前成功的commit()
或abort()
后或者在相关环境关闭后是没有效果的。 - commit(): 提交挂起的事务。
- cursor(db=None): Shortcut for
lmdb.Cursor(db, self)
- delete(key, value=’’, db=None): Delete a key from the database.
- key: The key to delete.
- value:如果数据库是以
dupsort = True
打开的,并且value
不是空的bytestring
,则删除仅与此(key, value)
对匹配的元素,否则该key
的所有值都将被删除。 - Returns
True
if at least one key was deleted.
- drop(db, delete=True): 删除命名数据库中的所有键,并可选地删除命名数据库本身。删除命名数据库会导致其不可用,并使现有cursors无效。
- get(key, default=None, db=None): 获取匹配键的第一个值,如果键不存在,则返回默认值。cursor必须用于获取
dupsort = True
数据库中的key
的所有值。 - id(): 返回事务的ID。这将返回与此事务相关联的标识符。对于只读事务,这对应于正在读取的快照; 并发读取器通常具有相同的事务ID。
- pop(key, db=None): 使用临时cursor调用
Cursor.pop()
。- db: 要操作的命名数据库。如果未指定,默认为事务构造函数被给定的数据库。
- put(key, value, dupdata=True, overwrite=True, append=False, db=None): 存储一条记录(record),如果记录被写入,则返回
True
,否则返回False
,以指示key已经存在并且overwrite = False
。成功后,cursor位于新记录上。- key: Bytestring key to store.
- value: Bytestring value to store.
- dupdata: 如果
True
,并且数据库是用dupsort = True
打开的,如果给定key
已经存在,则添加键值对作为副本。否则覆盖任何现有匹配的key
。 - overwrite: If
False
, do not overwrite any existing matching key. - append: 如果为
True
,则将对附加到数据库末尾,而不首先比较其顺序。附加不大于现有最高key
的key
将导致损坏。 - db: 要操作的命名数据库。如果未指定,默认为事务构造函数被给定的数据库。
- replace(key, value, db=None): 使用临时cursor调用
Cursor.replace()
. - db: Named database to operate on. If unspecified, defaults to the database given to the Transaction constructor.
- stat(db): Return statistics like
Environment.stat()
, except for a single DBI.db
must be a database handle returned byopen_db()
.
Imdb.Cursor
class lmdb.Cursor(db, txn)
是用于在数据库中导航(navigate)的结构。
- db: Database to navigate.
- txn: Transaction to navigate.
As a convenience, Transaction.cursor()
can be used to quickly return a cursor:
<code style="margin-left:0">>>> env = lmdb.open('/tmp/foo') >>> child_db = env.open_db('child_db') >>> with env.begin() as txn: ... cursor = txn.cursor() # Cursor on main database. ... cursor2 = txn.cursor(child_db) # Cursor on child database.</code>
游标以未定位的状态开始。如果在这种状态下使用 iternext()
或 iterprev()
,那么迭代将分别从开始处和结束处开始。迭代器直接使用游标定位,这意味着在同一游标上存在多个迭代器时会产生奇怪的行为。
从Python绑定的角度来看,一旦任何扫描或查找方法(例如
next()
、prev_nodup()
、set_range()
)返回False
或引发异常,游标将返回未定位状态。这主要是为了确保在面对任何错误条件时语义的安全性和一致性。
当游标返回到未定位的状态时,它的key()
和value()
返回空字符串,表示没有活动的位置,尽管在内部,LMDB游标可能仍然有一个有效的位置。
这可能会导致在迭代dupsort=True
数据库的key
时出现一些令人吃惊的行为,因为iternext_dup()
等方法将导致游标显示为未定位,尽管它返回False
只是为了表明当前键没有更多的值。在这种情况下,简单地调用next()
将导致在下一个可用键处继续迭代。
This behaviour may change in future.
Iterator methods such as iternext()
and iterprev()
accept keys and values arguments. If both are True
, then the value of item()
is yielded on each iteration. If only keys is True
, key()
is yielded, otherwise only value()
is yielded.
在迭代之前,游标可能定位在数据库中的任何位置
<code style="margin-left:0">>>> with env.begin() as txn: ... cursor = txn.cursor() ... if not cursor.set_range('5'): # Position at first key >= '5'. ... print('Not found!') ... else: ... for key, value in cursor: # Iterate from first key >= '5'. ... print((key, value))</code>
不需要迭代来导航,有时会导致丑陋或低效的代码。在迭代顺序不明显的情况下,或者与正在读取的数据相关的情况下,使用 set_key()
、 set_range()
、 key()
、 value()
和 item()
可能是更好的选择。
<code style="margin-left:0">>>> # Record the path from a child to the root of a tree. >>> path = ['child14123'] >>> while path[-1] != 'root': ... assert cursor.set_key(path[-1]), \ ... 'Tree is broken! Path: %s' % (path,) ... path.append(cursor.value())</code>
几个实例方法:
- set_key(key): Seek exactly to key, returning
True
on success orFalse
if the exact key was not found. 对于set_key()
,空字节串是错误的。对于使用dupsort=True
打开的数据库,移动到键的第一个值(复制)。 - set_range(key): Seek to the first
key
greater than or equal tokey
, returningTrue
on success, orFalse
to indicatekey
was past end of database. Behaves likefirst()
if key is the empty bytestring. 对于使用dupsort=True
打开的数据库,移动到键的第一个值(复制)。 - get(key, default=None): Equivalent to
set_key()
, exceptvalue()
is returned when key is found, otherwise default. - item(): Return the current
(key, value)
pair. - key(): Return the current key.
- value(): Return the current value.
操作流程
概况地讲,操作LMDB的流程是:
- 通过
env = lmdb.open()
打开环境 - 通过
txn = env.begin()
建立事务 - 通过
txn.put(key, value)
进行插入和修改 - 通过
txn.delete(key)
进行删除 - 通过
txn.get(key)
进行查询 - 通过
txn.cursor()
进行遍历 - 通过
txn.commit()
提交更改
这里要注意:
put
和delete
后一定注意要commit
,不然根本没有存进去- 每一次
commit
后,需要再定义一次txn=env.begin(write=True)
来自https://github.com/kophy/py4db的代码:
<code style="margin-left:0">#!/usr/bin/env python import lmdb import os, sys def initialize(): env = lmdb.open("students"); return env; def insert(env, sid, name): txn = env.begin(write = True); txn.put(str(sid), name); txn.commit(); def delete(env, sid): txn = env.begin(write = True); txn.delete(str(sid)); txn.commit(); def update(env, sid, name): txn = env.begin(write = True); txn.put(str(sid), name); txn.commit(); def search(env, sid): txn = env.begin(); name = txn.get(str(sid)); return name; def display(env): txn = env.begin(); cur = txn.cursor(); for key, value in cur: print (key, value); env = initialize(); print "Insert 3 records." insert(env, 1, "Alice"); insert(env, 2, "Bob"); insert(env, 3, "Peter"); display(env); print "Delete the record where sid = 1." delete(env, 1); display(env); print "Update the record where sid = 3." update(env, 3, "Mark"); display(env); print "Get the name of student whose sid = 3." name = search(env, 3); print name; env.close(); os.system("rm -r students");</code>
创建图像数据集
这里主要借鉴自https://github.com/open-mmlab/mmsr/blob/master/codes/data_scripts/create_lmdb.py的代码。
改写为:
<code style="margin-left:0">import glob import os import pickle import sys import cv2 import lmdb import numpy as np from tqdm import tqdm def main(mode): proj_root = '/home/lart/coding/TIFNet' datasets_root = '/home/lart/Datasets/' lmdb_path = os.path.join(proj_root, 'datasets/ECSSD.lmdb') data_path = os.path.join(datasets_root, 'RGBSaliency', 'ECSSD/Image') if mode == 'creating': opt = { 'name': 'TrainSet', 'img_folder': data_path, 'lmdb_save_path': lmdb_path, 'commit_interval': 100, # After commit_interval images, lmdb commits 'num_workers': 8, } general_image_folder(opt) elif mode == 'testing': test_lmdb(lmdb_path, index=1) def general_image_folder(opt): """ Create lmdb for general image folders If all the images have the same resolution, it will only store one copy of resolution info. Otherwise, it will store every resolution info. """ img_folder = opt['img_folder'] lmdb_save_path = opt['lmdb_save_path'] meta_info = { 'name': opt['name']} if not lmdb_save_path.endswith('.lmdb'): raise ValueError("lmdb_save_path must end with 'lmdb'.") if os.path.exists(lmdb_save_path): print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) sys.exit(1) # read all the image paths to a list print('Reading image path list ...') all_img_list = sorted(glob.glob(os.path.join(img_folder, '*'))) # cache the filename, 这里的文件名必须是ascii字符 keys = [] for img_path in all_img_list: keys.append(os.path.basename(img_path)) # create lmdb environment # 估算大概的映射空间大小 data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes print('data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(all_img_list) env = lmdb.open(lmdb_save_path, map_size=data_size * 10) # map_size: # Maximum size database may grow to; used to size the memory mapping. If database grows larger # than map_size, an exception will be raised and the user must close and reopen Environment. # write data to lmdb txn = env.begin(write=True) resolutions = [] tqdm_iter = tqdm(enumerate(zip(all_img_list, keys)), total=len(all_img_list), leave=False) for idx, (path, key) in tqdm_iter: tqdm_iter.set_description('Write {}'.format(key)) key_byte = key.encode('ascii') data = cv2.imread(path, cv2.IMREAD_UNCHANGED) if data.ndim == 2: H, W = data.shape C = 1 else: H, W, C = data.shape resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W)) txn.put(key_byte, data) if (idx + 1) % opt['commit_interval'] == 0: txn.commit() # commit 之后需要再次 begin txn = env.begin(write=True) txn.commit() env.close() print('Finish writing lmdb.') # create meta information # check whether all the images are the same size assert len(keys) == len(resolutions) if len(set(resolutions)) <= 1: meta_info['resolution'] = [resolutions[0]] meta_info['keys'] = keys print('All images have the same resolution. Simplify the meta info.') else: meta_info['resolution'] = resolutions meta_info['keys'] = keys print('Not all images have the same resolution. Save meta info for each image.') pickle.dump(meta_info, open(os.path.join(lmdb_save_path, 'meta_info.pkl'), "wb")) print('Finish creating lmdb meta info.') def test_lmdb(dataroot, index=1): env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False) meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), "rb")) print('Name: ', meta_info['name']) print('Resolution: ', meta_info['resolution']) print('# keys: ', len(meta_info['keys'])) # read one image key = meta_info['keys'][index] print('Reading {} for test.'.format(key)) with env.begin(write=False) as txn: buf = txn.get(key.encode('ascii')) img_flat = np.frombuffer(buf, dtype=np.uint8) C, H, W = [int(s) for s in meta_info['resolution'][index].split('_')] img = img_flat.reshape(H, W, C) cv2.namedWindow('Test') cv2.imshow('Test', img) cv2.waitKeyEx() if __name__ == "__main__": # mode = creating or testing main(mode='creating')</code>
配合DataLoader
这里仅对训练集进行LMDB处理,测试机依旧使用的原始的读取图片的方式。
<code style="margin-left:0">import os import pickle import lmdb import numpy as np from PIL import Image from prefetch_generator import BackgroundGenerator from torch.utils.data import DataLoader, Dataset from torchvision import transforms from utils import joint_transforms def _get_paths_from_lmdb(dataroot): """get image path list from lmdb meta info""" meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) paths = meta_info['keys'] sizes = meta_info['resolution'] if len(sizes) == 1: sizes = sizes * len(paths) return paths, sizes def _read_img_lmdb(env, key, size): """read image from lmdb with key (w/ and w/o fixed size) size: (C, H, W) tuple""" with env.begin(write=False) as txn: buf = txn.get(key.encode('ascii')) img_flat = np.frombuffer(buf, dtype=np.uint8) C, H, W = size img = img_flat.reshape(H, W, C) return img def _make_dataset(root, prefix=('.jpg', '.png')): img_path = os.path.join(root, 'Image') gt_path = os.path.join(root, 'Mask') img_list = [ os.path.splitext(f)[0] for f in os.listdir(gt_path) if f.endswith(prefix[1]) ] return [(os.path.join(img_path, img_name + prefix[0]), os.path.join(gt_path, img_name + prefix[1])) for img_name in img_list] class TestImageFolder(Dataset): def __init__(self, root, in_size, prefix): self.imgs = _make_dataset(root, prefix=prefix) self.test_img_trainsform = transforms.Compose([ # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值 transforms.Resize((in_size, in_size)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def __getitem__(self, index): img_path, gt_path = self.imgs[index] img = Image.open(img_path).convert('RGB') img_name = (img_path.split(os.sep)[-1]).split('.')[0] img = self.test_img_trainsform(img) return img, img_name def __len__(self): return len(self.imgs) class TrainImageFolder(Dataset): def __init__(self, root, in_size, scale=1.5, use_bigt=False): self.use_bigt = use_bigt self.in_size = in_size self.root = root self.train_joint_transform = joint_transforms.Compose([ joint_transforms.JointResize(in_size), joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomRotate(10) ]) self.train_img_transform = transforms.Compose([ transforms.ColorJitter(0.1, 0.1, 0.1), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 处理的是Tensor ]) # ToTensor 操作会将 PIL.Image 或形状为 H×W×D,数值范围为 [0, 255] 的 np.ndarray 转换为形状为 D×H×W, # 数值范围为 [0.0, 1.0] 的 torch.Tensor。 self.train_target_transform = transforms.ToTensor() self.gt_root = '/home/lart/coding/TIFNet/datasets/DUTSTR/DUTSTR_GT.lmdb' self.img_root = '/home/lart/coding/TIFNet/datasets/DUTSTR/DUTSTR_IMG.lmdb' self.paths_gt, self.sizes_gt = _get_paths_from_lmdb(self.gt_root) self.paths_img, self.sizes_img = _get_paths_from_lmdb(self.img_root) self.gt_env = lmdb.open(self.gt_root, readonly=True, lock=False, readahead=False, meminit=False) self.img_env = lmdb.open(self.img_root, readonly=True, lock=False, readahead=False, meminit=False) def __getitem__(self, index): gt_path = self.paths_gt[index] img_path = self.paths_img[index] gt_resolution = [int(s) for s in self.sizes_gt[index].split('_')] img_resolution = [int(s) for s in self.sizes_img[index].split('_')] img_gt = _read_img_lmdb(self.gt_env, gt_path, gt_resolution) img_img = _read_img_lmdb(self.img_env, img_path, img_resolution) if img_img.shape[-1] != 3: img_img = np.repeat(img_img, repeats=3, axis=-1) img_img = img_img[:, :, [2, 1, 0]] # bgr => rgb img_gt = np.squeeze(img_gt, axis=2) gt = Image.fromarray(img_gt, mode='L') img = Image.fromarray(img_img, mode='RGB') img, gt = self.train_joint_transform(img, gt) gt = self.train_target_transform(gt) img = self.train_img_transform(img) if self.use_bigt: gt = gt.ge(0.5).float() # 二值化 img_name = self.paths_img[index] return img, gt, img_name def __len__(self): return len(self.paths_img) class DataLoaderX(DataLoader): def __iter__(self): return BackgroundGenerator(super(DataLoaderX, self).__iter__())</code>
参考链接
- 文档:https://lmdb.readthedocs.io/en/release/
- http://shuokay.com/2018/05/14/python-lmdb/
- 关于LMDB的介绍:https://blog.csdn.net/jyl1999xxxx/article/details/53942824
- 代码示例:
- https://github.com/kophy/py4db
- https://www.jianshu.com/p/66496c8726a1
- https://blog.csdn.net/ayst123/article/details/44077903
- https://github.com/open-mmlab/mmsr
- torchvision中涉及到lmdb使用的一部分代码:https://github.com/pytorch/vision/blob/master/torchvision/datasets/lsun.py
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/187821.html原文链接:https://javaforall.cn
未经允许不得转载:木盒主机 » PyTorch使用LMDB数据库加速文件读取[通俗易懂]