Easy Porting: NVIDIA to AMD Guide

Introduction

With both AMD and NVIDIA establishing themselves as top offerings for AI compute, questions have arisen over the differences in software required to run on each. Real-world workloads can run on both types of hardware with little to no code changes, and we're excited to demonstrate this further today.

We'll start by training an image classifier on the CIFAR-10 dataset in PyTorch on both NVIDIA and AMD.

Learn more about the CIFAR-10 dataset here.

We'll then move on to a more practical use-case: fine-tuning Llama 3.1 8B on a corpus of SQL data.

Learn more about Llama 3.1 here.


Training an Image Classification Model

To start, you're going to need to install PyTorch locally. Install the appropriate version depending on your hardware.

pip install requests
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2/

This will be the only difference in process for this tutorial.

Next, navigate to the directory you'd like to set this tutorial up in. From there, create the following Python script:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

device = torch.device('cuda')

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def train_and_evaluate(model, train_loader, test_loader, num_epochs=10, learning_rate=0.01):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

    for epoch in range(num_epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Accuracy: {accuracy:.2f}%')

    return model

def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

model = SimpleCNN()
trained_model = train_and_evaluate(model, train_loader, test_loader)

model_save_path = 'cifar10_cnn_model.pth'
save_model(trained_model, model_save_path)

trained_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = trained_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Final Test Accuracy: {100 * correct / total:.2f}%')

This script loads the dataset, transforms it, then trains and evaluates a CNN model that can classify at around 80% accuracy. This model gets saved at the model_save_path, which can be configured on your own.

You'll notice that at the top, we set our computation device via device = torch.device('cuda'). In PyTorch's ROCm installation, 'cuda' actually points to AMD GPUs, leaving no need to make any changes to any of your desired scripts.

Next, create the following inference script:

This script loads the model generated by the previous script, then classifies the specified image in image_url into one of the 10 categories:

That's it!


Fine-Tuning LLMs

For the purposes of this tutorial, we'll be fine-tuning Facebook's OPT-350m model. We'll begin by setting up our dependencies for significantly speeding up LLM training.

The following tutorial assumes the following prerequisites. If you're using different versions, please adjust your commands accordingly.

  • Linux (Ubuntu)

  • CUDA 12.1 or ROCm 6.2

Begin by installing the needed dependencies.

Notice that, as above, this will be the only difference between the two training processes

From there, make the following script in a subfolder you'd like to do your work in.

This script trains Facebook's OPT-350m model on an imdb review dataset, and saves the model for later inference. To conduct inference, use the following script:


Accelerated Inference for Llama 3.1 (and other HF Models)

For this section of the tutorial, we're going to use vLLM, a framework for accelerated LLM inference and serving.

More information on vLLM here.

We're going to serve Llama 3.1 8B Instruct through Docker containers. We'll start by pulling the images and serving the endpoints from there. Note that since the Llama models are gated, we'll have to log in through huggingface-cli to use them.

Note that you will have to note your HuggingFace API Token for both methods.

In a separate terminal, you can now query the endpoints!

Last updated