๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ˜ŽAI/Generative AI

[Paper Review] Classifier-Free Diffusion Guidance

by SolaKim 2025. 2. 5.

https://arxiv.org/abs/2207.12598

 

Classifier-Free Diffusion Guidance

Classifier guidance is a recently introduced method to trade off mode coverage and sample fidelity in conditional diffusion models post training, in the same spirit as low temperature sampling or truncation in other types of generative models. Classifier g

arxiv.org

 

 

Introduce

 

์ด ๋…ผ๋ฌธ์€ classifier guidance ๋…ผ๋ฌธ์—์„œ classifier ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ ๋„ controllablity ๋ฅผ ๋ถ€์—ฌํ•  ์ˆ˜ ์žˆ๋Š”์ง€์— ๋Œ€ํ•ด์„œ ์—ฐ๊ตฌ๋ฅผ ํ•œ ๋…ผ๋ฌธ์ž…๋‹ˆ๋‹ค.

Classifier Guidance ๋…ผ๋ฌธ์€ diffusion ๋ชจ๋ธ์—์„œ controllablity ๋ฅผ ๋ถ€์—ฌํ•œ ๋…ผ๋ฌธ์œผ๋กœ, mode coverage(๋‹ค์–‘์„ฑ)๊ณผ sample fidelity(์ •ํ™•์„ฑ) ์˜ trade off ๋ฅผ ํ†ตํ•ด์„œ ๊ฒฐ๊ณผ๋ฌผ์„ ์›ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ๋„์ถœํ•ด๋‚ผ ์ˆ˜ ์žˆ๋„๋ก ์—ฐ๊ตฌ๋ฅผ ํ–ˆ์Šต๋‹ˆ๋‹ค. 

Classifier Guidance ๋Š” ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ’ˆ์งˆ์„ ๋†’์ด๊ธฐ ์œ„ํ•ด ํ™•์‚ฐ ๋ชจ๋ธ์— ๋ถ„๋ฅ˜๊ธฐ(Classifier)์˜ ๊ทธ๋ž˜๋””์–ธํŠธ(Gradient) ๋ฅผ ํ™œ์šฉํ•˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.

  • ํ•ต์‹ฌ ์•„์ด๋””์–ด: ํ™•์‚ฐ ๋ชจ๋ธ์˜ ์ƒ˜ํ”Œ์ด ๋ถ„๋ฅ˜๊ธฐ์— ๋” ์ž˜ ๋งž๋„๋ก ์ƒ˜ํ”Œ๋ง ๊ณผ์ •์„ ์œ ๋„(guidance) ํ•ฉ๋‹ˆ๋‹ค.

  • ฯตθโ€‹(z,c): ํ™•์‚ฐ ๋ชจ๋ธ์˜ ์Šค์ฝ”์–ด ์ถ”์ •์น˜ (denoising score)
  • zโ€‹ log pclassifierโ€‹(cโˆฃz): ๋ถ„๋ฅ˜๊ธฐ์˜ ๊ทธ๋ž˜๋””์–ธํŠธ
  • w: guidance strength (๊ฐ€์ค‘์น˜)
  • ์ฆ‰, ๋ถ„๋ฅ˜๊ธฐ์˜ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ํ™œ์šฉํ•ด "๋” ์ •ํ™•ํ•œ" ์ƒ˜ํ”Œ์„ ๋งŒ๋“ค๋„๋ก ๋ชจ๋ธ์„ ์œ ๋„ํ•ฉ๋‹ˆ๋‹ค.

 

ํ•˜์ง€๋งŒ, ์ด Classifier Guidance ๋Š” ๋‹ค์–‘ํ•œ ํ•œ๊ณ„์ ์ด ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.

  1. ๋ชจ๋ธ ํ•™์Šต ๊ณผ์ •์˜ ๋ณต์žก์„ฑ
    •  ํ™•์‚ฐ ๋ชจ๋ธ ์™ธ์— ๋ณ„๋„์˜ ๋ถ„๋ฅ˜๊ธฐ(classifier) ์„ ์ถ”๊ฐ€๋กœ ํ•™์Šตํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
    •  ์ด๋Š” ๋ชจ๋ธ์˜ ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ณต์žกํ•˜๊ฒŒ ๋งŒ๋“ค๊ณ  ๊ณ„์‚ฐ์  ์ž์›์„ ์š”๊ตฌํ•ฉ๋‹ˆ๋‹ค.
  2. ์‚ฌ์ „ ํ•™์Šต๋œ ๋ถ„๋ฅ˜๊ธฐ ์‚ฌ์šฉ ๋ถˆ๊ฐ€๋Šฅ
    •  ๋…ธ์ด์ฆˆ๊ฐ€ ์žˆ๋Š” ๋ฐ์ดํ„ฐ์— ๋งž์ถฐ ํ•™์Šต๋œ ์ƒˆ๋กœ์šด ๋ถ„๋ฅ˜๊ธฐ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. (๋ณ„๋„์˜ ๋ถ„๋ฅ˜๊ธฐ ์ถ”๊ฐ€ ํ•™์Šต ํ•„์š”)
  3. ์ž ์žฌ์ ์ธ ์ ๋Œ€์  ํ•™์Šต
    •  classifier guidance ์˜ ์ƒ˜ํ”Œ๋ง ๊ณผ์ •(์œ„์˜ ์‹ ์ฐธ๊ณ )์€ ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ์†์ด๊ธฐ ์œ„ํ•ด ๊ฒฝ๊ณ„์„ ์„ ๊ณต๊ฒฉํ•˜๋Š” ์ ๋Œ€์  ๊ณต๊ฒฉ(adversarial attack) ๊ณผ ์œ ์‚ฌํ•ฉ๋‹ˆ๋‹ค. 
  4. ํ‰๊ฐ€ ์ง€ํ‘œ์˜ ์‹ ๋ขฐ์„ฑ ๋ฌธ์ œ
    •  3๋ฒˆ์˜ ์ด์œ ๋กœ, ์ƒ˜ํ”Œ ํ’ˆ์งˆ์ด ์ข‹์•„์ง€๋Š” ์ด์œ ๊ฐ€ ์ง„์งœ ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ํ–ฅ์ƒ ๋•Œ๋ฌธ์ธ์ง€, ์•„๋‹ˆ๋ฉด ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ์ž˜ ์†์—ฌ์„œ(=์ ๋Œ€์  ๊ณต๊ฒฉ) ํ‰๊ฐ€ ์ง€ํ‘œ๊ฐ€ ๋†’์•„์ง„ ๊ฒƒ์ธ์ง€ ๋ถˆ๋ถ„๋ช…ํ•ฉ๋‹ˆ๋‹ค.

 

์ด ํ•œ๊ณ„์ ์„ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด Classifier Free ์ธ ์ด ๋…ผ๋ฌธ์ด ๋“ฑ์žฅํ•˜๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

  • ๋ถ„๋ฅ˜๊ธฐ ์—†์ด๋„ ์ƒ˜ํ”Œ ํ’ˆ์งˆ ๊ฐœ์„ 
    • Conditional ๊ณผ Unconditional ๋ชจ๋ธ์˜ score ๋ฅผ ์กฐํ•ฉํ•˜์—ฌ ๋ถ„๋ฅ˜๊ธฐ ์—†์ด๋„ ํ’ˆ์งˆ์„ ๊ฐœ์„ ํ•˜๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค.
  • ์ ๋Œ€์  ๊ณต๊ฒฉ ๋ฐฉ์ง€
    • classifier gradient ๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์—, ์ ๋Œ€์  ๊ณต๊ฒฉ๊ณผ ๊ฐ™์€ ์ด์Šˆ๊ฐ€ ์‚ฌ๋ผ์ง‘๋‹ˆ๋‹ค.
  • ํ‰๊ฐ€ ์ง€ํ‘œ์˜ ์‹ ๋ขฐ์„ฑ ํ™•๋ณด
    • FID, IS ๋“ฑ์˜ ํ‰๊ฐ€ ์ง€ํ‘œ ๊ฐœ์„ ์ด ์‹ค์ œ ์ƒ˜ํ”Œ ํ’ˆ์งˆ ํ–ฅ์ƒ๊ณผ ๋” ๋ฐ€์ ‘ํ•˜๊ฒŒ ์—ฐ๊ฒฐ๋ฉ๋‹ˆ๋‹ค.

 


Background

 

Diffusion Model

 

<Forward Process>

์ˆ˜์‹ (1) ์„ค๋ช…

  • ๐Ÿ“Œ λ : Signal-to-Noise Ratio(SNR) ์˜ ๋กœ๊ทธ ์Šค์ผ€์ผ ํ‘œํ˜„
    • λ ๊ฐ€ ํด์ˆ˜๋ก ๋…ธ์ด์ฆˆ๊ฐ€ ์ ์Œ(๊นจ๋—ํ•œ ๋ฐ์ดํ„ฐ ์ƒํƒœ) (๋†’์€ SNR)
    • λ ๊ฐ€ ์ž‘์„์ˆ˜๋ก ๋…ธ์ด์ฆˆํ™”๋œ ๋ฐ์ดํ„ฐ (๋‚ฎ์€ SNR)
    • λ ๋Š” ์‹ค์ œ๋กœ ์‹œ๊ฐ„ ์ถ•์ฒ˜๋Ÿผ ์‚ฌ์šฉ๋จ
      • "๊นจ๋—ํ•œ ๋ฐ์ดํ„ฐ" → "์™„์ „ํ•œ ๋…ธ์ด์ฆˆ"๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ณผ์ •์„ ๋‚˜ํƒ€๋ƒ„
      • → ๊ฑฐ์˜ ๋…ธ์ด์ฆˆ๋กœ ๊ฐ€๋“ ์ฐฌ ์ƒํƒœ
      • λmax=20 → ์›๋ณธ ๋ฐ์ดํ„ฐ์— ๊ฐ€๊นŒ์šด ์ƒํƒœ
  • q(zλโˆฃx): ๋ฐ์ดํ„ฐ x ์— ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•œ ๊ฒฐ๊ณผ์ธ zλ์˜ ํ™•๋ฅ  ๋ถ„ํฌ(๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ)
  • αλโ€‹: ์›๋ณธ ๋ฐ์ดํ„ฐ์˜ ์Šค์ผ€์ผ๋ง ๊ณ„์ˆ˜. ๋…ธ์ด์ฆˆ๊ฐ€ ์ถ”๊ฐ€๋ ์ˆ˜๋ก ์ด ๊ฐ’์€ ์ ์  ์ž‘์•„์ง
    • ๋…ธ์ด์ฆˆ๊ฐ€ ๋งŽ์„์ˆ˜๋ก ์›๋ณธ๋ฐ์ดํ„ฐ์˜ ์˜ํ–ฅ์ด ์ค„์–ด๋“ฌ

  • σλ^2*I: ๋…ธ์ด์ฆˆ์˜ ๋ถ„์‚ฐ์„ ๋‚˜ํƒ€๋‚ด๋Š” ํ•ญ
    • ๋…ธ์ด์ฆˆ๊ฐ€ ๋งŽ์ด ์ถ”๊ฐ€๋  ์ˆ˜๋ก ์ด ๊ฐ’์€ ์ปค์ง

  • : ๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ๋ฅผ ์˜๋ฏธํ•˜๋ฉฐ, ํ‰๊ท ์€ αλx, ๋ถ„์‚ฐ์€ σλ^2
  • Forward Process ์—์„œ๋Š” ์‹œ๊ฐ„ t ๊ฐ€ ์ฆ๊ฐ€ํ•จ์— ๋”ฐ๋ผ:
    • αλ๋Š” ๊ฐ์†Œ → ์›๋ณธ ๋ฐ์ดํ„ฐ์˜ ์˜ํ–ฅ ๊ฐ์†Œ (์‹ ํ˜ธ ์•ฝํ™”)
    • σλ^2๋Š” ์ฆ๊ฐ€ → ๋…ธ์ด์ฆˆ์˜ ์˜ํ–ฅ ์ฆ๊ฐ€
    • ๊ทธ๋Ÿฌ๋‚˜ SNR์˜ ๋กœ๊ทธ ์Šค์ผ€์ผ λโ€‹๋Š” ์ด์™€ ๋ฐ˜๋Œ€๋กœ, λโ€‹๊ฐ€ ํด์ˆ˜๋ก ๋…ธ์ด์ฆˆ๊ฐ€ ์ ๊ณ (๊นจ๋—ํ•œ ๋ฐ์ดํ„ฐ), λโ€‹๊ฐ€ ์ž‘์„์ˆ˜๋ก ๋…ธ์ด์ฆˆ๊ฐ€ ๋งŽ๋‹ค. 
    • Forward Process ์—์„œ๋Š” ์‹œ๊ฐ„์ด ์ง€๋‚จ์— ๋”ฐ๋ผ λโ€‹๊ฐ€ ์ž‘์•„์ง‘๋‹ˆ๋‹ค.
  • ๊ฒฐ๊ตญ zλ ๋Š” ์ ์  ์™„์ „ํžˆ ๋…ธ์ด์ฆˆํ™”๋œ ๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ์— ๊ฐ€๊นŒ์›Œ์ง‘๋‹ˆ๋‹ค.

 

์ˆ˜์‹ (2) ์„ค๋ช…

  • q(zλโ€‹โˆฃzλโ€‹): ์‹œ๊ฐ„ λ ๋‹จ๊ณ„์˜ ๋…ธ์ด์ฆˆ ๋ฐ์ดํ„ฐ zλโ€‹ ์—์„œ ๋” ๋งŽ์€ ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ๋‹ค์Œ ๋‹จ๊ณ„ λ′ ์˜ ๋ฐ์ดํ„ฐ zλ′ ๋กœ ์ „ํ™˜ํ•˜๋Š” ํ™•๋ฅ  ๋ถ„ํฌ
  • โ€‹(αλโ€‹/ αλ)โ€‹ zλ: ์ด์ „ ๋‹จ๊ณ„์˜ ๋ฐ์ดํ„ฐ zλ ๋ฅผ ์Šค์ผ€์ผ๋งํ•œ ๊ฐ’์œผ๋กœ, ๊ธฐ์กด ์ •๋ณด์˜ ์œ ์ง€ ์ •๋„๋ฅผ ๊ฒฐ์ •

  • ์ด ์ˆ˜์‹์€ ๋…ธ์ด์ฆˆ ์ถ”๊ฐ€๋ฅผ ์—ฌ๋Ÿฌ ๋‹จ๊ณ„๋กœ ๋‚˜๋ˆ ์„œ ์ ์ง„์ ์œผ๋กœ ์ ์šฉํ•˜๋Š” ๊ณผ์ •์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด ๊ณผ์ •์ด Markov Process (๋งˆ๋ฅด์ฝ”ํ”„ ๊ณผ์ •) ์ฒ˜๋Ÿผ ์ด์ „ ์ƒํƒœ์—๋งŒ ์˜์กดํ•˜์—ฌ ๋‹ค์Œ ์ƒํƒœ๋กœ ์ „์ด๋ฉ๋‹ˆ๋‹ค.

 

<Back Process>

์ˆ˜์‹ (3) ์„ค๋ช…

  • :
    • ์‹œ๊ฐ„ λ ๋‹จ๊ณ„์—์„œ์˜ ๋…ธ์ด์ฆˆ ๋ฐ์ดํ„ฐ ๋กœ๋ถ€ํ„ฐ
      ๋” ์ ์€ ๋…ธ์ด์ฆˆ ๋‹จ๊ณ„์ธ ๋กœ ์ „ํ™˜ํ•˜๋Š” ํ™•๋ฅ  ๋ถ„ํฌ
    • ์ •๊ทœ ๋ถ„ํฌ(๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ)๋ฅผ ๋”ฐ๋ฆ„
  • ํ‰๊ท (Mean)
    • ์ฒซ๋ฒˆ์งธ ํ•ญ: ๊ธฐ์กด ๋…ธ์ด์ฆˆ ๋ฐ์ดํ„ฐ zλ ์˜ ์˜ํ–ฅ
    • ๋‘๋ฒˆ์งธ ํ•ญ: ์›๋ณธ ๋ฐ์ดํ„ฐ x ์˜ ์˜ํ–ฅ
    • : ๋…ธ์ด์ฆˆ ๋‹จ๊ณ„ ๊ฐ„์˜ ๊ฐ€์ค‘์น˜๋กœ, λ′>λ ์ผ์ˆ˜๋ก ๋” ๋งŽ์€ ๋ณต์›(๋…ธ์ด์ฆˆ ์ œ๊ฑฐ)์ด ์ด๋ฃจ์–ด์ง

  • ๋ถ„์‚ฐ(Variance)
    • ๋…ธ์ด์ฆˆ ๊ฐ์†Œ๋Ÿ‰์„ ์กฐ์ ˆํ•˜๋Š” ๋ถ„์‚ฐ ํ•ญ

  • ์ด ์ˆ˜์‹์€ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๋ฉด์„œ ์›๋ณธ๋ฐ์ดํ„ฐ๋กœ ๋ณต์›ํ•˜๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
  • ์—ญ๋ฐฉ ๊ณผ์ •์—์„œ ์ƒ˜ํ”Œ์€ ๊ธฐ์กด ๋…ธ์ด์ฆˆ ๋ฐ์ดํ„ฐ์˜ ์ •๋ณด์™€ ์›๋ณธ ๋ฐ์ดํ„ฐ์˜ ์ •๋ณด๋ฅผ ์ ์ ˆํžˆ ํ˜ผํ•ฉํ•ฉ๋‹ˆ๋‹ค.
  • ๋…ธ์ด์ฆˆ๊ฐ€ ์ ์€ ๋‹จ๊ณ„๋กœ ๊ฐˆ์ˆ˜๋ก ์›๋ณธ ๋ฐ์ดํ„ฐ์˜ ๋น„์ค‘์ด ์ ์  ์ปค์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

์ˆ˜์‹ (4) ์„ค๋ช…

  • pθ(zλ′โˆฃzλ)
    • ๋ชจ๋ธ θ ๊ฐ€ ํ•™์Šตํ•œ ์—ญ๋ฐฉํ–ฅ ํ™•๋ฅ  ๋ถ„ํฌ
      → ์ฆ‰, ๋…ธ์ด์ฆˆ๋ฅผ ์ ์ง„์ ์œผ๋กœ ์ œ๊ฑฐํ•˜๋ฉฐ ๋‹ค์Œ ์ƒ˜ํ”Œ zλ′ ๋กœ ์ด๋™ํ•˜๋Š” ๊ณผ์ •
    • ์ •๊ทœ๋ถ„ํฌ(๊ฐ€์šฐ์‹œ์•ˆ๋ถ„ํฌ)๋กœ ํ‘œํ˜„๋จ : N(ํ‰๊ท (Mean),๋ถ„์‚ฐ(Variance))
  • ํ‰๊ท (Mean):
    • : ํ˜„์žฌ ๋‹จ๊ณ„์˜ ๋…ธ์ด์ฆˆ ๋ฐ์ดํ„ฐ
    • xθ(zλ)
      • ๋ชจ๋ธ์ด ์˜ˆ์ธกํ•œ ์›๋ณธ ๋ฐ์ดํ„ฐ ๋ณต์› ๊ฒฐ๊ณผ (denoising output)
      • ๋ชจ๋ธ์ด ๋กœ๋ถ€ํ„ฐ "์ด๋Ÿฐ ๋ฐ์ดํ„ฐ๊ฐ€ ์›๋ณธ์ผ ๊ฒƒ์ด๋‹ค"๋ผ๊ณ  ์ถ”์ •ํ•œ ๊ฐ’
  • ๋ถ„์‚ฐ(Variance):
    • :
      • ๋ชจ๋ธ์ด ์ถ”์ •ํ•œ ์—ญ๋ฐฉํ–ฅ ๊ณผ์ •์˜ ๋ถ„์‚ฐ
      • ์‹ค์ œ๋กœ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๋Š” ๊ณผ์ •์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ถ„์‚ฐ ๊ฐ’
    • ๋’ค์˜ ํ•ญ:
      • ์ด๋ก ์ ์ธ ์ „๋ฐฉ ๊ณผ์ •์˜ ๋ถ„์‚ฐ ๊ฐ’ (forward process์˜ ground truth ๋ถ„์‚ฐ)
      • ์‹ค์ œ ๋ฐ์ดํ„ฐ ๋ถ„ํฌ์—์„œ ์œ ๋„๋œ ์ด์ƒ์ ์ธ ๋ถ„์‚ฐ ๊ฐ’
    • ν:
      • ๋ถ„์‚ฐ ๋ณด๊ฐ„(interpolation) ๊ณ„์ˆ˜
      • 0 ≤ ν ≤ 10์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ,
        ๋ชจ๋ธ์ด ์ถ”์ •ํ•œ ๋ถ„์‚ฐ๊ณผ ์ด๋ก ์ ์ธ ๋ถ„์‚ฐ ์‚ฌ์ด๋ฅผ ์กฐ์ ˆ

 

์ˆ˜์‹ (5) ์„ค๋ช…

  • ์†์‹ค ํ•จ์ˆ˜๋Š” "๋ชจ๋ธ์ด ์‹ค์ œ ๋…ธ์ด์ฆˆ๋ฅผ ์–ผ๋งˆ๋‚˜ ์ž˜ ๋ณต์›(denoise)ํ•˜๋Š”์ง€"๋ฅผ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
  1. ์›๋ณธ ๋ฐ์ดํ„ฐ x ์— ๋…ธ์ด์ฆˆ ฯต ๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ์ƒ์„ฑ
  2. ๋ชจ๋ธ θ ๊ฐ€ ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„ ๋…ธ์ด์ฆˆ ฯต ๋ฅผ ์˜ˆ์ธก
  3. ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’ ฯตθ(zλ) ์™€ ์‹ค์ œ ๋…ธ์ด์ฆˆ ฯต ์˜ ์ฐจ์ด๋ฅผ ๊ณ„์‚ฐ
  4. ์ด ์ฐจ์ด๋ฅผ ์ตœ์†Œํ™”ํ•˜๋„๋ก ๋ชจ๋ธ์„ ํ•™์Šต

 

๊ธฐ์กด ๋‚ด์šฉ vs ์ƒˆ๋กœ์šด ๊ธฐ์—ฌ

 

๊ฐœ๋… ๊ธฐ์กด Diffusion ๋ชจ๋ธ ์ด ๋…ผ๋ฌธ์—์„œ์˜ ์ƒˆ๋กœ์šด ๊ธฐ์—ฌ
Denoising Score Matching (DSM) โœ… Vincent (2011), Song & Ermon (2019) ๊ธฐ์กด DSM์„ ํ™œ์šฉ, ๋‹ค์–‘ํ•œ ๋…ธ์ด์ฆˆ ์Šค์ผ€์ผ์—์„œ ํ™•์žฅ ์ ์šฉ
Variational Lower Bound (VLB) โœ… Kingma et al. (2021) Weighted VLB๋กœ ํ•ด์„ํ•˜์—ฌ ์ƒ˜ํ”Œ ํ’ˆ์งˆ ์กฐ์ ˆ ๊ฐœ์„ 
Noise Schedule (Cosine) โœ… Nichol & Dhariwal (2021) ํ•˜์ดํผ๋ณผ๋ฆญ ์‹œ์ปจํŠธ ๋ถ„ํฌ ๊ธฐ๋ฐ˜ ๋…ธ์ด์ฆˆ ์Šค์ผ€์ค„๋ง ์ œ์•ˆ
Classifier Guidance โœ… Ho et al. (2020) Classifier-Free Guidance๋กœ ๋ถ„๋ฅ˜๊ธฐ ์—†์ด ํ’ˆ์งˆ ๊ฐœ์„ 
Langevin Dynamics โœ… Song & Ermon (2019) ๊ธฐ์กด ๊ฐœ๋… ์œ ์ง€, ๋ชจ๋ธ ํšจ์œจ์„ฑ ๋ฐ ์ƒ˜ํ”Œ ํ’ˆ์งˆ ๊ฐœ์„ 

 

  • ๊ธฐ์กด (๋น„์กฐ๊ฑด๋ถ€): θ(zλ) → ๋‹จ์ˆœํžˆ ๋…ธ์ด์ฆˆ ๋งŒ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์Œ
  • ์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ: θ(zλ,c) → ๋…ธ์ด์ฆˆ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์กฐ๊ฑด c๋„ ํ•จ๊ป˜ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„์„œ "์กฐ๊ฑด์— ๋งž๋Š” ๊ฒฐ๊ณผ"

 


Guidance

 

GAN์ด๋‚˜ flow-based model์˜ ๊ฒฝ์šฐ, ์ƒ˜ํ”Œ๋ง ์‹œ์— ๋ถ„์‚ฐ์ด๋‚˜ ์ž…๋ ฅ noise์˜ ๋ฒ”์œ„๋ฅผ ์ค„์—ฌ truncated sampling์ด๋‚˜ low temperature sampling์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. ์ด๋Ÿฐ ๋ฐฉ๋ฒ•๋“ค์€ ์ƒ˜ํ”Œ์˜ ๋‹ค์–‘์„ฑ์„ ์ค„์ด๋ฉด์„œ ๊ฐ ์ƒ˜ํ”Œ์˜ ํ’ˆ์งˆ์„ ๋†’์ธ๋‹ค. ํ•˜์ง€๋งŒ, diffusion model์˜ ๊ฒฝ์šฐ ์ด๋Ÿฌํ•œ ๋ฐฉ๋ฒ•๋“ค์ด ํšจ๊ณผ์ ์ด์ง€ ์•Š๋‹ค.

 

Classifier Guidance

 

  • ฯต ฬƒθ(zλ, c)
    • Guided Score Function (์ˆ˜์ •๋œ ์Šค์ฝ”์–ด ํ•จ์ˆ˜)
    • ์ƒ˜ํ”Œ๋ง ๊ณผ์ •์—์„œ ์‚ฌ์šฉ๋˜๋Š” ์ตœ์ข… ์Šค์ฝ”์–ด์ด๋‹ค.
  • ฯต θ(zλ, c)
    • ๊ธฐ์กด ํ™•์‚ฐ ๋ชจ๋ธ์˜ ์Šค์ฝ”์–ด ํ•จ์ˆ˜
    • ์›๋ž˜๋Š” ์ด ํ•จ์ˆ˜๋กœ ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•œ๋‹ค.
  • w
    • classifier guidance strength (๊ฐ€์ค‘์น˜)
    • ํด์ˆ˜๋ก ์ƒ˜ํ”Œ ํ’ˆ์งˆ(fidelity) ๊ฐ€ ํ–ฅ์ƒ๋˜์ง€๋งŒ, ๋‹ค์–‘์„ฑ(diversity) ๋Š” ๊ฐ์†Œํ•œ๋‹ค.
    • Classifier์˜ ๊ทธ๋ž˜๋””์–ธํŠธ(gradient) 
    • ๋ฐ์ดํ„ฐ ๊ฐ€ ์กฐ๊ฑด c์— ๋” ์ ํ•ฉํ•˜๋„๋ก ์ƒ˜ํ”Œ์„ "๋Œ์–ด๋‹น๊ธฐ๋Š”" ์—ญํ• ์„ ํ•ฉ๋‹ˆ๋‹ค.

 

โœ… Classifier Guidance ์˜ ํšจ๊ณผ

  • Inception Score (IS) ํ–ฅ์ƒ
    • inception score ์€ ๋ชจ๋ธ์ด ์ƒ์„ฑํ•œ ์ƒ˜ํ”Œ์ด ๋ถ„๋ฅ˜ํ•˜๊ธฐ ์‰ฌ์šธ์ˆ˜๋ก ๋†’๊ฒŒ ๋‚˜์˜ต๋‹ˆ๋‹ค.
    • classifier guidance ๋Š” ๋ถ„๋ฅ˜๊ธฐ๊ฐ€ ์ž˜ ๋งž์ถœ ์ˆ˜ ์žˆ๋Š” ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•˜๊ธฐ ๋•Œ๋ฌธ์— inception score ๊ฐ€ ๋†’์•„์ง‘๋‹ˆ๋‹ค.
  • ๋‹ค์–‘์„ฑ ๊ฐ์†Œ
    • ์ƒ˜ํ”Œ๋“ค์ด ๋ถ„๋ฅ˜๊ธฐ ๊ธฐ์ค€์œผ๋กœ ํ™•์‹ ํ•  ์ˆ˜ ์žˆ๋Š” ์˜์—ญ์— ๋ชฐ๋ฆฌ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค
    • ์ด๋กœ ์ธํ•ด ๋‹ค์–‘ํ•œ ์ƒ˜ํ”Œ์ด ์ค„์–ด๋“ค๊ณ , ๋™์ผํ•œ ํŒจํ„ด์˜ ์ƒ˜ํ”Œ์ด ๋ฐ˜๋ณต๋˜๋Š” ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค.

 

์ด์ „ Classifier Guidance ๋…ผ๋ฌธ์—์„œ๋Š”

"๋น„์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ์— Classifier Guidance๋ฅผ ์ ์šฉํ•œ ๊ฒƒ๋ณด๋‹ค
์ด๋ฏธ ์กฐ๊ฑด๋ถ€๋กœ ํ•™์Šต๋œ ๋ชจ๋ธ์— Guidance๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์ด ์ƒ˜ํ”Œ ํ’ˆ์งˆ์ด ๋” ์šฐ์ˆ˜"

ํ•˜๋‹ค๋Š” ๊ฒฐ๋ก ์„ ๋„์ถœํ–ˆ์—ˆ์Šต๋‹ˆ๋‹ค. 

์ด๋Š” ์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ์—์„œ w์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ ์šฉํ•œ ๊ฒฐ๊ณผ๊ฐ€
๋น„์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ์—์„œ w+1 ๊ฐ€์ค‘์น˜๋ฅผ ์ ์šฉํ•œ ๊ฒฐ๊ณผ์™€ ๊ฐ™์•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

 

 

โœ… ๊ฒฐ๋ก :

  • ์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ + Classifier Guidance = ๋” ๋‚˜์€ ์ƒ˜ํ”Œ ํ’ˆ์งˆ ๋ฐ ๋‹ค์–‘์„ฑ ์œ ์ง€
  • ๋น„์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ + ๊ฐ•ํ•œ Guidance = ํ’ˆ์งˆ์€ ๋†’์ง€๋งŒ, ๋‹ค์–‘์„ฑ ๊ฐ์†Œ ๋ฐ ๋ถˆ์•ˆ์ •์„ฑ ์ฆ๊ฐ€

 

์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ
๋น„์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ

 

 

Classifier-Free Guidance

 

โœ… ๊ธฐ๋ณธ ๊ฐœ๋…

  • Unconditional Diffusion Model:
    • ์กฐ๊ฑด ์—†์ด ์ˆœ์ˆ˜ํ•œ ๋ฐ์ดํ„ฐ ๋ถ„ํฌ p(x) ๋ฅผ ํ•™์Šต
    • ์ƒ˜ํ”Œ๋ง ์‹œ ๋‹จ์ˆœํžˆ ๋…ธ์ด์ฆˆ์—์„œ ์‹œ์ž‘ํ•˜์—ฌ ๋ฐ์ดํ„ฐ๋กœ ๋ณต์›
  • Conditional Diffusion Model:
    • ์กฐ๊ฑด c (์˜ˆ: ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”, ํ…์ŠคํŠธ ๋“ฑ)์— ๋”ฐ๋ผ ์ƒ˜ํ”Œ ์ƒ์„ฑ
    • ์กฐ๊ฑด์— ๋งž๋Š” ๋ฐ์ดํ„ฐ ๋ถ„ํฌ p(xโˆฃc) ๋ฅผ ํ•™์Šต

 

๋…ผ๋ฌธ์—์„œ๋Š” ๋‘ ๊ฐœ์˜ ๋ชจ๋ธ์„ ๋”ฐ๋กœ ํ•™์Šตํ•˜์ง€ ์•Š๊ณ , ํ•˜๋‚˜์˜ ์‹ ๊ฒฝ๋ง ฯตθโ€‹(zλโ€‹,c) ์œผ๋กœ Unconditional ๊ณผ Conditional ๋ชจ๋ธ์„ ๋™์‹œ์— ํ•™์Šต์‹œํ‚ต๋‹ˆ๋‹ค. 

 

โœ… ํ•ต์‹ฌ ์•„์ด๋””์–ด:

  • ์กฐ๊ฑด c ๋ฅผ ์ผ๋ถ€ ํ™•๋ฅ ๋กœ ๋น„ํ™œ์„ฑํ™”ํ•˜์—ฌ Unconditional ๋ชจ๋ธ๋กœ ํ•™์Šต
    • ์—ฌ๊ธฐ์„œ ๋Š” "์กฐ๊ฑด ์—†์Œ"์„ ๋‚˜ํƒ€๋‚ด๋Š” ํŠน์ˆ˜ ํ† ํฐ ๋˜๋Š” Null ๊ฐ’

  • ๋‚˜๋จธ์ง€ ๊ฒฝ์šฐ์—๋Š” ์กฐ๊ฑด c ๋ฅผ ํ™œ์„ฑํ™”ํ•˜์—ฌ Conditional ๋ชจ๋ธ๋กœ ํ•™์Šต

 

์•„๋ž˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜ 1์€ Single Neural Network ๋กœ conditional diffusion model ๊ณผ unconditional diffusion model ๋‘˜๋‹ค ๊ฒฐํ•ฉํ•˜์—ฌ ํ•™์Šต์‹œํ‚ค๋Š” ๊ณผ์ •์„ ๋‹ด์€ Joint Training ๊ณผ์ • ์ž…๋‹ˆ๋‹ค. 

์œ„์˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜ 1์—์„œ 3๋ฒˆ์งธ ์ค„์„ ๋ณด๋ฉด Puncond ํ™•๋ฅ ์„ ์‚ฌ์šฉํ•˜์—ฌ Unconditional ๋กœ ์‚ฌ์šฉํ•˜๋Š”์ง€ ์•„๋‹Œ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

  • ๋น„์œจ๋กœ ์กฐ๊ฑด์„ ์ œ๊ฑฐํ•˜์—ฌ Unconditional ๋ชจ๋ธ๋กœ ํ•™์Šต
  • 1−puncond๋น„์œจ๋กœ ์‹ค์ œ ์กฐ๊ฑด c ๋ฅผ ์ œ๊ณตํ•˜์—ฌ Conditional ๋ชจ๋ธ๋กœ ํ•™์Šต
  • ์—ฌ๊ธฐ์„œ Puncond ๋Š” ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ

 

 Joint Training ์˜ ์†์‹คํ•จ์ˆ˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 

    • (7๋ฒˆ์งธ ์ค„) ๋ชจ๋ธ์€ Loss ๊ฐ’์„ ํ™œ์šฉํ•˜์—ฌ 
    • ๋…ธ์ด์ฆˆ ฯต ๋ฅผ ์ •ํ™•ํžˆ ์˜ˆ์ธกํ•˜๋„๋ก ํ•™์Šต๋ฉ๋‹ˆ๋‹ค.
    • Classifier-Free Guidance ํ•™์Šต:
    • ํ™•๋ฅ  puncond ๋กœ ์กฐ๊ฑด ์—†์ด(Unconditional) ํ•™์Šต
    • ๋‚˜๋จธ์ง€ ํ™•๋ฅ ๋กœ๋Š” ์กฐ๊ฑด๋ถ€(Conditional) ํ•™์Šต ์ˆ˜ํ–‰

 

์ž ์ด์ œ ์ด ๋…ผ๋ฌธ์˜ Classifier-Free Guidance ์˜ main ์‹ ์ž…๋‹ˆ๋‹ค!! ๐Ÿ™Œ๐Ÿ‘

  • : ์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ์˜ ๋…ธ์ด์ฆˆ ์ถ”์ •
  • ฯตθ(z,∅): ๋น„์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ์˜ ๋…ธ์ด์ฆˆ ์ถ”์ •
  • w: Guidance Strength (์ƒ˜ํ”Œ ํ’ˆ์งˆ๊ณผ ๋‹ค์–‘์„ฑ ์กฐ์ ˆ)

 

 

์•Œ๊ณ ๋ฆฌ์ฆ˜ 2๋Š” ์กฐ๊ฑด ์ƒ˜ํ”Œ๋ง ๊ณผ์ •์„ ๋‚˜ํƒ€๋‚ธ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ž…๋‹ˆ๋‹ค.

์•Œ๊ณ ๋ฆฌ์ฆ˜ 2์˜ ์ฃผ์š” ๊ตฌ์„ฑ ์š”์†Œ

  • ์ž…๋ ฅ ๋งค๊ฐœ๋ณ€์ˆ˜
    • w: Guidance Strength. ์ด ๊ฐ’์€ ์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ์—์„œ ์กฐ๊ฑด ์ •๋ณด์˜ ์ค‘์š”๋„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.
    • c: Conditioning information. ์ด๋Š” ์ƒ์„ฑ ๊ณผ์ •์—์„œ ๋ชจ๋ธ์ด ๋”ฐ๋ผ์•ผ ํ•  ์กฐ๊ฑด์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
    • : ๋กœ๊ทธ ์‹ ํ˜ธ ๋Œ€ ์žก์Œ ๋น„์œจ(log SNR)์˜ ์ฆ๊ฐ€ํ•˜๋Š” ์‹œํ€€์Šค. ์ƒ์„ฑ ๊ณผ์ •์—์„œ์˜ ๋…ธ์ด์ฆˆ ์ˆ˜์ค€์„ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.
  • ์ดˆ๊ธฐํ™”
    • z1∼N(0,I): ์ดˆ๊ธฐ ์ƒ˜ํ”Œ์€ ํ‘œ์ค€ ์ •๊ทœ ๋ถ„ํฌ์—์„œ ์ถ”์ถœ๋ฉ๋‹ˆ๋‹ค.
  • ๋ฐ˜๋ณต๊ณผ์ •
    • ๊ฐ ์‹œ๊ฐ„ ๋‹จ๊ณ„ t ์—์„œ, Classifier-Free Guidance ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ๊ณผ ๋ฌด์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ์„ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค.
      • 3๋ฒˆ์งธ ์ค„: ์กฐ๊ฑด๋ถ€์™€ ๋ฌด์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ์˜ ๊ฐ€์ค‘ ํ‰๊ท ์„ ๊ณ„์‚ฐํ•˜์—ฌ ๊ฐ€์ด๋˜์Šค๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
    • ์ƒ˜ํ”Œ๋ง ๋‹จ๊ณ„์—์„œ๋Š” ๊ณ„์‚ฐ๋œ ์Šค์ฝ”์–ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์Œ ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
      • 4๋ฒˆ์งธ ์ค„: ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๊ณ  ์ƒ˜ํ”Œ์„ ๊ฐœ์„ ํ•ฉ๋‹ˆ๋‹ค. 
    • 5๋ฒˆ์งธ ์ค„: ๋‹ค์Œ ์ƒ˜ํ”Œ zt+1 ์€ ์ •๊ทœ ๋ถ„ํฌ์—์„œ ์ถ”์ถœ๋˜๋ฉฐ, ์ด ๋ถ„ํฌ์˜ ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ์€ ์ด์ „ ์ƒ˜ํ”Œ๊ณผ ๊ณ„์‚ฐ๋œ ๊ฐ’์— ์˜ํ•ด ๊ฒฐ์ •๋ฉ๋‹ˆ๋‹ค.

 

์ •๋ฆฌ ๐Ÿง

ํ•ญ๋ชฉ Algorithm 1 (ํ•™์Šต) Algorithm 2 (์ƒ˜ํ”Œ๋ง)
๋ชฉ์  ๋ชจ๋ธ ํ•™์Šต (๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋Šฅ๋ ฅ ํ•™์Šต) ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ ์ƒ์„ฑ (์ƒ˜ํ”Œ๋ง)
๊ณผ์ • Forward Process (๋…ธ์ด์ฆˆ ์ถ”๊ฐ€) Reverse Process (๋…ธ์ด์ฆˆ ์ œ๊ฑฐ)
์ถœ๋ ฅ ๊ฒฐ๊ณผ ๋…ธ์ด์ฆˆ๋ฅผ ์ •ํ™•ํ•˜๊ฒŒ ์˜ˆ์ธกํ•˜๋Š” ๋ชจ๋ธ ฯตθ ์›๋ณธ๊ณผ ์œ ์‚ฌํ•œ ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
ํ•ต์‹ฌ ํŒŒ๋ผ๋ฏธํ„ฐ puncond (Unconditional ํ•™์Šต ํ™•๋ฅ ) w (Guidance Strength)
์—ญํ•  ์กฐ๊ฑด๋ถ€/๋น„์กฐ๊ฑด๋ถ€ ๋ฐ์ดํ„ฐ๋กœ ๋ชจ๋ธ ํ•™์Šต ํ•™์Šต๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ณ ํ’ˆ์งˆ ์ƒ˜ํ”Œ ์ƒ์„ฑ
CFG ์ ์šฉ ์—ฌ๋ถ€ โœ… Classifier-Free Guidance๋ฅผ ์œ„ํ•œ ์Šค์ฝ”์–ด ํ•™์Šต โœ… Classifier-Free Guidance๋ฅผ ์ ์šฉํ•˜์—ฌ ์ƒ˜ํ”Œ ํ’ˆ์งˆ ํ–ฅ์ƒ

 


Experiments

 

์ด ๋…ผ๋ฌธ์—์„œ๋Š” Classifier-Free Guidance (CFG) ์˜ ํšจ๊ณผ๋ฅผ ๊ฒ€์ฆํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์–‘ํ•œ ํ•˜์ดํผ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์กฐ์ ˆํ•˜๋ฉด์„œ ์‹คํ—˜์„ ์ง„ํ–‰ํ–ˆ๋‹ค. 

์‹คํ—˜์€ ์ƒ˜ํ”Œ ํ’ˆ์งˆ(Fidelity), ๋‹ค์–‘์„ฑ(Diversity), ํšจ์œจ์„ฑ(Efficiency) ์— ๋ฏธ์น˜๋Š” ์˜ํ–ฅ์„ ๋ถ„์„ํ•˜๋Š” ๋ฐ ์ดˆ์ ์„ ๋งž์ท„๋‹ค.

  • Varying the Classifier-Free Guidance Strength
    • w ∈ {0, 0.1, 0.2, . . . , 4}
    • log SNR endpoints
      • λmin = −20 and λmax = 20
    • 64 x 64 models
      • noise interpolation coefficient: v = 0.3
      • trained for 400,000 steps
    • 128 x 128 models
      • v = 0.2
      • trained for 2,700,000 steps

 

  • Varying the Unconditional Training Probability
    • puncond ∈ {0.1, 0.2, 0.5}
    • 0.1,0.2 ์—์„œ ๊ฐ€์žฅ ์ข‹์€ ์„ฑ๋Šฅ
    • 0.5 ์—์„œ๋Š” ์„ฑ๋Šฅ์ด ์ €ํ•˜๋จ

 

  • Varying the Number of Smapling Steps
    • T ∈ {128, 256, 1024}
    • 256 ์ด ๊ฒฐ๊ณผ๊ฐ€ ๊ฐ€์žฅ ๋ฐธ๋Ÿฐ์Šค๊ฐ€ ์ข‹์•˜์Œ

 


Discussion

 

  • CFG๋Š” ์ƒ˜ํ”Œ์˜ Unconditional Likelihood(p(x))๋ฅผ ๊ฐ์†Œ์‹œํ‚ค๊ณ , Conditional Likelihood(p(xโˆฃc))๋ฅผ ์ฆ๊ฐ€์‹œํ‚ค๋Š” ๋ฐฉ์‹์œผ๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.

 

  • Classifier-Free Guidance๋Š” ์ƒ˜ํ”Œ๋ง ์†๋„(Sampling Speed)์— ๋ถˆ๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด Diffusion Model์˜ Forward Pass๋ฅผ ๋‘ ๋ฒˆ ์‹คํ–‰ํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
    • (1) ํ•™์Šต ๊ณผ์ •(Training Phase)์—์„œ๋Š”?
      • Forward Process(๋…ธ์ด์ฆˆ ์ถ”๊ฐ€)๋Š” Unconditional๊ณผ Conditional์„ ๋”ฐ๋กœ ๋‘ ๋ฒˆ ์ง„ํ–‰ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
      • ํ™•๋ฅ ์ ์œผ๋กœ ์กฐ๊ฑด์„ ์ œ๊ฑฐํ•˜์—ฌ Unconditional ๋ฐ์ดํ„ฐ๋ฅผ ํ•จ๊ป˜ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
      • ์ฆ‰, ํ•™์Šต ์‹œ์—๋Š” ํ•œ ๋ฒˆ์˜ Forward Pass๋งŒ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.
    • (2) ์ƒ˜ํ”Œ๋ง ๊ณผ์ •(Sampling Phase)์—์„œ๋Š”?
      • ์ƒ˜ํ”Œ๋ง ๋‹จ๊ณ„์—์„œ๋Š” Classifier-Free Guidance (CFG)๊ฐ€ ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.
      • ์ƒ˜ํ”Œ ํ’ˆ์งˆ์„ ๊ฐœ์„ ํ•˜๊ธฐ ์œ„ํ•ด Conditional Score์™€ Unconditional Score๋ฅผ ํ•จ๊ป˜ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
      • ์ด๋ฅผ ์œ„ํ•ด ์ƒ˜ํ”Œ๋ง ๊ณผ์ •์—์„œ ๋‘ ๋ฒˆ์˜ Forward Pass๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
      ์™œ ๋‘ ๋ฒˆ์˜ Forward Pass๊ฐ€ ํ•„์š”ํ•œ๊ฐ€?
      1. ์ฒซ ๋ฒˆ์งธ Forward Pass:
        • Conditional Score ฯตθ(z,c) ๊ณ„์‚ฐ
        • ์ฆ‰, ์กฐ๊ฑด c ๋ฅผ ํฌํ•จํ•œ ์˜ˆ์ธก ์ˆ˜ํ–‰
      2. ๋‘ ๋ฒˆ์งธ Forward Pass:
        • Unconditional Score ฯตθ(z) ๊ณ„์‚ฐ
        • ์ฆ‰, ์กฐ๊ฑด ์—†์ด ์˜ˆ์ธก ์ˆ˜ํ–‰
      3. Guidance ์ ์šฉ:
        • ๋‘ ๊ฐœ์˜ ์Šค์ฝ”์–ด๋ฅผ ์กฐํ•ฉํ•˜์—ฌ ์ตœ์ข… ์ƒ˜ํ”Œ๋ง ์ˆ˜ํ–‰
        • ฯต~θ(z,c)=(1+w)⋅ฯตθ(z,c)−w⋅ฯตθ(z)
        • ์—ฌ๊ธฐ์„œ ฯตθ(z,c)์™€ ฯตθ(z)๋ฅผ ๋‘˜ ๋‹ค ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋ฏ€๋กœ
          Forward Pass๋ฅผ ๋‘ ๋ฒˆ ์‹คํ–‰ํ•ด์•ผ ํ•จ.
      ๊ฒฐ๋ก : โŒ ์ƒ˜ํ”Œ๋ง ๊ณผ์ •์—์„œ๋Š” Forward Pass๊ฐ€ ๋‘ ๋ฒˆ ์ง„ํ–‰๋จ.

 

 

CFG๊ฐ€ Unconditional ๋ชจ๋ธ์„ ํ•„์š”๋กœ ํ•˜๋Š” ํ•ต์‹ฌ ์ด์œ 

๊ธฐ์กด ์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ๋งŒ ์‚ฌ์šฉํ•œ ๊ฒฝ์šฐ Unconditional ๋ชจ๋ธ์„ ํ•จ๊ป˜ ํ•™์Šตํ•œ ๊ฒฝ์šฐ
ํŠน์ • ์กฐ๊ฑด์„ ๋„ˆ๋ฌด ๊ฐ•ํ•˜๊ฒŒ ๋”ฐ๋ฆ„ ์กฐ๊ฑด์„ ๋”ฐ๋ฅด๋ฉด์„œ๋„ ์ƒ˜ํ”Œ ๋‹ค์–‘์„ฑ์„ ์œ ์ง€ ๊ฐ€๋Šฅ
๋ชจ๋ธ์ด ํŠน์ • ๋ชจ๋“œ(mode)์— ๊ฐ‡ํž˜ ๋” ๋„“์€ ๋ฐ์ดํ„ฐ ๋ถ„ํฌ๋ฅผ ๋ฐ˜์˜ํ•˜์—ฌ ์ƒ˜ํ”Œ๋ง ํ’ˆ์งˆ ํ–ฅ์ƒ
๋‹ค์–‘ํ•œ ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•˜๊ธฐ ์–ด๋ ค์›€ ww ๊ฐ’์„ ์กฐ์ ˆํ•˜์—ฌ ํ’ˆ์งˆ๊ณผ ๋‹ค์–‘์„ฑ ์กฐ์ ˆ ๊ฐ€๋Šฅ

 

 

๐Ÿ’ก๐Ÿ’ก๐Ÿ’ก

๋…ผ๋ฌธ์„ ์ฝ๋‹ค๊ฐ€ ๋“  ์˜๋ฌธ์ ์ž…๋‹ˆ๋‹ค.

์ด๋ฏธ ์กฐ๊ฑด๋ถ€ ํ™•์‚ฐ ๋ชจ๋ธ์ด ์กด์žฌํ•˜๋Š”๋ฐ,
์™œ CG (Classifier Guidance)์™€ CFG (Classifier-Free Guidance) ๋…ผ๋ฌธ์ด ์ค‘์š”ํ•œ๊ฐ€? 

 

๊ทธ์— ๋Œ€ํ•œ ๋‹ต ์ž…๋‹ˆ๋‹ค.

 

๐Ÿš€ 1.  ์กฐ๊ฑด๋ถ€ ํ™•์‚ฐ ๋ชจ๋ธ(Conditional Diffusion Model)์€ ์–ด๋–ป๊ฒŒ ๋งŒ๋“ค์–ด์ง€๋Š”๊ฐ€?

์กฐ๊ฑด๋ถ€ ํ™•์‚ฐ ๋ชจ๋ธ์€ ์ฃผ์–ด์ง„ ์กฐ๊ฑด c๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํŠน์ •ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํ™•์‚ฐ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
์ด ๋ชจ๋ธ์€ ๊ธฐ๋ณธ์ ์ธ Diffusion Process๋ฅผ ๋”ฐ๋ฅด์ง€๋งŒ, ํ•™์Šต ๊ณผ์ •์—์„œ ์ถ”๊ฐ€์ ์ธ ์กฐ๊ฑด์„ ๋ชจ๋ธ์— ์ œ๊ณตํ•˜์—ฌ ์ œ์–ด ๊ฐ€๋Šฅํ•˜๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

โœ… (1) ๊ธฐ๋ณธ์ ์ธ ์กฐ๊ฑด๋ถ€ ํ™•์‚ฐ ๋ชจ๋ธ์˜ ๊ตฌ์กฐ

ฯตθ(z,c)

  • z: ๋…ธ์ด์ฆˆ๊ฐ€ ์ถ”๊ฐ€๋œ ์ด๋ฏธ์ง€ (Diffusion Step)
  • c: ์กฐ๊ฑด (์˜ˆ: ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”, ํ…์ŠคํŠธ, ํฌ์ฆˆ ์ •๋ณด ๋“ฑ)
  • ฯตθ(z,c): ์กฐ๊ฑด๋ถ€ ๋ชจ๋ธ์ด ์˜ˆ์ธกํ•˜๋Š” ๋…ธ์ด์ฆˆ

๐Ÿ“Œ ์†์‹ค ํ•จ์ˆ˜ (Loss Function)

  • ๋ชจ๋ธ์ด ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ ๋ฐฉํ–ฅ์„ ํ•™์Šตํ•˜๋„๋ก ํ•จ.

๐Ÿ“Œ ์กฐ๊ฑด์„ ์ž…๋ ฅํ•˜๋Š” ๋ฐฉ๋ฒ•

  1. ํด๋ž˜์Šค ๋ ˆ์ด๋ธ” ์‚ฌ์šฉ (One-hot Encoding or Embedding)
    • ์˜ˆ: "๊ฐ•์•„์ง€" ํด๋ž˜์Šค๋ผ๋ฉด c๋ฅผ ์ˆซ์ž ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ชจ๋ธ์— ์ž…๋ ฅ.
  2. ํ…์ŠคํŠธ ์‚ฌ์šฉ (CLIP, Cross-Attention ๋“ฑ ํ™œ์šฉ)
    • ํ…์ŠคํŠธ๋ฅผ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•œ ํ›„ ๋ชจ๋ธ ๋‚ด๋ถ€์—์„œ ์กฐ๊ฑด์œผ๋กœ ์‚ฌ์šฉ.
  3. ํฌ์ฆˆ, Depth Map ๋“ฑ ๊ตฌ์กฐ ์ •๋ณด ํ™œ์šฉ
    • ํŠน์ •ํ•œ ํฌ์ฆˆ๋‚˜ ๋ชจ์–‘์„ ์œ ์ง€ํ•˜๋ฉด์„œ ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•˜๋„๋ก ์œ ๋„.

 

ํ•˜์ง€๋งŒ ์ด CG ์™€ CFG ๋…ผ๋ฌธ์—์„œ๋Š” CLIP ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์•˜๊ณ , Cross-Attention ์„ ์ด์šฉํ•œ Stable Diffusion ๋ชจ๋ธ์ด ๋ฐœํ‘œ๋˜๊ธฐ ์ „์— ๋‚˜์™”๊ธฐ ๋•Œ๋ฌธ์— "1๋ฒˆ: ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”" ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. 

์ด CFG ๋…ผ๋ฌธ์—์„œ๋Š” ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”์ด ๋ถ™์–ด์žˆ๋Š” ๋ฐ์ดํ„ฐ๋งŒ์„ ์‚ฌ์šฉํ•ด์•ผํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ฐ์ดํ„ฐ๋“ค์„ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ์…‹ ํ•ด์ƒ๋„ ์นดํ…Œ๊ณ ๋ฆฌ ์ˆ˜ ์„ค๋ช…
CIFAR-10 32×32 10 ์†Œํ˜• ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์…‹ (๊ฐ•์•„์ง€, ๊ณ ์–‘์ด, ์ž๋™์ฐจ ๋“ฑ)
ImageNet 64×64 64×64 1,000 ImageNet์„ 64×64๋กœ ๋‹ค์šด์ƒ˜ํ”Œ๋งํ•œ ๋ฒ„์ „
LSUN-Bedroom 256×256 1 ์‹ค๋‚ด ์ธํ…Œ๋ฆฌ์–ด (์นจ์‹ค) ์ด๋ฏธ์ง€ ์ƒ์„ฑ
LSUN-Cat 256×256 1 ๊ณ ์–‘์ด ์ด๋ฏธ์ง€ ์ƒ์„ฑ
LSUN-Horse 256×256 1 ๋ง ์ด๋ฏธ์ง€ ์ƒ์„ฑ
FFHQ (Flickr-Faces-HQ) 256×256 1 ๊ณ ํ•ด์ƒ๋„ ์–ผ๊ตด ์ด๋ฏธ์ง€

 

 

 

๐ŸŽฏ 2. ์กฐ๊ฑด๋ถ€ ํ™•์‚ฐ ๋ชจ๋ธ์ด ์ด๋ฏธ ์žˆ์Œ์—๋„ CG์™€ CFG ๋…ผ๋ฌธ์ด ์ค‘์š”ํ•œ ์ด์œ 

 

โœ… (1) Classifier Guidance (CG) ๋…ผ๋ฌธ์˜ ์˜์˜

  • ๊ธฐ์กด ์กฐ๊ฑด๋ถ€ ํ™•์‚ฐ ๋ชจ๋ธ์ด ์—†๋”๋ผ๋„, Unconditional Diffusion Model์—์„œ ํŠน์ •ํ•œ ์กฐ๊ฑด์„ ์ถ”๊ฐ€ํ•  ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•จ.
  • ๊ธฐ์กด ๋ถ„๋ฅ˜๊ธฐ(Classifier)์˜ ๊ทธ๋ผ๋””์–ธํŠธ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ํŠน์ • ํด๋ž˜์Šค๋ฅผ ๊ฐ•ํ™”ํ•˜๋Š” ๋ฐฉ์‹์„ ์‚ฌ์šฉ.

๐Ÿ“Œ CG์˜ ํ•ต์‹ฌ ์•„์ด๋””์–ด

  • ๋ถ„๋ฅ˜๊ธฐ์˜ ์ถœ๋ ฅ ํ™•๋ฅ  p(yโˆฃx) ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ƒ˜ํ”Œ๋ง์„ ์กฐ์ ˆ.
  • ์ฆ‰, Unconditional Diffusion Model์„ ์ถ”๊ฐ€์ ์ธ ํ•™์Šต ์—†์ด Class-Conditional ๋ฐฉ์‹์œผ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Œ.
  • ์ƒˆ๋กœ์šด ์กฐ๊ฑด์ด ํ•„์š”ํ•  ๋•Œ ๋ชจ๋ธ์„ ๋‹ค์‹œ ํ•™์Šตํ•  ํ•„์š” ์—†์ด, ๊ธฐ์กด ๋ชจ๋ธ์— ์กฐ๊ฑด์„ ์ถ”๊ฐ€ ๊ฐ€๋Šฅ.

๐Ÿ“Œ CG ๋…ผ๋ฌธ์˜ ํ•ต์‹ฌ ๊ธฐ์—ฌ

  1. ๊ธฐ์กด Diffusion Model์ด Class-Conditional๋กœ ํ•™์Šต๋˜์ง€ ์•Š์•„๋„, ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ์‚ฌ์šฉํ•ด ํŠน์ • ์กฐ๊ฑด์„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Œ.
  2. ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ์…‹์ด ์ถ”๊ฐ€๋  ๋•Œ, Diffusion Model์„ ๋‹ค์‹œ ํ•™์Šตํ•  ํ•„์š” ์—†์ด ๋ถ„๋ฅ˜๊ธฐ๋งŒ ํ•™์Šตํ•˜๋ฉด ๋จ.
  3. "๊ฐ•์•„์ง€"๋ผ๋Š” ์กฐ๊ฑด๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, "์›ƒ๋Š” ๊ฐ•์•„์ง€", "์Šฌํ”ˆ ๊ฐ•์•„์ง€" ๋“ฑ์˜ ์ถ”๊ฐ€์ ์ธ ์†์„ฑ๋„ ๋ถ„๋ฅ˜๊ธฐ์˜ ํ™œ์šฉ์„ ํ†ตํ•ด ์กฐ์ ˆ ๊ฐ€๋Šฅ.
    1. ์ด๊ฑด ์ด์ œ ๋ฐ์ดํ„ฐ์…‹์— ์›ƒ๋Š”๋‹ค/ ์Šฌํ”„๋‹ค ์— ๊ด€ํ•œ ๋ ˆ์ด๋ธ”์ด ๋˜์–ด์žˆ์–ด์•ผ์ง€ ์‹คํ˜„์ด ๊ฐ€๋Šฅ

 

 

โœ… (2) Classifier-Free Guidance (CFG) ๋…ผ๋ฌธ์˜ ์˜์˜

  • CG ๋ฐฉ์‹์€ ๋ถ„๋ฅ˜๊ธฐ(Classifier)๊ฐ€ ํ•„์š”ํ•˜์ง€๋งŒ, CFG๋Š” ๋ถ„๋ฅ˜๊ธฐ ์—†์ด๋„ ์กฐ๊ฑด์„ ์กฐ์ ˆํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•จ.
  • ๋ถ„๋ฅ˜๊ธฐ ์—†์ด๋„ ์ƒ˜ํ”Œ ํ’ˆ์งˆ์„ ๋†’์ผ ์ˆ˜ ์žˆ๋Š” ๋ฐฉ์‹์„ ์—ฐ๊ตฌํ•จ.

๐Ÿ“Œ CFG์˜ ํ•ต์‹ฌ ์•„์ด๋””์–ด

  • Unconditional ๋ชจ๋ธ๊ณผ Conditional ๋ชจ๋ธ์„ ํ•จ๊ป˜ ํ•™์Šตํ•˜์—ฌ ๋ถ„๋ฅ˜๊ธฐ ์—†์ด ์กฐ๊ฑด์„ ์ ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ• ์ œ์•ˆ.
  • ๋ถ„๋ฅ˜๊ธฐ์˜ ๊ทธ๋ผ๋””์–ธํŠธ๊ฐ€ ํ•„์š” ์—†์œผ๋ฏ€๋กœ, ์ถ”๊ฐ€์ ์ธ ๋ถ„๋ฅ˜๊ธฐ ํ•™์Šต ์—†์ด๋„ ์กฐ๊ฑด์„ ๋ฐ˜์˜ํ•  ์ˆ˜ ์žˆ์Œ.

๐Ÿ“Œ CFG ๋…ผ๋ฌธ์˜ ํ•ต์‹ฌ ๊ธฐ์—ฌ

  1. ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ๋”ฐ๋กœ ํ•™์Šตํ•  ํ•„์š” ์—†์ด, Diffusion Model ์ž์ฒด๋งŒ์œผ๋กœ ์กฐ๊ฑด์„ ์กฐ์ ˆํ•  ์ˆ˜ ์žˆ์Œ.
  2. ์ƒ˜ํ”Œ๋ง ์†๋„๋Š” CG๋ณด๋‹ค ๋А๋ฆด ์ˆ˜ ์žˆ์ง€๋งŒ, ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ€ ๋‹จ์ˆœํ•˜๊ณ  ๋ถ„๋ฅ˜๊ธฐ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•˜์ง€ ์•Š์Œ.
  3. ํ…์ŠคํŠธ, ์ด๋ฏธ์ง€ ๋“ฑ ๋‹ค์–‘ํ•œ ์กฐ๊ฑด์„ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋ฐ˜์˜ํ•  ์ˆ˜ ์žˆ์Œ (Stable Diffusion ๊ฐ™์€ ๋ชจ๋ธ์—์„œ ํ™œ์šฉ๋จ).

 

 

 

๊ธด ๊ธ€ ์ฝ์–ด์ฃผ์…”์„œ gamsahapnida

์‚ฌ์‹ค ๊ฐœ์ธ ๊ณต๋ถ€ ๊ธฐ๋ก์šฉ์ด๊ธด ํ•œ๋ฐ...