Copyright © 2022-2025 aizws.net · 网站版本: v1.2.6·内部版本: v1.25.2·
页面加载耗时 0.00 毫秒·物理内存 95.0MB ·虚拟内存 1434.8MB
欢迎来到 AI 中文社区(简称 AI 中文社),这里是学习交流 AI 人工智能技术的中文社区。 为了更好的体验,本站推荐使用 Chrome 浏览器。
很翔实的一篇教程。



# Operating System module for interacting with the operating systemimport os# Module for generating random numbersimport random# Module for numerical operationsimport numpy as np# OpenCV library for image processingimport cv2# Python Imaging Library for image processingfrom PIL import Image, ImageDraw, ImageFont# PyTorch library for deep learningimport torch# Dataset class for creating custom datasets in PyTorchfrom torch.utils.data import Dataset# Module for image transformationsimport torchvision.transforms as transforms# Neural network module in PyTorchimport torch.nn as nn# Optimization algorithms in PyTorchimport torch.optim as optim# Function for padding sequences in PyTorchfrom torch.nn.utils.rnn import pad_sequence# Function for saving images in PyTorchfrom torchvision.utils import save_image# Module for plotting graphs and imagesimport matplotlib.pyplot as plt# Module for displaying rich content in IPython environmentsfrom IPython.display import clear_output, display, HTML# Module for encoding and decoding binary data to textimport base64
# Create a directory named 'training_dataset'os.makedirs('training_dataset', exist_ok=True)# Define the number of videos to generate for the datasetnum_videos = 10000# Define the number of frames per video (1 Second Video)frames_per_video = 10# Define the size of each image in the datasetimg_size = (64, 64)# Define the size of the shapes (Circle)shape_size = 10
# Define text prompts and corresponding movements for circlesprompts_and_movements = [("circle moving down", "circle", "down"), # Move circle downward("circle moving left", "circle", "left"), # Move circle leftward("circle moving right", "circle", "right"), # Move circle rightward("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise("circle shrinking", "circle", "shrink"), # Shrink circle("circle expanding", "circle", "expand"), # Expand circle("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally("circle moving up-left", "circle", "up_left"), # Move circle up-left("circle moving down-right", "circle", "down_right"), # Move circle down-right("circle moving down-left", "circle", "down_left"), # Move circle down-left]
# Define function with parametersdef create_image_with_moving_shape(size, frame_num, shape, direction):# Create a new RGB image with specified size and white backgroundimg = Image.new('RGB', size, color=(255, 255, 255))# Create a drawing context for the imagedraw = ImageDraw.Draw(img)# Calculate the center coordinates of the imagecenter_x, center_y = size[0] // 2, size[1] // 2# Initialize position with center for all movementsposition = (center_x, center_y)# Define a dictionary mapping directions to their respective position adjustments or image transformationsdirection_map = {# Adjust position downwards based on frame number"down": (0, frame_num * 5 % size[1]),# Adjust position to the left based on frame number"left": (-frame_num * 5 % size[0], 0),# Adjust position to the right based on frame number"right": (frame_num * 5 % size[0], 0),# Adjust position diagonally up and to the right"diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),# Adjust position diagonally down and to the left"diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),# Adjust position diagonally up and to the left"diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),# Adjust position diagonally down and to the right"diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),# Rotate the image clockwise based on frame number"rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),# Rotate the image counter-clockwise based on frame number"rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),# Adjust position for a bouncing effect vertically"bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),# Adjust position for a bouncing effect horizontally"bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),# Adjust position for a zigzag effect vertically"zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),# Adjust position for a zigzag effect horizontally"zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),# Adjust position upwards and to the right based on frame number"up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),# Adjust position upwards and to the left based on frame number"up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),# Adjust position downwards and to the right based on frame number"down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),# Adjust position downwards and to the left based on frame number"down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])}# Check if direction is in the direction mapif direction in direction_map:# Check if the direction maps to a position adjustmentif isinstance(direction_map[direction], tuple):# Update position based on the adjustmentposition = tuple(np.add(position, direction_map[direction]))else: # If the direction maps to an image transformation# Update the image based on the transformationimg = direction_map[direction]# Return the image as a numpy arrayreturn np.array(img)
# Iterate over the number of videos to generatefor i in range(num_videos):# Randomly choose a prompt and movement from the predefined listprompt, shape, direction = random.choice(prompts_and_movements)# Create a directory for the current videovideo_dir = f'training_dataset/video_{i}'os.makedirs(video_dir, exist_ok=True)# Write the chosen prompt to a text file in the video directorywith open(f'{video_dir}/prompt.txt', 'w') as f:f.write(prompt)# Generate frames for the current videofor frame_num in range(frames_per_video):# Create an image with a moving shape based on the current frame number, shape, and directionimg = create_image_with_moving_shape(img_size, frame_num, shape, direction)# Save the generated image as a PNG file in the video directorycv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)



# Define a dataset class inheriting from torch.utils.data.Datasetclass TextToVideoDataset(Dataset):def __init__(self, root_dir, transform=None):# Initialize the dataset with root directory and optional transformself.root_dir = root_dirself.transform = transform# List all subdirectories in the root directoryself.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]# Initialize lists to store frame paths and corresponding promptsself.frame_paths = []self.prompts = []# Loop through each video directoryfor video_dir in self.video_dirs:# List all PNG files in the video directory and store their pathsframes = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]self.frame_paths.extend(frames)# Read the prompt text file in the video directory and store its contentwith open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:prompt = f.read().strip()# Repeat the prompt for each frame in the video and store in prompts listself.prompts.extend([prompt] * len(frames))# Return the total number of samples in the datasetdef __len__(self):return len(self.frame_paths)# Retrieve a sample from the dataset given an indexdef __getitem__(self, idx):# Get the path of the frame corresponding to the given indexframe_path = self.frame_paths[idx]# Open the image using PIL (Python Imaging Library)image = Image.open(frame_path)# Get the prompt corresponding to the given indexprompt = self.prompts[idx]# Apply transformation if specifiedif self.transform:image = self.transform(image)# Return the transformed image and the promptreturn image, prompt
# Define a class for text embeddingclass TextEmbedding(nn.Module):# Constructor method with vocab_size and embed_size parametersdef __init__(self, vocab_size, embed_size):# Call the superclass constructorsuper(TextEmbedding, self).__init__()# Initialize embedding layerself.embedding = nn.Embedding(vocab_size, embed_size)# Define the forward pass methoddef forward(self, x):# Return embedded representation of inputreturn self.embedding(x)
class Generator(nn.Module):def __init__(self, text_embed_size):super(Generator, self).__init__()# Fully connected layer that takes noise and text embedding as inputself.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)# Transposed convolutional layers to upsample the inputself.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1)# Output has 3 channels for RGB images# Activation functionsself.relu = nn.ReLU(True)# ReLU activation functionself.tanh = nn.Tanh()# Tanh activation function for final outputdef forward(self, noise, text_embed):# Concatenate noise and text embedding along the channel dimensionx = torch.cat((noise, text_embed), dim=1)# Fully connected layer followed by reshaping to 4D tensorx = self.fc1(x).view(-1, 256, 8, 8)# Upsampling through transposed convolution layers with ReLU activationx = self.relu(self.deconv1(x))x = self.relu(self.deconv2(x))# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)x = self.tanh(self.deconv3(x))return x
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# Convolutional layers to process input imagesself.conv1 = nn.Conv2d(3, 64, 4, 2, 1) # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1# Fully connected layer for classificationself.fc1 = nn.Linear(256 * 8 * 8, 1) # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)# Activation functionsself.leaky_relu = nn.LeakyReLU(0.2, inplace=True) # Leaky ReLU activation with negative slope 0.2self.sigmoid = nn.Sigmoid() # Sigmoid activation for final output (probability)def forward(self, input):# Pass input through convolutional layers with LeakyReLU activationx = self.leaky_relu(self.conv1(input))x = self.leaky_relu(self.conv2(x))x = self.leaky_relu(self.conv3(x))# Flatten the output of convolutional layersx = x.view(-1, 256 * 8 * 8)# Pass through fully connected layer with Sigmoid activation for binary classificationx = self.sigmoid(self.fc1(x))return x
# Check for GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Create a simple vocabulary for text promptsall_prompts = [prompt for prompt, _, _ in prompts_and_movements] # Extract all prompts from prompts_and_movements listvocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))} # Create a vocabulary dictionary where each unique word is assigned an indexvocab_size = len(vocab) # Size of the vocabularyembed_size = 10 # Size of the text embedding vectordef encode_text(prompt):# Encode a given prompt into a tensor of indices using the vocabularyreturn torch.tensor([vocab[word] for word in prompt.split()])# Initialize models, loss function, and optimizerstext_embedding = TextEmbedding(vocab_size, embed_size).to(device) # Initialize TextEmbedding model with vocab_size and embed_sizenetG = Generator(embed_size).to(device) # Initialize Generator model with embed_sizenetD = Discriminator().to(device) # Initialize Discriminator modelcriterion = nn.BCELoss().to(device) # Binary Cross Entropy loss functionoptimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for DiscriminatoroptimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for Generator
# Number of epochsnum_epochs = 13# Iterate over each epochfor epoch in range(num_epochs):# Iterate over each batch of datafor i, (data, prompts) in enumerate(dataloader):# Move real data to devicereal_data = data.to(device)# Convert prompts to listprompts = [prompt for prompt in prompts]# Update DiscriminatornetD.zero_grad() # Zero the gradients of the Discriminatorbatch_size = real_data.size(0) # Get the batch sizelabels = torch.ones(batch_size, 1).to(device) # Create labels for real data (ones)output = netD(real_data) # Forward pass real data through DiscriminatorlossD_real = criterion(output, labels) # Calculate loss on real datalossD_real.backward() # Backward pass to calculate gradients# Generate fake datanoise = torch.randn(batch_size, 100).to(device) # Generate random noisetext_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts]) # Encode prompts into text embeddingsfake_data = netG(noise, text_embeds) # Generate fake data from noise and text embeddingslabels = torch.zeros(batch_size, 1).to(device) # Create labels for fake data (zeros)output = netD(fake_data.detach()) # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)lossD_fake = criterion(output, labels) # Calculate loss on fake datalossD_fake.backward() # Backward pass to calculate gradientsoptimizerD.step() # Update Discriminator parameters# Update GeneratornetG.zero_grad() # Zero the gradients of the Generatorlabels = torch.ones(batch_size, 1).to(device) # Create labels for fake data (ones) to fool Discriminatoroutput = netD(fake_data) # Forward pass fake data (now updated) through DiscriminatorlossG = criterion(output, labels) # Calculate loss for Generator based on Discriminator's responselossG.backward() # Backward pass to calculate gradientsoptimizerG.step() # Update Generator parameters# Print epoch informationprint(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")
## OUTPUT ##Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776...
# Save the Generator model's state dictionary to a file named 'generator.pth'torch.save(netG.state_dict(), 'generator.pth')# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'torch.save(netD.state_dict(), 'discriminator.pth')
# Inference function to generate a video based on a given text promptdef generate_video(text_prompt, num_frames=10): # Create a directory for the generated video frames based on the text prompt os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True) # Encode the text prompt into a text embedding tensor text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0) # Generate frames for the video for frame_num in range(num_frames): # Generate random noise noise = torch.randn(1, 100).to(device) # Generate a fake frame using the Generator network with torch.no_grad(): fake_frame = netG(noise, text_embed) # Save the generated fake frame as an image file save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')# usage of the generate_video function with a specific text promptgenerate_video('circle moving up-right')# Define the path to your folder containing the PNG framesfolder_path = 'generated_video_circle_moving_up-right'# Get the list of all PNG files in the folderimage_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]# Sort the images by name (assuming they are numbered sequentially)image_files.sort()# Create a list to store the framesframes = []# Read each image and append it to the frames listfor image_file in image_files:image_path = os.path.join(folder_path, image_file)frame = cv2.imread(image_path)frames.append(frame)# Convert the frames list to a numpy array for easier processingframes = np.array(frames)# Define the frame rate (frames per second)fps = 10# Create a video writer objectfourcc = cv2.VideoWriter_fourcc(*'XVID')out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))# Write each frame to the videofor frame in frames:out.write(frame)# Release the video writerout.release()

声明:本文转载自机器之心,转载目的在于传递更多信息,并不代表本社区赞同其观点和对其真实性负责,本文只提供参考并不构成任何建议,若有版权等问题,点击这里。