UNet2DModel for Digit Image Generation
Model Details
- Model Name: UNet2DModel
- Task: Digit Image Generation
- Model Type: Generative Model
- Dataset: MNIST (Handwritten Digit Images)
- Output Image Size: 32x32 pixels
- Image Color: Black and White (Grayscale)
Model Description
The model is a generative model specifically designed for digit image generation. It is trained on the MNIST dataset, which consists of handwritten digit images. The model is capable of generating realistic black and white digit images of numbers 0 to 9 with a size of 32x32 pixels.
from diffusers import UNet2DModel
unet = UNet2DModel(
in_channels=1,
out_channels=1,
sample_size=32,
block_out_channels=(32,64,128,256),
norm_num_groups=8,
num_class_embeds=10
)
Training
Limitations
- Single Modality: The model generates black and white digit images and does not capture color information.
- Limited to Digits: The model is specifically trained for digit image generation and may not generalize well to other types of images or objects.
- Resolution: The model generates digit images with a fixed resolution of 32x32 pixels and may not perform well on tasks requiring higher-resolution images.
Ethical Considerations
- Bias: The model's performance may be influenced by biases present in the MNIST dataset, such as variations in handwriting styles.
- Potential Misuse: The generated digit images should not be used for any malicious or fraudulent purposes, including creating counterfeit documents or impersonating individuals.
- Privacy: The model does not store or process any personal or sensitive information.
Usage Examples:
Example Python code snippets and instructions for using the model to generate image.
from diffusers import UNet2DModel, DDPMScheduler
device = "cuda"
scheduler = DDPMScheduler()
unet = UNet2DModel.from_pretrained("gnokit/unet-mnist-32", use_safetensors=True, variant="fp16").to(device)
class_to_generate = 8 # 0-9
sample = torch.randn(1, 1, 32, 32).to(device)
class_labels = [class_to_generate]
class_labels = torch.tensor(class_labels).to(device)
for i, t in enumerate(scheduler.timesteps):
# Get model pred
with torch.no_grad():
noise_pred = unet(sample, t, class_labels=class_labels).sample
# Update sample with step
sample = scheduler.step(noise_pred, t, sample).prev_sample
image = sample.clip(-1, 1)*0.5 + 0.5 # image in tensor format
license: apache-2.0 datasets: - mnist pipeline_tag: unconditional-image-generation
- Downloads last month
- 9