这部分课程我不明白这段代码做了什么:for file in os.listdir(path): if(os.path.isfile(os.path.join(path,file)) and select in file): temp = scipy.io.loadmat(os.path.join(path,file)) temp = {k:v for k, v in temp.items() if k[0] != '_'} for i in range(len(temp[patch_type+"_patches"])): self.tensors.append(temp[patch_type+"_patches"][i]) self.labels.append(temp[patch_type+"_labels"][0][i])self.tensors = np.array(self.tensors)self.labels = np.array(self.labels)尤其是这一行:temp = {k:v for k, v in temp.items() if k[0] != '_'}全班如下:class Datasets(Dataset): def __init__(self,path,train,transform=None): if(train): select ="Training" patch_type = "train" else: select = "Testing" patch_type = "testing" self.tensors = [] self.labels = [] self.transform = transform for file in os.listdir(path): if(os.path.isfile(os.path.join(path,file)) and select in file): temp = scipy.io.loadmat(os.path.join(path,file)) temp = {k:v for k, v in temp.items() if k[0] != '_'} for i in range(len(temp[patch_type+"_patches"])): self.tensors.append(temp[patch_type+"_patches"][i]) self.labels.append(temp[patch_type+"_labels"][0][i]) self.tensors = np.array(self.tensors) self.labels = np.array(self.labels) def __len__(self): try: if len(self.tensors) != len(self.labels): raise Exception("Lengths of the tensor and labels list are not the same") except Exception as e: print(e.args[0]) return len(self.tensors) def __getitem__(self,idx): sample = (self.tensors[idx],self.labels[idx]) # print(self.labels) sample = (torch.from_numpy(self.tensors[idx]),torch.from_numpy(np.array(self.labels[idx])).long()) return sample #tuple containing the image patch and its corresponding label
2 回答

凤凰求蛊
TA贡献1825条经验 获得超4个赞
这是一个字典理解;在这种特殊情况下,它dict从现有的 dict创建一个新的temp,但仅适用于键k不以下划线开头的项目。该检查由if ...零件执行。
它相当于
new = {}
for k, v in temp.items():
if key[0] != '_':
new[k] = value
temp = new
或者,略有不同:
new = {}
for key, value in temp.items():
if not key.startswith('_'):
new[key] = value
temp = new
您可以看到它作为单行看起来更好一些,因为它避免了临时 dict (new; 在幕后,它仍然创建了一个无名的临时 dict )。
添加回答
举报
0/150
提交
取消