Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

作者您好,为什么分类预测结果一直为daisy #7

Open
ZZHOO1 opened this issue Jul 16, 2024 · 4 comments
Open

作者您好,为什么分类预测结果一直为daisy #7

ZZHOO1 opened this issue Jul 16, 2024 · 4 comments

Comments

@ZZHOO1
Copy link

ZZHOO1 commented Jul 16, 2024

No description provided.

@ZZHOO1
Copy link
Author

ZZHOO1 commented Jul 16, 2024

Uploading predict.jpg…

@codecat0
Copy link
Owner

权重载入了吗

@ZZHOO1
Copy link
Author

ZZHOO1 commented Jul 17, 2024

我检查了代码,问题再加载数据时,不是加载的所有数据集,是第一个循环的数据

@ZZHOO1
Copy link
Author

ZZHOO1 commented Jul 17, 2024

数据集划分错误:
由于 read_split_data 函数中的 return 语句放置在第一个类别循环的末尾,导致函数只会处理第一个类别的数据并提前返回。确保数据集划分包含所有类别。
def read_split_data(root: str, val_rate: float = 0.2, plot_image: bool = False):
# 保证随机结果可复现
random.seed(0)
assert os.path.exists(root), f'dataset root {root} does not exist.'

# 遍历文件夹,一个文件夹对应一个类别
flower_classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]

# 排序,保证顺序一致
flower_classes.sort()

# 给类别进行编码,生成对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_classes))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as f:
    f.write(json_str)

# 训练集所有图片的路径和对应索引信息
train_images_path, train_images_label = [], []

# 验证集所有图片的路径和对应索引信息
val_images_path, val_images_label = [], []

# 每个类别的样本总数
every_class_num = []

# 支持的图片格式
images_format = [".jpg", ".JPG", ".png", ".PNG"]

# 遍历每个文件夹下的文件
for cla in flower_classes:
    cla_path = os.path.join(root, cla)

    # 获取每个类别文件夹下所有图片的路径
    images = [os.path.join(cla_path, i) for i in os.listdir(cla_path)
              if os.path.splitext(i)[-1] in images_format]

    # 获取类别对应的索引
    image_class = class_indices[cla]

    # 获取此类别的样本数
    every_class_num.append(len(images))

    # 按比例随机采样验证集
    val_path = random.sample(images, k=int(len(images) * val_rate))

    for img_path in images:
        if img_path in val_path:
            val_images_path.append(img_path)
            val_images_label.append(image_class)
        else:
            train_images_path.append(img_path)
            train_images_label.append(image_class)

print(f"{sum(every_class_num)} images found in dataset.")
print(f"{len(train_images_path)} images for training.")
print(f"{len(val_images_path)} images for validation.")

if plot_image:
    plt.bar(range(len(flower_classes)), every_class_num, align='center')
    plt.xticks(range(len(flower_classes)), flower_classes)
    for i, v in enumerate(every_class_num):
        plt.text(x=i, y=v + 5, s=str(v), ha='center')
    plt.xlabel('image class')
    plt.ylabel('number of images')
    plt.title('flower class distribution')
    plt.show()

return train_images_path, train_images_label, val_images_path, val_images_label

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants