class MAKEUP(Dataset):
def __init__(self, image_path, transform, mode, transform_mask, cls_list):
self.image_path = image_path ##图片目录
self.transform = transform ##图片预处理接口
self.mode = mode ##模式,为训练或者测试
self.transform_mask = transform_mask ##掩膜预处理接口
self.cls_list = cls_list ##分类类别,为妆造和非妆造两类
self.cls_A = cls_list[0] ##第一类:makeup
self.cls_B = cls_list[1] ##第二类:non-makeup
##设置训练相关的属性变量,包括txt文件路径,每一行的内容以及行数
for cls in self.cls_list:
setattr(self, "train_" + cls + "_list_path", os.path.join(self.image_path, "train_" + cls + ".txt"))
setattr(self, "train_" + cls + "_lines", open(getattr(self, "train_" + cls + "_list_path"), 'r').readlines())
setattr(self, "num_of_train_" + cls + "_data", len(getattr(self, "train_" + cls + "_lines")))
##设置测试相关的属性变量,包括txt文件路径,每一行的内容以及行数
for cls in self.cls_list:
setattr(self, "test_" + cls + "_list_path", os.path.join(self.image_path, "test_" + cls + ".txt"))
setattr(self, "test_" + cls + "_lines", open(getattr(self, "test_" + cls + "_list_path"), 'r').readlines())
setattr(self, "num_of_test_" + cls + "_data", len(getattr(self, "test_" + cls + "_lines")))
self.preprocess() ##对数据文件进行预处理
def preprocess(self):
## 对makeup类和non-makeup类的训练txt文件进行随机打乱操作,取得RGB和MASK文件路径
for cls in self.cls_list:
setattr(self, "train_" + cls + "_filenames", [])
setattr(self, "train_" + cls + "_mask_filenames", [])
lines = getattr(self, "train_" + cls + "_lines")
random.shuffle(lines) ##对txt文件进行shuffle
for i, line in enumerate(lines):
splits = line.split()
getattr(self, "train_" + cls + "_filenames").append(splits[0])
getattr(self, "train_" + cls + "_mask_filenames").append(splits[1])
for cls in self.cls_list:
setattr(self, "test_" + cls + "_filenames", [])
setattr(self, "test_" + cls + "_mask_filenames", [])
lines = getattr(self, "test_" + cls + "_lines")
for i, line in enumerate(lines):
splits = line.split()
getattr(self, "test_" + cls + "_filenames").append(splits[0])
getattr(self, "test_" + cls + "_mask_filenames").append(splits[1])
## 从文件路径中获取RGB图片文件和MASK掩膜文件
def __getitem__(self, index):
##训练模式,随机设置A类(makeup)和B类(non-makeup)的indexA和indexB,需要读入RGB图像和对应的掩膜图像
if self.mode == 'train':
index_A = random.randint(0, getattr(self, "num_of_train_" + self.cls_A + "_data") - 1)
index_B = random.randint(0, getattr(self, "num_of_train_" + self.cls_B + "_data") - 1)
image_A = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_A + "_filenames")[index_A])).convert("RGB") ##读取RGB
image_B = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_B + "_filenames")[index_B])).convert("RGB") ##读取RGB
mask_A = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_A + "_mask_filenames")[index_A])) ##读取MASK
mask_B = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_B + "_mask_filenames")[index_B])) ##读取MASK
## 调用transform和transform_mask处理RGB图像和MASK图像
return self.transform(image_A), self.transform(image_B), self.transform_mask(mask_A), self.transform_mask(mask_B)
##测试模式,使用输入的index变量从A类(makeup)和B类(non-makeup)中各自取出一张图做测试,不需要读入掩膜
if self.mode in ['test', 'test_all']:
image_A = Image.open(os.path.join(self.image_path, getattr(self, "test_" + self.cls_A + "_filenames")[index // getattr(self, 'num_of_test_' + self.cls_list[1] + '_data')])).convert("RGB")
image_B = Image.open(os.path.join(self.image_path, getattr(self, "test_" + self.cls_B + "_filenames")[index % getattr(self, 'num_of_test_' + self.cls_list[1] + '_data')])).convert("RGB")
## 调用transform和transform_mask处理RGB图像和MASK图像
return self.transform(image_A), self.transform(image_B)