{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "5d31b5dc", "metadata": {}, "outputs": [], "source": [ "NUM_WORKERS = 1\n", "BATCH_SIZE = 1\n", "EPOCHS = 10\n", "LEARNING_RATE = 0.002\n", "EMBED_DIM = 300\n", "WINDOW = 2" ] }, { "cell_type": "code", "execution_count": null, "id": "e674452f", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import time" ] }, { "cell_type": "code", "execution_count": null, "id": "ed399515", "metadata": {}, "outputs": [], "source": [ "# Device configuration\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(device)" ] }, { "cell_type": "markdown", "id": "aadad342", "metadata": {}, "source": [ "# Environment\n", "We are running this on the SCC using the `python3/3.10.12` and `pytorch/1.13.1` modules. " ] }, { "cell_type": "markdown", "id": "1ba2b3cb", "metadata": {}, "source": [ "# Corpus for training\n", "\"The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia.\"\n", "https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/\n", "\n", "This is available to use from the Torchtext package.\n", "\n", "The `pytorch/1.13.1` module comes with a version of `torchtext`, but apparently this requires `torchdata` to run. According to:
\n", "https://github.com/pytorch/data#installation
\n", "for PyTorch 1.13.1 we should use Torchdata 0.5.1\n", "\n", "`pip install --no-cache-dir --prefix=/projectnb/jbrcs/word2vec/packages torchdata==0.5.1`" ] }, { "cell_type": "code", "execution_count": null, "id": "993b347f", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "from torchtext.datasets import WikiText103\n", "\n", "train, valid, test = WikiText103(root='./data', split= ('train', 'valid', 'test'))" ] }, { "cell_type": "code", "execution_count": null, "id": "f8bd8e9f", "metadata": {}, "outputs": [], "source": [ "valid" ] }, { "cell_type": "markdown", "id": "f84cd619", "metadata": {}, "source": [ "This is a \"DataPipe\" which allows us to stream in data efficiently to our program. We can make an iterator using the DataPipe if we want:" ] }, { "cell_type": "code", "execution_count": null, "id": "da63b4db", "metadata": {}, "outputs": [], "source": [ "myiter = iter(valid)\n", "print(repr(next(myiter)))\n", "print(repr(next(myiter)))\n", "print(repr(next(myiter)))\n", "print(repr(next(myiter)))" ] }, { "cell_type": "markdown", "id": "9eee1f77", "metadata": {}, "source": [ "We want to build a \"vocabulary\" of all the words in our corpus. To do this we need to \"tokenize\" the corpus into individual words. We could write our own simple tokenizer by splitting on spaces, etc; instead we will use one included with Torchtext:
\n", "https://pytorch.org/text/stable/data_utils.html#get-tokenizer\n", "\n", "Then we build our vocabulary, assigning an index to each word:
\n", "https://pytorch.org/text/stable/vocab.html#build-vocab-from-iterator" ] }, { "cell_type": "code", "execution_count": null, "id": "c71bc32b", "metadata": {}, "outputs": [], "source": [ "from torchtext.vocab import build_vocab_from_iterator\n", "from torchtext.data.utils import get_tokenizer\n", "\n", "# This tokenizer converts words to token ID numbers\n", "tokenizer = get_tokenizer('basic_english')\n", "\n", "remove_titles = lambda x: x[:2]!=' ='\n", "remove_short_sequences = lambda x: len(x)>2*WINDOW\n", "\n", "valid = valid.filter(remove_titles)\n", "valid = valid.map(tokenizer)\n", "valid = valid.filter(remove_short_sequences)\n", "\n", "vocab_valid = build_vocab_from_iterator(valid, min_freq=1, specials=[\"\"])\n", "VOCAB_SIZE = len(vocab_valid)\n", "valid = valid.map(vocab_valid.lookup_indices)" ] }, { "cell_type": "code", "execution_count": null, "id": "8ed465dd", "metadata": {}, "outputs": [], "source": [ "print(f\"Number of words in vocab: {VOCAB_SIZE}\")\n", "\n", "# Example of getting ID using vocab\n", "print(\n", " vocab_valid['cool'])\n", "\n", "# We can tokenize a whole sentence like so:\n", "vocab_valid.lookup_indices(tokenizer('Josh is a cool guy'))" ] }, { "cell_type": "code", "execution_count": null, "id": "37ae23b6", "metadata": {}, "outputs": [], "source": [ "# Using a dataloader allows us to efficiently work with large amounts of data\n", "loader_valid = DataLoader(\n", " dataset=valid,\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " drop_last=True,\n", " num_workers=NUM_WORKERS,\n", " collate_fn=lambda x: x)" ] }, { "cell_type": "code", "execution_count": null, "id": "36ec416e", "metadata": {}, "outputs": [], "source": [ "def make_context(tokens, window):\n", " # This gives the context words within the window around the target word. For example:\n", " # the quick BROWN fox jumps\n", " # If the target word is \"brown\". For a window size of 2 the context words are:\n", " # \"the, quick, fox, jumps\"\n", " \n", " # The first/last \"window\" amount of words have their left/right context window instersect with the start/end\n", " # We still want 2*window words for our context, so we take the next closest words by\n", " # padding the start and end of a document with a \"window\" amount of words. For example:\n", " # fox jumps | the quick brown fox jumps\n", " # Where the \"|\" indicates the end of the padded portion. So if our window size is still 2 and the target\n", " # is \"quick\" then we only have one word in the left window in th original sequence, but with the padded one:\n", " # fox jumps | the QUICK brown fox jumps\n", " # Our left window now extends into the padded region. The effect is as if our left window size is one and\n", " # the right window size is three, but it makes it easier to code.\n", " \n", " padded = (tokens[window+1 : 2*window+1]\n", " + tokens\n", " + tokens[-(2*window+1) : -(window+1)])\n", "\n", " data = []\n", " for i in range(window, len(padded) - window):\n", " context = padded[i-window : i] + padded[i+1 : i+1+window]\n", " data.append(context)\n", "\n", " return torch.tensor([tokens]), torch.tensor(data)" ] }, { "cell_type": "code", "execution_count": null, "id": "063757bf", "metadata": {}, "outputs": [], "source": [ "class CBOW(torch.nn.Module):\n", " def __init__(self, vocab_size, embedding_dim):\n", " super(CBOW, self).__init__()\n", " \n", " self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n", " self.linear = nn.Linear(embedding_dim, vocab_size)\n", " self.activation_function = nn.LogSoftmax(dim = -1)\n", " \n", " self.embeddings.weight.data.uniform_(-1,1)\n", "\n", " def forward(self, inputs):\n", " embeds = torch.sum(self.embeddings(inputs), dim=1)\n", " out = self.linear(embeds)\n", " out = self.activation_function(out)\n", " return out\n", "\n", " def get_word_emdedding(self, word):\n", " word = torch.tensor([word_to_ix[word]])\n", " return self.embeddings(word).view(1,-1)" ] }, { "cell_type": "code", "execution_count": null, "id": "03dcc99f", "metadata": {}, "outputs": [], "source": [ "for idx, batch in enumerate(loader_valid):\n", " for sample in batch:\n", " print('OOOOOOO')\n", " print(sample)\n", " if doc_id==0:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "dc9b8022", "metadata": {}, "outputs": [], "source": [ "model = CBOW(VOCAB_SIZE, EMBED_DIM).to(device)\n", "loss_function = nn.NLLLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)\n", "\n", "latch = True\n", "#TRAINING\n", "mytime = time.time()\n", "for epoch in range(EPOCHS):\n", " total_loss = 0\n", " for idx, batch in enumerate(loader_valid):\n", " targets, context = make_context(batch, WINDOW)\n", " targets = torch.tensor(targets, dtype=torch.long).to(device)\n", " context = torch.tensor(context, dtype=torch.long).to(device)\n", " log_probs = model(context)\n", " total_loss += loss_function(log_probs, targets)\n", " if idx%100==0:\n", " print(idx, total_loss.item())\n", " print(epoch, total_loss.item(), time.time()-mytime)\n", " \n", " #optimize at the end of each epoch\n", " optimizer.zero_grad()\n", " total_loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "id": "108a3c69", "metadata": {}, "outputs": [], "source": [ "#TESTING\n", "# widely caught using lobster pots\n", "#Atlantic Ocean , Mediterranean Sea and\n", "#widely caught lobster pots\n", "context = 'Atlantic Ocean Sea and'\n", "context_vector = v1.lookup_indices(tokenizer(context))\n", "a = model(torch.tensor(context_vector, dtype=torch.long).to(device))\n", "val,ind = torch.topk(a,10)\n", "v1.lookup_tokens(ind.tolist()[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "fe3c3d54", "metadata": {}, "outputs": [], "source": [ "def mysim(a,b):\n", " aa = model.embeddings(torch.tensor(v1[a], dtype=torch.long).to(device))\n", " bb = model.embeddings(torch.tensor(v1[b], dtype=torch.long).to(device))\n", " sim = torch.nn.CosineSimilarity(dim=0, eps=1e-6)\n", " return sim(aa,bb).item()\n", "\n", "mysim('lobster','josh')\n", " " ] }, { "cell_type": "markdown", "id": "3d05fe9b", "metadata": {}, "source": [ "# TODO\n", "* Create a counter from collections to get word frequency\n", "* Use the counter to create a vocab\n", "* Use the word frequency for subsampling to improve performance (use subsampling distribution from word2vec paper) by sampling common words less frequently" ] }, { "cell_type": "code", "execution_count": null, "id": "fab6ab40", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }