Hanah

Hanah

TinyDiffusion

Reference:Datawhale tinydiffusion

Kaggle 实现:#

数据准备#

  • git clone 下载 tinydiffusion 到本地
  • 下载cifar-10-python数据集到 datasets/ 文件夹,保持目录结构为 datasets/cifar-10-batches-py。
  • 打包上传到 kaggle dataset,命名为 cifar-10-python。

运行#

  • 新建 notebook,input 添加数据集 cifar-10-python。
  • 运行默认代码
  • 配置环境
!pip install -r /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/requirements.txt
  • 复制数据集到 output 路径下
import shutil
import os

# 定义数据集输入目录和输出目录
input_dir = '/kaggle/input/tinydiffusion'
output_dir = '/kaggle/output/tinydiffusion'

# 检查输出目录是否存在,如果不存在则创建
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# 遍历输入目录中的所有文件和文件夹
for root, dirs, files in os.walk(input_dir):
    # 构建对应的输出目录
    relative_path = os.path.relpath(root, input_dir)
    output_sub_dir = os.path.join(output_dir, relative_path)
    
    # 如果输出子目录不存在,则创建
    if not os.path.exists(output_sub_dir):
        os.makedirs(output_sub_dir)
    
    # 复制文件到输出目录
    for file in files:
        input_file_path = os.path.join(root, file)
        output_file_path = os.path.join(output_sub_dir, file)
        shutil.copy2(input_file_path, output_file_path)

print("数据集复制完成!")
  • 创建路径,放置每 10 个 epoch 产生的图片
import os

# 定义保存图像的目标目录
save_dir = '/kaggle/output/tinydiffusion/TinyDiffusion/samples'

# 检查目录是否存在,如果不存在则创建
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print(f"成功创建目录: {save_dir}")
else:
    print(f"目录 {save_dir} 已经存在。")
  • 进入 output 路径
cd /kaggle/output/tinydiffusion/TinyDiffusion
  • 处理数据
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/dataloader.py
  • 打印结构
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/unet.py
  • 图像加噪
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/diffusion.py
  • 训练,这里只训练 100 次,如果要修改,在命令行修改即可
!python /kaggle/output/tinydiffusion/TinyDiffusion/ddpm/train.py --epochs 100 --batch_size 128 --lr 1e-4 --img_size 32
  • 查看图片
import os
from PIL import Image
import matplotlib.pyplot as plt

# 定义图片目录
save_dir = '/kaggle/output/tinydiffusion/TinyDiffusion/samples'

# 检查目录是否存在
if not os.path.exists(save_dir):
    print(f"目录 {save_dir} 不存在。")
else:
    # 获取目录下的所有文件
    all_files = os.listdir(save_dir)
    
    # 过滤出图片文件(这里假设图片文件的扩展名是 .png, .jpg, .jpeg)
    image_files = [f for f in all_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    if not image_files:
        print(f"目录 {save_dir} 中没有图片文件。")
    else:
        print(f"找到以下图片文件:")
        for idx, image_file in enumerate(image_files):
            print(f"{idx + 1}. {image_file}")
            # 显示图片(可选)
            image_path = os.path.join(save_dir, image_file)
            img = Image.open(image_path)
            plt.imshow(img)
            plt.title(image_file)
            plt.axis('off')
            plt.show()

结果#

image
image
image
image
image
image
image
image
image
image

加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。