Hanah

Hanah

TinyDiffusion

Reference:Datawhale tinydiffusion

Kaggle 實現:#

數據準備#

  • git clone 下載 tinydiffusion 到本地
  • 下載cifar-10-python數據集到 datasets/ 文件夾,保持目錄結構為 datasets/cifar-10-batches-py。
  • 打包上傳到 kaggle dataset,命名為 cifar-10-python。

運行#

  • 新建 notebook,input 添加數據集 cifar-10-python。
  • 運行默認代碼
  • 配置環境
!pip install -r /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/requirements.txt
  • 複製數據集到 output 路徑下
import shutil
import os

# 定義數據集輸入目錄和輸出目錄
input_dir = '/kaggle/input/tinydiffusion'
output_dir = '/kaggle/output/tinydiffusion'

# 檢查輸出目錄是否存在,如果不存在則創建
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# 遍歷輸入目錄中的所有文件和文件夾
for root, dirs, files in os.walk(input_dir):
    # 構建對應的輸出目錄
    relative_path = os.path.relpath(root, input_dir)
    output_sub_dir = os.path.join(output_dir, relative_path)
    
    # 如果輸出子目錄不存在,則創建
    if not os.path.exists(output_sub_dir):
        os.makedirs(output_sub_dir)
    
    # 複製文件到輸出目錄
    for file in files:
        input_file_path = os.path.join(root, file)
        output_file_path = os.path.join(output_sub_dir, file)
        shutil.copy2(input_file_path, output_file_path)

print("數據集複製完成!")
  • 創建路徑,放置每 10 個 epoch 產生的圖片
import os

# 定義保存圖像的目標目錄
save_dir = '/kaggle/output/tinydiffusion/TinyDiffusion/samples'

# 檢查目錄是否存在,如果不存在則創建
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print(f"成功創建目錄: {save_dir}")
else:
    print(f"目錄 {save_dir} 已經存在。")
  • 進入 output 路徑
cd /kaggle/output/tinydiffusion/TinyDiffusion
  • 處理數據
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/dataloader.py
  • 打印結構
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/unet.py
  • 圖像加噪
!python /kaggle/input/tinydiffusion/TinyDiffusion/ddpm/diffusion.py
  • 訓練,這裡只訓練 100 次,如果要修改,在命令行修改即可
!python /kaggle/output/tinydiffusion/TinyDiffusion/ddpm/train.py --epochs 100 --batch_size 128 --lr 1e-4 --img_size 32
  • 查看圖片
import os
from PIL import Image
import matplotlib.pyplot as plt

# 定義圖片目錄
save_dir = '/kaggle/output/tinydiffusion/TinyDiffusion/samples'

# 檢查目錄是否存在
if not os.path.exists(save_dir):
    print(f"目錄 {save_dir} 不存在。")
else:
    # 獲取目錄下的所有文件
    all_files = os.listdir(save_dir)
    
    # 過濾出圖片文件(這裡假設圖片文件的擴展名是 .png, .jpg, .jpeg)
    image_files = [f for f in all_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    if not image_files:
        print(f"目錄 {save_dir} 中沒有圖片文件。")
    else:
        print(f"找到以下圖片文件:")
        for idx, image_file in enumerate(image_files):
            print(f"{idx + 1}. {image_file}")
            # 顯示圖片(可選)
            image_path = os.path.join(save_dir, image_file)
            img = Image.open(image_path)
            plt.imshow(img)
            plt.title(image_file)
            plt.axis('off')
            plt.show()

結果#

image
image
image
image
image
image
image
image
image
image

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。