Skip to content

Get chroma to a functioning state #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025

Conversation

iddl
Copy link

@iddl iddl commented Jun 9, 2025

This is a series of fixes to @hameerabbasi's chroma branch to get the code working on diffusers.

How to get the model to run

import torch
from PIL import Image
from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, FluxPipeline, AutoencoderKL
from transformers import T5EncoderModel, T5Tokenizer
from huggingface_hub import hf_hub_download

transformer_file = hf_hub_download(repo_id="lodestones/Chroma", filename="chroma-unlocked-v30.safetensors")
transformer = FluxTransformer2DModel.from_single_file(
    transformer_file,
    variant="chroma",
)

vae = AutoencoderKL.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    subfolder="vae"
)

t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
t5_encoder = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2").to("cpu")


pipe = FluxPipeline(
    transformer=transformer,
    vae=vae,
    tokenizer=None,
    tokenizer_2=None,
    text_encoder=None,
    text_encoder_2=None,
    scheduler=FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=True, base_shift = 0.5, max_shift = 1.15, use_beta_sigmas = True),
    variant="chroma"
)

pipe.to("cuda", dtype=torch.bfloat16)

def embed_prompt(prompt):
    prompt_tokens = t5_tokenizer(prompt, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
    prompt_embeds = t5_encoder(prompt_tokens.input_ids, attention_mask=prompt_tokens.attention_mask)[0]
    max_len = min(prompt_tokens.attention_mask.sum() + 1, 512)
    prompt_embeds = prompt_embeds[:, :max_len] # Truncate to promptlength+1 (i.e. leave one padding token at the end)
    return prompt_embeds


seed = 1030
prompt = "Extreme close-up photograph of a single tiger eye, direct frontal view. The iris is very detailed and the pupil resembling a dark void. The word \"Chroma\" is across the lower portion of the image in large white stylized letters, with brush strokes resembling those made with Japanese calligraphy. Each strand of the thick fur is highly detailed and distinguishable. Natural lighting to capture authentic eye shine and depth."
negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
prompt_embeds=embed_prompt(prompt)
negative_prompt_embeds=embed_prompt(negative_prompt)

image = pipe(
    prompt_embeds=prompt_embeds.to(device="cuda", dtype=torch.bfloat16),
    negative_prompt_embeds=negative_prompt_embeds.to(device="cuda",dtype=torch.bfloat16),
    
    # Zeros for CLIP:
    pooled_prompt_embeds=torch.zeros(1, 768).to(torch.float16),
    negative_pooled_prompt_embeds=torch.zeros(1, 768).to(torch.float16),
    
    num_inference_steps=26,
    true_cfg_scale=4.0,
    guidance_scale=0,
    width=1024,
    height=1024,
    generator=torch.Generator().manual_seed(seed),
).images[0]

image

Screenshot 2025-06-09 at 10 54 12 AM

@hameerabbasi hameerabbasi merged commit e95ac9d into hameerabbasi:chroma Jun 10, 2025
@DN6 DN6 mentioned this pull request Jun 13, 2025
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants