Reference:Datawhale tinydiffusion
Kaggle 实现:#
数据准备#
- git clone 下载 tinydiffusion 到本地
- 下载cifar-10-python数据集到 datasets/ 文件夹,保持目录结构为 datasets/cifar-10-batches-py。
- 打包上传到 kaggle dataset,命名为 cifar-10-python。
运行#
- 新建 notebook,input 添加数据集 cifar-10-python。
- 运行默认代码
- 配置环境
!pip install -r /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/requirements.txt
- 复制数据集到 output 路径下
import shutil
import os
# 定义数据集输入目录和输出目录
input_dir = '/kaggle/input/tinydiffusion'
output_dir = '/kaggle/output/tinydiffusion'
# 检查输出目录是否存在,如果不存在则创建
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 遍历输入目录中的所有文件和文件夹
for root, dirs, files in os.walk(input_dir):
# 构建对应的输出目录
relative_path = os.path.relpath(root, input_dir)
output_sub_dir = os.path.join(output_dir, relative_path)
# 如果输出子目录不存在,则创建
if not os.path.exists(output_sub_dir):
os.makedirs(output_sub_dir)
# 复制文件到输出目录
for file in files:
input_file_path = os.path.join(root, file)
output_file_path = os.path.join(output_sub_dir, file)
shutil.copy2(input_file_path, output_file_path)
print("数据集复制完成!")
- 创建路径,放置每 10 个 epoch 产生的图片
import os
# 定义保存图像的目标目录
save_dir = '/kaggle/output/tinydiffusion/TinyDiffusion/samples'
# 检查目录是否存在,如果不存在则创建
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(f"成功创建目录: {save_dir}")
else:
print(f"目录 {save_dir} 已经存在。")
- 进入 output 路径
cd /kaggle/output/tinydiffusion/TinyDiffusion
- 处理数据
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/dataloader.py
- 打印结构
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/unet.py
- 图像加噪
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/diffusion.py
- 训练,这里只训练 100 次,如果要修改,在命令行修改即可
!python /kaggle/output/tinydiffusion/TinyDiffusion/ddpm/train.py --epochs 100 --batch_size 128 --lr 1e-4 --img_size 32
- 查看图片
import os
from PIL import Image
import matplotlib.pyplot as plt
# 定义图片目录
save_dir = '/kaggle/output/tinydiffusion/TinyDiffusion/samples'
# 检查目录是否存在
if not os.path.exists(save_dir):
print(f"目录 {save_dir} 不存在。")
else:
# 获取目录下的所有文件
all_files = os.listdir(save_dir)
# 过滤出图片文件(这里假设图片文件的扩展名是 .png, .jpg, .jpeg)
image_files = [f for f in all_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if not image_files:
print(f"目录 {save_dir} 中没有图片文件。")
else:
print(f"找到以下图片文件:")
for idx, image_file in enumerate(image_files):
print(f"{idx + 1}. {image_file}")
# 显示图片(可选)
image_path = os.path.join(save_dir, image_file)
img = Image.open(image_path)
plt.imshow(img)
plt.title(image_file)
plt.axis('off')
plt.show()
结果#