Dreambooth, Textual Inversion 등 stable diffusion과 같은 foundataion 모델에서 학습하지 않은 custom concept을 적은 데이터셋으로 학습해 원하는 이미지를 만드는 연구는 계속되고 있지만, 여전히 프롬프트에 맞는 그림을 그려내지 못하거나, 복잡한 프롬프트를 제대로 그리지 못하는 문제점이 있다.
Dreambooth와 Textual Inversion 각기 다른 두 방식은 꽤 좋은 결과를 그려주지만, 방법에 따른 명확한 한계점을 가지고 있는데, 한계점이 있다. AttnDreamBooth는 이러한 두 방법의 한계점을 극복하고 해결하는 연구이며, 2024 Neurips에 accept되었다.
AttnDreamBooth가 어떠한 문제점을 어떻게 해결했는지 아래에 간단히 살펴보겠다.
https://arxiv.org/pdf/2406.05000
기존 방법의 한계점
Dreambooth는 sks, vkv 같은 새로은 토큰을 사용해 unet을 fine-tuning하기 때문에 text embedding을 제대로 배울 수 없다. 보다 자세히는, Dreambooth는 text encoder를 고정시킨채 새로운 토큰이 원하는 concept을 그려내도록 학습하는데, 이 과정에서 새로운 concept과 토큰의 임베딩을 충분히 배울 수 없다.
Textual Inversion은 새로운 토큰을 추가하지 않고, 기존 단어 (cat, dog 등)가 표현하는 이미지가 원하는 concept이 되도록 학습하기 때문에, overfitting 될 수 있다.
아이디어
위와 같은 한계점에도, Textual Inversion은 학습의 초기 단계에서 텍스트와 이미지의 embedding alignment를 잘 학습하고, Dreambooth는 ebmedding alignment를 학습하는데는 어려움을 겪지만, 새로운 concept의 identity는 잘 학습한다. 해결하는 가장 쉬운 아이디어는 잘 하는 두가지 방법을 섞어 unet과 text encoder를 함께 학습시키는 것인데, unet과 text encoder의 학습 속도가 다르기 때문에 여전히 새로운 concept을 제대로 학습하지 못했다.
그래서 저자들은 3단계로 나누어, embedding alignment를 먼저 학습하고, attention map을 refine한 뒤 subject identity를 얻을 수 있는 AttnDreamBooth 를 제안한다.
Embedding alignement 학습
새로운 concept에 대한 embedding alignment가 제대로 학습되었다면 cross attention 맵을 제대로 찾을 수 있게 된다. 특히, overfitting을 줄이는 방향으로 text embedding을 학습하는데, textual inversion과 비교해 적은 optimization step, 낮은 learning rate을 사용하며, photo of [v] [category] 의 프롬프트를 이용해 cross attention의 overfitting을 방지한다. 또한, cross attention 맵의 학습을 위해서 [v] 의 attention map과 [category]의 attention map의 L2 loss 를 사용해 최적화한다.
Refine Cross-Attention Map
Dreambooth가 학습에 어려움을 겪던 embedding misalignment 이슈를 완화시키기 위해, 위 단계에서 얻은 cross attention map을 refine하는 과정이 필요하며, 저자들은 이를 unet의 모든 cross attention layer에서 fine-tuning한다.
Capture subject identity
위의 두 단계의 학습을 통해 embedding alignment는 어느정도 가능해졌다. 아래 그림에서 [v]에 대한 부분이 붉게 나타나는 것을 볼 수 있다. 하지만, 여전히 Input의 이미지와는 다른 모습을 보이는데, 이를 위해 Dreambooth의 방법과 유사하게 unet의 모든 layer를 freeze한 뒤, 학습해 identity preservation을 보다 잘 할 수 있었다.
결론
기존에 굉장히 잘 알려진 두 가지 연구들의 한계점을 명확히 분석하고, 비교적 간단한 아이디어지만, 이를 해결할 수 있는 방법을 적용해 좋은 결과를 그려냈다. 여전히 overfitting과 복잡한 프롬프트에 대한 한계점을 어느정도 가지고 있어 저자가 주장하는 embedding alignment 관점에서는 개선할 부분이 있을 것 같다.
Stable diffusion을 활용한 fine-tuning과 dreambooth에 대해 궁금하신 분들은 아래 글들을 참고해주세요.
Stable Diffusion과 파인 튜닝 방법: 완벽 가이드
DreamBooth를 활용한 이미지 파인튜닝: 코드와 구현 방법
텍스트-투-이미지 변환: Hugging Face Diffusers 라이브러리를 사용한 실습