实验三:KMeans图像聚类

作者 by 超米 / 2024-05-26 / 暂无评论 / 115 个足迹

示例代码:

# 导入必要的库
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people

# 获取LFW数据集
lfw_dataset = fetch_lfw_people(min_faces_per_person=70)

# 查看数据集信息
print("图像总数:", lfw_dataset.images.shape)
print("每个人物的图像数:", len(np.unique(lfw_dataset.target)))

# 显示部分图像
fig, ax = plt.subplots(10, 10, figsize=(10, 10))
for i, axi in enumerate(ax.flat):
    axi.imshow(lfw_dataset.images[i], cmap='gray')
    axi.set(xticks=[], yticks=[], xlabel=lfw_dataset.target[i])
plt.show()

# 图像矢量量化
data = lfw_dataset.images.reshape((len(lfw_dataset.images), -1))

# 初始化KMeans模型
kmeans = KMeans(n_clusters=10)

# 训练模型并进行聚类
kmeans.fit(data)
labels = kmeans.predict(data)

# 图像压缩与重建
compressed_data = kmeans.cluster_centers_[labels]
reconstructed_images = compressed_data.reshape(lfw_dataset.images.shape)

# 显示重建后的图像
fig, ax = plt.subplots(10, 10, figsize=(10, 10))
for i, axi in enumerate(ax.flat):
    axi.imshow(reconstructed_images[i], cmap='gray')
    axi.set(xticks=[], yticks=[], xlabel=lfw_dataset.target[i])
plt.show()

答案:

import numpy as np
import matplotlib.pyplot as plt
from PIL.Image import Image
from sklearn.datasets import fetch_lfw_people
from sklearn.utils import Bunch

print("-------------------------------------")
print("|物联2201  220223433  米热地力·买买提|")
print("-------------------------------------")
# 查看lfw_people数据集
people: Bunch = fetch_lfw_people(min_faces_per_person=20, resize=0.7)
print(people.target)  # 人物标记
print(people.target_names)  # 人物名
print(people['data'].shape)  # 数据形状
print(people['target'].shape)  # 标记形状
# 查看并显示人脸图像
image_shape = people.images[0].shape
print(image_shape)
print("Number of classes:", len(people.target_names))
print("shape of targetss:", people.target.shape)
fig, axes = plt.subplots(2, 5, figsize=(15, 8))
for target, image, ax in zip(people.target, people.images, axes.ravel()):
    ax.imshow(image)
    ax.set_title(people.target_names[target])
plt.show()

# 统计每个标记数量
counts = np.bincount(people.target)
for i, (count, name) in enumerate(zip(counts, people.target_names)):
    print("{0:25} {1:3}".format(name, count), end=' ')
    if (i + 1) % 4 == 0:
        print()

# - * - coding: utf-8 - * -
from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
import matplotlib
import matplotlib.pyplot as plt


def restore_image(cb, cluster, shape):
    row, col, dummy = shape
    image = np.empty((row, col, dummy))
    for r in range(row):
        for c in range(col):
            image[r, c] = cb[cluster[r * col + c]]
    return image


if __name__ == '__main__':
    matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
    matplotlib.rcParams['axes.unicode_minus'] = False
    # 聚类数 2,6,30
    num_vq = 2

    im: Image = Image.open('F:\shujuji\lfw_funneled\Tiger_Woods\Tiger_Woods_0023.jpg')
    image = np.array(im).astype(np.float_) / 255
    image = image[:, :, :3]
    image_v = image.reshape((-1, 3))
    kmeans = KMeans(n_clusters=num_vq, init='k-means++')

    N = image_v.shape[0]  # 图像像素总数
    # 选择样本,计算聚类中心
    idx = np.random.randint(0, N, size=int(N * 0.7))
    image_sample = image_v[idx]
    kmeans.fit(image_sample)

    result = kmeans.predict(image_v)  # 聚类结果
    print('聚类结果:\n', result)
    print('聚类中心:\n', kmeans.cluster_centers_)
    plt.figure(figsize=(15, 8), facecolor='w')
    plt.subplot(211)
    plt.axis('off')
    plt.title(u'原始图片', fontsize=18)
    plt.imshow(image)

    # 可以使用 plt.savefig('原始图片.png'),保存原始图片并对比
    plt.subplot(212)
    vq_image = restore_image(kmeans.cluster_centers_, result, image.shape)
    plt.axis('off')
    plt.title(u'聚类个数:%d' % num_vq, fontsize=20)
    plt.imshow(vq_image)

    # 可以使用 plt.savefig('矢量化图片.png'),保存处理后的图片并对比
    plt.tight_layout()
    plt.show()

独特见解