Warning: file_get_contents(/data/phpspider/zhask/data//catemap/8/python-3.x/17.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 我尝试建立一个基于ICARL的16类植物图像数据增量学习系统,但训练损失保持在0.6931左右_Python_Python 3.x_Conv Neural Network_Trainingloss - Fatal编程技术网

Python 我尝试建立一个基于ICARL的16类植物图像数据增量学习系统,但训练损失保持在0.6931左右

Python 我尝试建立一个基于ICARL的16类植物图像数据增量学习系统,但训练损失保持在0.6931左右,python,python-3.x,conv-neural-network,trainingloss,Python,Python 3.x,Conv Neural Network,Trainingloss,我是一名初学者,希望构建一个基于ICARL的增量学习系统用于植物识别,我使用NORSE软件包构建了自己的resnet实现。自定义imagefolder数据加载器根据给定的类号范围提供数据,我的数据集是16类花卉图像,每个类有2688个列车图像和768个测试图像。我试了几天来调整学习速度、批量大小和历元数,但训练损失没有太大变化 主要编码: import os #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

我是一名初学者,希望构建一个基于ICARL的增量学习系统用于植物识别,我使用NORSE软件包构建了自己的resnet实现。自定义imagefolder数据加载器根据给定的类号范围提供数据,我的数据集是16类花卉图像,每个类有2688个列车图像和768个测试图像。我试了几天来调整学习速度、批量大小和历元数,但训练损失没有太大变化

主要编码:

import os
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

#import all module
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset 
from torchvision.datasets import ImageFolder, DatasetFolder 
from torch.utils.data import SubsetRandomSampler
from torch.utils.data import SequentialSampler
from torch.utils.data import BatchSampler
from torch.nn.parameter import Parameter
import torchvision
import numpy as np
from SpykeTorch import snn
from SpykeTorch import functional as sf
from SpykeTorch import visualization as vis
from SpykeTorch import utils
from torchvision import transforms
from tqdm import tqdm
import struct
import glob
import random
from PIL import Image
from torch import Tensor

import torchvision.datasets as dsets
import torchvision.models as models
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch import multiprocessing
from data_loader import FIORA16, iCIFAR10
from model import iCaRLNet
from tqdm import tqdm
from tqdm.notebook import trange
from tqdm.contrib import tenumerate
import time
from memory_profiler import profile

use_cuda = True
#torch.autograd.set_detect_anomaly(True)


random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

def show_images(images):
    N = images.shape[0]
    fig = plt.figure(figsize=(1, N))
    gs = gridspec.GridSpec(1, N)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in tenumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img)
    plt.show()

transform = transforms.Compose([
            transforms.Resize(32),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            #transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
        #transforms.Grayscale(),
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])


def run():
    multiprocessing.freeze_support()
    print('loop')

    
# Hyper Parameters
total_classes = 16
num_classes = 2

# Initialize CNN
if __name__ == '__main__':
    run()
    K = 2000 # total number of exemplars
    icarl = iCaRLNet(2048,1)
    icarl.cuda()

    
    for s in trange(0, total_classes, num_classes):
        t1_curr1 = perf_counter()
        print("training for class:",s,"~",s+num_classes-1)
        print("Load Datasets")
        train_set = FIORA16(root='data/train',
                             class_num=range(s,s+num_classes),
                             transform=transform_test)
        
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=100,
                                                   shuffle=True, num_workers=2)
                
        test_set = FIORA16(root='data/test',
                            class_num=range(s+num_classes),
                            transform=transform_test)
        
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=100,
                                                  shuffle=True, num_workers=2)
        
        print("Update representation via BackProp")
        icarl.update_representation(train_set)
        m = K // icarl.n_classes
        
        print("Reduce exemplar sets for known classes")
        icarl.reduce_exemplar_sets(m)
        
        print("Construct exemplar sets for new classes")
        for y in trange(icarl.n_known, icarl.n_classes):
            print ("Constructing exemplar set for class-%d..." %(y)),
            images = train_set.get_image_class(y)
            icarl.construct_exemplar_set(images, m, transform_test)
            print ("Done")
        
        for y, P_y in tenumerate(icarl.exemplar_sets):
            print ("Exemplar set for class-%d:" % (y), len(P_y))
            #for i in range(10):
                #show_images(P_y[i])

        icarl.n_known = icarl.n_classes
        print ("iCaRL known classes: %d" % icarl.n_known)
        
        print("Classify images by neares-means-of-exemplars")
        print("train data test")
        total = 0.0
        correct = 0.0
        for indices, images, labels in tqdm(train_loader):
            images = Variable(images).cuda()
            preds = icarl.classify(images, transform_test)
            total += labels.size(0)
            correct += (preds.data.cpu() == labels).sum()

        print('Train Accuracy: %d %%' % (100 * correct / total))
        
        print("test data test")
        total = 0.0
        correct = 0.0
        for indices, images, labels in tqdm(test_loader):
            images = Variable(images).cuda()
            preds = icarl.classify(images, transform_test)
            total += labels.size(0)
            correct += (preds.data.cpu() == labels).sum()

        print('Test Accuracy: %d %%' % (100 * correct / total))
        
        print("################################################################################")    
        
        #per batch timer
        t1_curr2 = perf_counter()
  
        print("Elapsed time:", t1_curr2, t1_curr1)

        print("Elapsed time during the whole program in seconds:",t1_curr2-t1_curr1)

        print("Elapsed time during the whole program in minutes:",(t1_curr2-t1_curr1)/60)
        
        print("################################################################################")
所使用的模型:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from PIL import Image
from torch import multiprocessing
from resnet import resnet18
from My_resnet import Norse
from tqdm import tqdm
from tqdm.notebook import trange
from tqdm.notebook import trange
from tqdm.contrib import tenumerate
from norse.torch.module.leaky_integrator import LILinearCell
import os
import random
from memory_profiler import profile

random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
#.requires_grad_()

#check device
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

# Hyper Parameters
num_epochs = 50
batch_size = 32
learning_rate = 0.002

class iCaRLNet(nn.Module):
    def __init__(self, feature_size, n_classes):
        # Network architecture
        super(iCaRLNet, self).__init__()
        self.feature_extractor = Norse()
        #self.feature_extractor = resnet18()
        self.feature_extractor.fc =\
            LILinearCell(self.feature_extractor.snn.fc.input_size, feature_size)
        #nn.Linear(self.feature_extractor.fc.in_features, feature_size)
        self.bn = nn.BatchNorm1d(feature_size, momentum=0.01)
        self.ReLU = nn.ReLU()
        self.fc = nn.Linear(feature_size, n_classes, bias=False)

        self.n_classes = n_classes
        self.n_known = 0

        # List containing exemplar_sets
        # Each exemplar_set is a np.array of N images
        # with shape (N, C, H, W)
        self.exemplar_sets = []

        # Learning method
        self.cls_loss = nn.CrossEntropyLoss()
        self.dist_loss = nn.BCELoss()
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate,
                                    weight_decay=0.00001)
        #self.optimizer = optim.SGD(self.parameters(), lr=learning_rate,
                                   #weight_decay=0.00001)

        # Means of exemplars
        self.compute_means = True
        self.exemplar_means = []

    def forward(self, x):
        #print('shape a1 ' + str(x.shape))
        x = self.feature_extractor(x)
        x = self.bn(x)
        x = self.ReLU(x)
        x = self.fc(x)
        return x

    def increment_classes(self, n):
        """Add n classes in the final fc layer"""
        in_features = self.fc.in_features
        out_features = self.fc.out_features
        weight = self.fc.weight.data

        self.fc = nn.Linear(in_features, out_features+n, bias=False)
        self.fc.weight.data[:out_features] = weight
        self.n_classes += n

    def classify(self, x, transform):
        #print("Classify images by neares-means-of-exemplars")
        """
        Args:
            x: input image batch
        Returns:
            preds: Tensor of size (batch_size,)
        """
        batch_size = x.size(0)
        self.eval()
        if self.compute_means:
            print ("Computing mean of exemplars..."),
            exemplar_means = []
            for P_y in tqdm(self.exemplar_sets):
                features = []
                # Extract feature for each exemplar in P_y
                with torch.no_grad():
                    for ex in (P_y):
                        #ex = ex.detach()
                        ex = ex.cuda()
                        #ex = ex.requires_grad_(False).cuda()
                        feature = self.feature_extractor(ex.unsqueeze(0))
                        feature = feature.squeeze()
                        feature.data = feature.data / feature.data.norm() # Normalize
                        features.append(feature)
                    features = torch.stack(features)
                    mu_y = features.mean(0).squeeze()
                    mu_y.data = mu_y.data / mu_y.data.norm() # Normalize
                    exemplar_means.append(mu_y)
            self.exemplar_means = exemplar_means
            self.compute_means = False
            print ("Done")

        exemplar_means = self.exemplar_means
        means = torch.stack(exemplar_means) # (n_classes, feature_size)
        means = torch.stack([means] * batch_size) # (batch_size, n_classes, feature_size)
        means = means.transpose(1, 2) # (batch_size, feature_size, n_classes)
        
        with torch.no_grad():
            feature = self.feature_extractor(x) # (batch_size, feature_size)
            for i in range(feature.size(0)): # Normalize
                feature.data[i] = feature.data[i] / feature.data[i].norm()
            feature = feature.unsqueeze(2) # (batch_size, feature_size, 1)
            feature = feature.expand_as(means) # (batch_size, feature_size, n_classes)

            dists = (feature - means).pow(2).sum(1).squeeze() #(batch_size, n_classes)
            _, preds = dists.min(1)

        return preds
        
    def construct_exemplar_set(self, images, m, transform):
        """Construct an exemplar set for image set

        Args:
            images: np.array containing images of a class
        """
        print("Compute and cache features for each example")
        features = []
        self.eval()
        
        for img in tqdm(images):
            x = img.detach().cuda()
            #x = img.requires_grad_(False).cuda()
            #x = transform(Image.fromarray(img))#.cuda()
            feature = self.feature_extractor(x.unsqueeze(0)).data.cpu().numpy()
            feature = feature / np.linalg.norm(feature) # Normalize
            features.append(feature[0])

        features = np.array(features)
        class_mean = np.mean(features, axis=0)
        class_mean = class_mean / np.linalg.norm(class_mean) # Normalize

        exemplar_set = []
        exemplar_features = [] # list of Variables of shape (feature_size,)
        print(m)
        
        for k in trange(m):
            S = np.sum(exemplar_features, axis=0)
            phi = features
            mu = class_mean
            mu_p = 1.0/(k+1) * (phi + S)
            mu_p = mu_p / np.linalg.norm(mu_p)
            i = np.argmin(np.sqrt(np.sum((mu - mu_p) ** 2, axis=1)))

            exemplar_set.append(images[i])
            exemplar_features.append(features[i])
            """
            print "Selected example", i
            print "|exemplar_mean - class_mean|:",
            print np.linalg.norm((np.mean(exemplar_features, axis=0) - class_mean))
            #features = np.delete(features, i, axis=0)
            """
                
        #self.exemplar_sets.append(np.array(exemplar_set))
        self.exemplar_sets.append(exemplar_set)
                

    def reduce_exemplar_sets(self, m):
        for y, P_y in tenumerate(self.exemplar_sets):
            self.exemplar_sets[y] = P_y[:m]


    def combine_dataset_with_exemplars(self, dataset):
        for y, P_y in tenumerate(self.exemplar_sets):
            exemplar_images = P_y
            exemplar_labels = [y] * len(P_y)
            dataset.append(exemplar_images, exemplar_labels)

    def update_representation(self, dataset):

        self.compute_means = True

        print("Increment number of weights in final fc layer")
        classes = list(set(dataset.targets))
        new_classes = [cls for cls in classes if cls > self.n_classes - 1]
        self.increment_classes(len(new_classes))
        self.cuda()
        print ("%d new classes" % (len(new_classes)))

        print("Form combined training set")
        self.combine_dataset_with_exemplars(dataset)

        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                               shuffle=True, num_workers=2, drop_last=True)

        print("Store network outputs with pre-update parameters")
        q = torch.zeros(len(dataset), self.n_classes).cuda()
        
        for indices, images, labels in tqdm(loader):
            images = Variable(images).cuda()
            #images = images.requires_grad_().cuda()
            indices = indices.cuda()
            labels = labels
            g = torch.sigmoid(self.forward(images))
            q[indices] = g.data
        
        q = Variable(q).cuda()
        #q = q.requires_grad_().cuda()

        print("Run network training")
        optimizer = self.optimizer
        #scaler = torch.cuda.amp.GradScaler()
        self.train()

        for epoch in trange(num_epochs):
            for i, (indices, images, labels) in enumerate(loader):
                images = Variable(images).cuda()
                labels = Variable(labels).cuda()
                #images = images.requires_grad_().cuda()
                #labels = labels.requires_grad_().cuda()
                indices = indices.cuda()
                
                #with torch.cuda.amp.autocast():
                optimizer.zero_grad()
                g = self.forward(images)
                
                # Classification loss for new classes
                loss = self.cls_loss(g, labels)
                #loss = loss / len(range(self.n_known, self.n_classes))

                # Distilation loss for old classes
                if self.n_known > 0:
                    g = torch.sigmoid(g)
                    q_i = q[indices]
                    dist_loss = sum(self.dist_loss(g[:,y], q_i[:,y])\
                            for y in range(self.n_known))
                    dist_loss = dist_loss / self.n_known
                    loss += dist_loss

                #loss.backward()
                #optimizer.step()
                #scaler.scale(loss).backward()
                #scaler.step(optimizer)
                #scaler.update()
                #if (i+1) % 2 == 0:
                    #scaler.step(optimizer)
                    #scaler.update()
                    #optimizer.zero_grad()    

                if (i+1) % 10 == 0:
                    print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' 
                           %(epoch+1, num_epochs, i+1, len(dataset)//batch_size, loss.data))
神经网络:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import multiprocessing
from norse.torch import LIFParameters, LIFState
from norse.torch.module.lif import LIFCell, LIFRecurrentCell
from norse.torch.module.leaky_integrator import LILinearCell
from norse.torch.functional.lif import LIFFeedForwardState
from norse.torch.functional.leaky_integrator import LIState
from tqdm.notebook import trange

import importlib
from norse.torch.module import encode
encode = importlib.reload(encode)

from torch.autograd import Variable
from typing import NamedTuple

def decode(x):
    x, _ = torch.max(x, 0)
    log_p_y = nn.functional.log_softmax(x, dim=1)
    return log_p_y
    #return x

def decode_last(x):
    x = x[-1]
    log_p_y = nn.functional.log_softmax(x, dim=1)
    return log_p_y

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class ConvNet(nn.Module):
    def __init__(
        self,  num_channels=64, feature_size=32, method="super", alpha=100, num_classes=16
    ):
        super(ConvNet, self).__init__()

        #self.features = int(((feature_size - 4) / 2 - 4) / 2)#+1

        self.conv1 = torch.nn.Conv2d(num_channels, 128, 2, 2, 0)
        self.conv2 = torch.nn.Conv2d(128, 256, 9, 2, 2)
        self.conv3 = torch.nn.Conv2d(256, 512, 9, 2, 2)
        #self.fc1 = torch.nn.Linear(self.features * self.features * 50, 512)
        self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=alpha))
        self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=alpha))
        self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=alpha))
        self.lif3 = LIFCell(p=LIFParameters(method=method, alpha=alpha))
        #self.avgpool = nn.AvgPool2d(4)
        self.fc = LILinearCell(512, 2048)#feature_size to be tuned
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(512)
        
    def forward(self, x):
        seq_length = x.shape[0]
        batch_size = x.shape[1]
        
        # specify the initial states
        s0 = s1 = s2 = s3 = so = None

        voltages = torch.zeros(
            seq_length, batch_size, 2048, device=x.device, dtype=x.dtype
        )

        for ts in range(seq_length):
            z = self.conv1(x[ts, :])
            #print('shape z1 ' + str(z.shape))
            z = self.bn1(z)
            #print('shape z2 ' + str(z.shape))
            z, s0 = self.lif0(F.relu(z), s0)
            #print('shape z3 ' + str(z.shape))
            z = self.conv2(z)
            #print('shape z4 ' + str(z.shape))
            z = self.bn2(z)
            #print('shape z5 ' + str(z.shape))
            z, s1 = self.lif1(F.relu(z), s1)
            #print('shape z6 ' + str(z.shape))
            z = self.conv3(z)
            #print('shape z7 ' + str(z.shape))
            z = self.bn3(z)
            #print('shape z8 ' + str(z.shape))
            z, s2 = self.lif1(F.relu(z), s2)
            #print('shape z9 ' + str(z.shape))
            #z = self.avgpool(z)
            #print('shape z10 ' + str(z.shape))
            z = z.view(z.size(0), -1) 
            #print('shape z11 ' + str(z.shape))
            z, s3 = self.lif2(F.relu(z), s3)
            #print('shape z12 ' + str(z.shape))
            v, so = self.fc(F.relu(z), so)
            #print('shape z13 ' + str(v.shape))
            voltages[ts, :, :] = v
        return voltages    
    
class Norse(nn.Module):
    def __init__(self):
        super(Norse, self).__init__()
        self.in_planes = 64
        T = 15
        self.pre_layers = nn.Sequential(conv3x3(3,64),
                                        nn.BatchNorm2d(64),
                                        nn.ReLU(True),)
        #self.encoder = encode.SpikeLatencyLIFEncoder(T)
        self.encoder = encode.ConstantCurrentLIFEncoder(T)
        self.snn = ConvNet(alpha=80)
        self.decoder = decode

    def forward(self, x):
        #print('shape x1 ' + str(x.shape))
        x = self.pre_layers(x)
        #print('shape x2 ' + str(x.shape))
        x = self.encoder(x)
        #print('shape x3 ' + str(x.shape))
        x = self.snn(x)
        #print('shape x4 ' + str(x.shape))
        x = self.decoder(x)
        #print('shape x5 ' + str(x.shape))
        return x    
数据加载器:

from torchvision.datasets import VisionDataset
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from torchvision.datasets import CIFAR10
from torch import multiprocessing
from PIL import Image
from tqdm import tqdm
from tqdm.notebook import trange
from tqdm.contrib import tenumerate
from torchvision.datasets import ImageFolder
from torchvision.datasets import DatasetFolder
from torch.utils.data import SubsetRandomSampler
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

import numpy as np
import torch


  
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
    return filename.lower().endswith(extensions)


def is_image_file(filename: str) -> bool:
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


def make_dataset(
    directory: str,
    class_num: list,
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:

    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = path, class_index
                    if class_index in class_num:
                        instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances

def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


# TODO: specify the return type
def accimage_loader(path: str) -> Any:
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path: str) -> Any:
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

class FIORA16(VisionDataset):
    def __init__(
            self,
            root: str,
            class_num: list,
            loader: Callable[[str], Any] = default_loader,
            extensions: Optional[Tuple[str, ...]]=('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'),
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super(FIORA16, self).__init__(root, transform=transform,
                                      target_transform=target_transform,
                                      )
        classes, class_to_idx = self.find_classes(self.root)
        self.class_num=list(class_num)
        samples = self.make_dataset(self.root, self.class_num, class_to_idx, extensions, is_valid_file)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.data =[]
        for i in trange(len(self.samples)):
            sample = self.loader(self.samples[i][0])
            if self.transform is not None:
                sample = self.transform(sample)
            self.data.append(sample)
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_num: list,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
        return make_dataset(directory, class_num, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

    def find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
        """Same as :func:`find_classes`.
        This method can be overridden to only consider
        a subset of classes, or to adapt to a different dataset directory structure.
        """
        return find_classes(dir)
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img, target = self.data[index], self.targets[index]
        #path, _ = self.samples[index]
        #sample = self.loader(path)
        #if self.transform is not None:
            #sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target

    def __len__(self) -> int:
        return len(self.targets)

    IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

    def get_image_class(self, label):
        datac = []
        for i in trange(len(self.targets)):
            if self.targets[i]==label:
                datac.append(self.data[i])
        return datac

    def append(self, images, labels):
        self.data.extend(images)
        self.targets = self.targets + labels

    
    def pil_loader(path: str) -> Image.Image:
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')


    # TODO: specify the return type
    def accimage_loader(path: str) -> Any:
        import accimage
        try:
            return accimage.Image(path)
        except IOError:
            # Potentially a decoding problem, fall back to PIL.Image
            return pil_loader(path)


    def default_loader(path: str) -> Any:
        from torchvision import get_image_backend
        if get_image_backend() == 'accimage':
            return accimage_loader(path)
        else:
            return pil_loader(path)


对不起,代码还是很脏