Reference:Datawhale tinydiffusion
Kaggle 實現:#
數據準備#
- git clone 下載 tinydiffusion 到本地
- 下載cifar-10-python數據集到 datasets/ 文件夾,保持目錄結構為 datasets/cifar-10-batches-py。
- 打包上傳到 kaggle dataset,命名為 cifar-10-python。
運行#
- 新建 notebook,input 添加數據集 cifar-10-python。
- 運行默認代碼
- 配置環境
!pip install -r /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/requirements.txt
- 複製數據集到 output 路徑下
import shutil
import os
# 定義數據集輸入目錄和輸出目錄
input_dir = '/kaggle/input/tinydiffusion'
output_dir = '/kaggle/output/tinydiffusion'
# 檢查輸出目錄是否存在,如果不存在則創建
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 遍歷輸入目錄中的所有文件和文件夾
for root, dirs, files in os.walk(input_dir):
# 構建對應的輸出目錄
relative_path = os.path.relpath(root, input_dir)
output_sub_dir = os.path.join(output_dir, relative_path)
# 如果輸出子目錄不存在,則創建
if not os.path.exists(output_sub_dir):
os.makedirs(output_sub_dir)
# 複製文件到輸出目錄
for file in files:
input_file_path = os.path.join(root, file)
output_file_path = os.path.join(output_sub_dir, file)
shutil.copy2(input_file_path, output_file_path)
print("數據集複製完成!")
- 創建路徑,放置每 10 個 epoch 產生的圖片
import os
# 定義保存圖像的目標目錄
save_dir = '/kaggle/output/tinydiffusion/TinyDiffusion/samples'
# 檢查目錄是否存在,如果不存在則創建
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(f"成功創建目錄: {save_dir}")
else:
print(f"目錄 {save_dir} 已經存在。")
- 進入 output 路徑
cd /kaggle/output/tinydiffusion/TinyDiffusion
- 處理數據
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/dataloader.py
- 打印結構
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/unet.py
- 圖像加噪
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/diffusion.py
- 訓練,這裡只訓練 100 次,如果要修改,在命令行修改即可
!python /kaggle/output/tinydiffusion/TinyDiffusion/ddpm/train.py --epochs 100 --batch_size 128 --lr 1e-4 --img_size 32
- 查看圖片
import os
from PIL import Image
import matplotlib.pyplot as plt
# 定義圖片目錄
save_dir = '/kaggle/output/tinydiffusion/TinyDiffusion/samples'
# 檢查目錄是否存在
if not os.path.exists(save_dir):
print(f"目錄 {save_dir} 不存在。")
else:
# 獲取目錄下的所有文件
all_files = os.listdir(save_dir)
# 過濾出圖片文件(這裡假設圖片文件的擴展名是 .png, .jpg, .jpeg)
image_files = [f for f in all_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if not image_files:
print(f"目錄 {save_dir} 中沒有圖片文件。")
else:
print(f"找到以下圖片文件:")
for idx, image_file in enumerate(image_files):
print(f"{idx + 1}. {image_file}")
# 顯示圖片(可選)
image_path = os.path.join(save_dir, image_file)
img = Image.open(image_path)
plt.imshow(img)
plt.title(image_file)
plt.axis('off')
plt.show()
結果#