Prototyping XGen-Image-1

9 min read

TLDR

Generative AI methods for image generation have a wide variety of potential applications in marketing, sales, and e-commerce. With these applications in mind, the Salesforce Research team has developed several techniques based on image-generative diffusion models, including methods for image editing, improved classifier guidance, and improved controlled generation methods. In this blog post, we document our experience of training large text-to-image diffusion models from scratch. We describe our design decisions, training process, and performance metrics for the first generation of our image generation models, called XGen-Image-1.  In summary:

  • We train XGen-Image-1, an 860 million parameter latent diffusion model using 1.1 billion publicly available images from the  LAION dataset
  • Combining a latent model VAE with off-the-shelf pixel upsamplers allows training at very low resolution, reducing compute cost.
  • A competitive image generation model can be trained on Google’s TPU stack for ~$75k
  • XGen-Image-1 matches the prompt alignment performance of Stable Diffusion 1.5 and 2.1, which are among the best-performing image generation models.
  • Automated touch-ups on prompted regions (e.g. “face”) are effective enhancements
  • Rejection sampling at inference time can dramatically improve results

What are we going to train?

At present, there are two main classes of diffusion models: pixel-based and latent-based. Pixel-based diffusion models include DeepFloyd-IF (Shonenkov et al. 2023), Imagen (Saharia et al. 2022), EDiffi (Balaji et al. 2022), Kandinsky (Shakhmatov et al. 2022), and DALLE-2 (Ramesh et al. 2022). Latent-based diffusion models (LDMs) include the Stable Diffusion family of models, pioneered in Rombach et al. (2021) and Wuerstchen (Pernias et al. 2023).

The key difference between these classes of models is that the latent diffusion models are denoising autoencoded image representations in a compressed space (typically 8x spatially) while the pixel-based diffusion models operate directly on pixels.

Bigger images are expensive to train on, so the majority of these approaches typically end up training a “base” model at 64x64 resolution. For the pixel models, this results in fairly small images as seen below. Content is observable, but not detail, and it won’t look good when enlarged. This necessitates following up the base generation with upsamplers - typically also diffusion models.

For all of these approaches, various pretrained models are used

  • Text conditioning: CLIP/T5 language embeddings
  • Latent diffusion models: VAE

Also for the cascading upsampling models, each component is train independently, essentially each component is pretrained w.r.t. the others (and vice versa)

Diversely trained autoencoders (and upsamplers it turns out) are extremely adaptable to different types of imagery. Typically no changes are needed compared to the needed shift in the generative process. That isn’t to say these models are perfect; VAE StableDiffusion artifacts such as detailed features (e.g. text, small faces) are commonly known and upsamplers can introduce artifacts of their own. These models aren’t perfect, but they are a lot more re-useable than the base generative model counterparts!

With this in mind, we decided that the more direct/straightforward problems of upsampling, conditioning, and encoding didn’t need to be the focus of our model. Instead we asked: since these elements are robust and reusable, how much can we reuse them?

We decided to test the limits of efficient training and see how low of resolution we could train at by combining pretrained autoencoding and pixel upsampling models.

As illustrated in the above pipeline, we use both a pretrained autoencoder and optional pixel-based upsamplers. This allows us to generate at a low (latent) resolution (32x32) while still outputting 1024x1024 images. In the future, we want to explore the lower bound of practical resolution even further. For the purpose of results,  we report quantitative results and human evaluation at the 256x256 stage directly after the VAE without any upsamplers. Qualitative results (at the start and end of this post) use both a “re-upsampler” from 256→64→256 (analogous to the “Refiner” of SDXL) and upsampler from 256→1024.

What data will we train it on?

Following Stable Diffusion,  we train our model using the  LAION-2B dataset with an aesthetic score filter of 4.5, which constitutes ~42% of the dataset. The LAION dataset consists of web-scraped images; let’s take a quick look at what the captions and images look like. As seen below, there’s a lot of product images with basic descriptions, a lot of clothing, etc. The beauty of huge datasets, however, is that in the long tail of concepts we can find instances of very rare nouns. When you multiply this long tail by the scale of the datasets, there’s still a lot of instances! For example, in a 1M image sample, the words “dragon” and “astronaut” occur in 411 and 86 instances respectively. When we consider the entire 2B image dataset, that means there are >800k and >170k labeled training instances of each concept (not to mention variants on those words)! For perspective, that means there are almost as many “dragon” instances to learn from as images in the original ImageNet-1k challenge total.


Training Infrastructure

We trained our model on TPU v4s. We found Ronghang Hu’s Moco-V3 TPU code to be an invaluable starting point.  As part of training on TPUs, we used Google Cloud Storage (GCS) for model saving and used gcloud mounted drives to store large datasets. We trained our model on a v4-512 TPU machine for 1.1M steps, taking about 9 days with estimated hardware costs of approximately $73k. The original StableDiffusion cost $600k.

Training hiccups

Early on, our loss would swing wildly step-to-step even with large batch size. We found that this was due to all workers receiving the same seeding. By random seeding workers with their rank,  an even distribution of noise steps and smoother loss curves were achieved.

Saving model checkpoints provided to end up being a surprisingly hairy problem in our infrastructure setup. The local ~ directories on TPUs aren’t persistent, and our code drive had slow I/O. Saving to GCS in parallel didn’t work out of the box - with Pytorch/XLA you need to be cautious about what’s being executed by all workers vs. just the master. In this case some operations will take a lock on the GCS entry resulting in the others hanging.

The code block below solved this issue by saving to GCS while only touching the bucket (taking out a lock) on the master thread.


if 'gs://' in file_or_path:
    print("Making blob")
    gcs_path = file_or_path.replace('gs://', '')
    bucket_name = gcs_path.split('/')[0]
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob('/'.join(gcs_path.split('/')[1:]))
    print("Opening blob")
    print("starting 'master' block")
    if should_write_data:
        with blob.open('wb', ignore_flush=True) as f:
            print("Actually saving")
            torch.save(cpu_data, f)

Loss curves - we expect that further training will continue to improve our model.

Automated Metric Evaluation

We perform automatic evaluation of our model (and checkpoints) by measuring CLIP Score (alignment to prompt) on the x-axis and FID (similarity of appearance on a dataset-level) on the y-axis. These metrics are computed across 15 guidance scales for 30k image-prompt pairs in the first figure (vs. StableDiffusion versions) and 1k pairs for inter-checkpoint comparison. Data pairs are randomly drawn from the COCO Captions dataset with “A photograph of” appended to the caption to avoid FID penalties associated with different graphic styles (e.g. illustrations).

CLIP-FID evaluations are limited as noted in the literature (e.g. SDXL, Podell et al. 2023) but are still a useful big-picture metric to have. We see that our model performs competitively with SD1.5 and SD2.1, actually exceeding the pretrained StableDiffusion models in both metrics - indicating high photorealism and prompt faithfulness. As a sanity check, we note that sequential “epochs” (12.5k steps) generally improve along both dimensions.

Human Evaluation

Following SDXL (Podell et al. 2023), we performed  human evaluation of our model vs. SD1.5 and 2.1 on the PartiPrompt (Yu et al. 2022) benchmark, measuring prompt alignment, using Amazon Mechanical Turk. We asked  users “Which image better follows the prompt?” collecting responses for all 1632 prompts in the benchmark in 6 separate trials, resulting in ~10k responses in total per comparison. Error bars indicate 95% confidence intervals.

We report the overall (all cross-section) average in the figure above. We see that XGen-Image is rated almost identically to SD1.5 while being marginally (though not significantly) behind SD2.1.

We do not directly evaluate against the recent SDXL (Podell et al. 2023) which is a much larger LDM that demonstrates far superior performance to SD1.x and 2.x. We are currently working on scaling up XGen-Image and addressing specific areas of improvement, with the goal of  matching the performance of SDXL.

Consistently Generating High Quality Images

Building on our trained XGen-Image model, we implemented two commonly used tricks for consistently generating high-quality images in our inference pipeline.

  1. Generate a ton of images and choosing the best one
  2. Inpaint things that don’t look good

We wanted to keep the 1-prompt 1-output setup, so looked to automate the above.

For (1), we tried rejection sampling - generating multiple images and automatically selecting the best one. We initially explored aesthetic score and CLIP score, but found PickScore (Kirstain et al. 2023) to provide a great catch-all metric that, as stated in their paper, correlated well with human preference.

To make these batches efficient we used half precision, efficient attention and the PNDM scheduler (Liu et al. 2022). This allows us to generate 32 images (4x8) in ~5 seconds on an A100 GPU. As shown in the example below, the success rate of a generation aligned to a prompt isn’t always going to be 0% or 100% - by allowing multiple chances for the model to get it right and being able to automatically determine a good candidate we can enhance the ability of the overall pipeline.

As an example of (2), we applied a fairly standard process for regional improvement to faces (though it generalizes to any object)

  1. Get segmentation mask for an object (from prompt)
  2. Crop based off segmentation mask
  3. Enlarge crop
  4. Run img2img on crop with caption matching/paralleling the segmentation prompt (for faces we segment with “a face” and img2img with “a photograph of a face”)
  5. Use the segmentation mask to blend the upsampled crop back into the original image

Qualitative Evaluation of Rejection Sampling

We see that automatic rejection sampling using PickScore dramatically improves the performance of XGen-Image, inducing a larger gap than any model differences to StableDiffusion versions. Here all PartiPrompts are evaluated in a single trial.

One More Collage


Conclusion

In this post we introduced XGen-Image-1,  a text-to-image latent diffusion model trained to re-use several pretrained components. Our prototype largely follows the LDM/SD pipeline, but only trains at 256x256 pixel (32x32 latent) resolution, reducing compute cost. The XGen-Image prototype performs similarly to Stable Diffusion 1.5 and 2.1 in evaluation. We found that using PickScore to perform batched rejection sampling dramatically improved generations as measured by both human performance and automatic metrics. We are excited to continue developing XGen-Image and share our observations along the way.

Breakdown of Contributions

Bram Wallace: Code library, model training, inference pipeline

Akash Gokul: Early prototyping, help with coding, upsampler speed-ups

Dongxu Li: Data collection, formatting, and loading pipeline

Junnan Li: Advice on model training and help with the data pipeline

Nikhil Naik: Project planning, supervision, and management

We thank Srinath Reddy Meadusani and Lavanya Karanam for their help and support with computing infrastructure.  We also  thank Ran Xu, Ning Yu, Shu Zhang, and Kathy Baxter for their suggestions at different stages of the project. Finally, we  thank Caiming Xiong and Silvio Savarse or their advice and support throughout the project.