# Binary image classification using CNN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report


In [None]:
# 3-layer CNN
class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1),
            nn.Sigmoid()  # Binary classification output
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x


In [None]:
# Define image transformations
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
# Load dataset
train_dataset = datasets.ImageFolder(root='dataset/train', transform=train_transforms)
test_dataset = datasets.ImageFolder(root='dataset/test', transform=test_transforms)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
# Class names
class_names = train_dataset.classes
print(f"Classes: {class_names}")

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model
model = CustomCNN().to(device)

# Loss function & optimizer
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

print(device)

In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)  # Reshape for BCE loss
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            preds = (outputs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")



In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
            outputs = model(inputs)
            preds = (outputs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")
    return all_labels, all_preds


In [None]:
def plot_confusion_matrix(labels, preds, class_names):
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()
    print("\nClassification Report:\n", classification_report(labels, preds, target_names=class_names))


In [None]:
def visualize_predictions(model, test_loader, class_names, num_images=6):
    model.eval()
    images_so_far = 0
    fig = plt.figure(figsize=(10, 6))
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
            outputs = model(inputs)
            preds = (outputs > 0.5).float()
            for i in range(inputs.size(0)):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'True: {class_names[int(labels[i].item())]}, Pred: {class_names[int(preds[i].item())]}')
                img = inputs.cpu().data[i].numpy().transpose((1, 2, 0))
                img = np.clip(img * 0.5 + 0.5, 0, 1)  # Undo normalization
                plt.imshow(img)
                if images_so_far == num_images:
                    return


In [None]:
# CAUTION: training this model can take from 45 min to an hour.

# Train the model
# train_model(model, train_loader, criterion, optimizer, num_epochs=5)

In [None]:
# Save the model
# torch.save(model.state_dict(), "model_epochs5.pth")
# print("Saved PyTorch Model State to model_epochs5.pth")

In [None]:
# Load the saved model
model = CustomCNN().to(device)
model.load_state_dict(torch.load("model_epochs5.pth", weights_only=True))

In [None]:
# Evaluate the model
true_labels, pred_labels = evaluate_model(model, test_loader)

# Plot confusion matrix and classification report
plot_confusion_matrix(true_labels, pred_labels, class_names)

# Visualize sample predictions
visualize_predictions(model, test_loader, class_names)
