PyTorch is an open-source machine learning library used for a wide variety of tasks such as deep learning, natural language processing (NLP), and computer vision. It provides a flexible platform to build machine learning models and comes with strong support for GPU acceleration, making it popular among researchers and developers.
torchvision
and torchtext
libraries for various computer vision and NLP tasks.How to use PyTorch to create a Neural Network for classifying MNIST digits:
#!/usr/bin/env python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# Define a simple neural network
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 128) # Fully connected layer 1
self.fc2 = nn.Linear(128, 10) # Fully connected layer 2 (10 classes)
def forward(self, x):
x = x.view(-1, 28*28) # Flatten the image
x = torch.relu(self.fc1(x)) # Apply ReLU to fc1
x = self.fc2(x) # Final output (logits)
return x
# Load dataset and preprocess
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Initialize the model, loss function, and optimizer
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)
# Training loop
for epoch in range(2):
running_loss = 0.0
for inputs, labels in trainloader:
optimizer.zero_grad() # Zero the gradients
outputs = net(inputs) # Forward pass
loss = criterion(outputs, labels) # Compute the loss
loss.backward() # Backward pass
optimizer.step() # Optimize
running_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {running_loss / len(trainloader)}')
print("Finished Training")
torchvision
: Contains datasets, models, and transforms for computer vision.torchtext
: For NLP models and datasets.torchaudio
: For Audio and Speech Processing.