{ "cells": [ { "cell_type": "code", "execution_count": 34, "id": "5d31b5dc", "metadata": {}, "outputs": [], "source": [ "NUM_WORKERS = 1\n", "BATCH_SIZE = 5\n", "EPOCHS = 10\n", "LEARNING_RATE = 0.002\n", "EMBED_DIM = 300\n", "WINDOW = 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "e674452f", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import time" ] }, { "cell_type": "code", "execution_count": 3, "id": "ed399515", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "# Device configuration\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(device)" ] }, { "cell_type": "markdown", "id": "aadad342", "metadata": {}, "source": [ "# Environment\n", "We are running this on the SCC using the `python3/3.10.12` and `pytorch/1.13.1` modules. " ] }, { "cell_type": "markdown", "id": "1ba2b3cb", "metadata": {}, "source": [ "# Corpus for training\n", "\"The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia.\"\n", "https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/\n", "\n", "This is available to use from the Torchtext package.\n", "\n", "The `pytorch/1.13.1` module comes with a version of `torchtext`, but apparently this requires `torchdata` to run. According to:
\n", "https://github.com/pytorch/data#installation
\n", "for PyTorch 1.13.1 we should use Torchdata 0.5.1\n", "\n", "`pip install --no-cache-dir --prefix=/projectnb/jbrcs/word2vec/packages torchdata==0.5.1`" ] }, { "cell_type": "code", "execution_count": 25, "id": "993b347f", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "from torchtext.datasets import WikiText103\n", "\n", "train, valid, test = WikiText103(root='./data', split= ('train', 'valid', 'test'))" ] }, { "cell_type": "code", "execution_count": 26, "id": "f8bd8e9f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ShardingFilterIterDataPipe" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid" ] }, { "cell_type": "markdown", "id": "f84cd619", "metadata": {}, "source": [ "This is a \"DataPipe\" which allows us to stream in data efficiently to our program. We can make an iterator using the DataPipe if we want:" ] }, { "cell_type": "code", "execution_count": 18, "id": "da63b4db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1788, 868, 2, 167, 14, 1, 1110, 1803, 60, 758, 1803, 2, 24, 8, 815, 4, 7428, 1803, 23, 1, 378, 1449, 1593, 2, 3017, 616, 5, 1180, 4, 1, 282, 616, 3, 26, 24, 4142, 1830, 7, 1, 122, 1803, 2, 351, 3, 2858, 3, 26, 85, 4973, 7, 8, 1121, 4, 933, 2404, 20, 595, 6, 21, 5, 8, 3015, 4, 277, 8259, 20, 338, 6219, 21, 2, 5, 1885, 8, 7501, 1179, 4, 4783, 3, 6, 509, 2, 1, 3008, 49, 1155, 2, 88, 1215, 1803, 924, 12, 5803, 3, 8414, 3850, 6, 1, 1529, 2, 2508, 4891, 31, 49, 778, 16, 1, 6014, 18, 87, 7, 8, 120, 98, 12530, 51, 14281, 6212, 3, 1788, 868, 24, 8, 1923, 5973, 1780, 2, 5, 24, 4020, 1653, 517, 1803, 5215, 2, 744, 114, 1, 99, 3354, 3]\n", "[1788, 868, 24, 8, 196, 11262, 2, 19, 8, 1049, 1121, 87, 7, 933, 10800, 20, 595, 6, 21, 5, 9485, 87, 7, 156, 50, 277, 8259, 20, 337, 50, 338, 6219, 21, 2, 132, 1, 3008, 1653, 6, 1803, 5215, 49, 2146, 665, 50, 1439, 2404, 20, 216, 50, 331, 6, 21, 217, 5, 16318, 238, 17, 3, 17, 227, 50, 54, 17, 3, 17, 54, 2268, 20, 58, 17, 3, 17, 156, 50, 138, 17, 3, 17, 216, 6219, 21, 3, 258, 65, 7558, 2, 3008, 48, 8, 2970, 7845, 31, 40, 2285, 15178, 6, 286, 7, 4973, 2, 6, 8, 923, 199, 11700, 20, 13700, 21, 3, 32, 85, 3038, 107, 427, 8, 120, 18, 822, 3008, 2, 34, 11351, 7, 637, 652, 58, 50, 54, 121, 18, 1800, 2183, 3]\n", "[1, 44, 1179, 4, 14191, 24, 1753, 19, 8, 196, 2, 10274, 1179, 4, 4783, 3, 1, 1800, 36, 24, 1, 5824, 2, 5, 47, 6579, 13859, 143, 18, 11261, 8750, 1, 65, 24, 1, 7572, 2, 31, 47, 3457, 2698, 4884, 2, 5, 24, 143, 18, 1578, 60, 15754, 1, 8750, 3, 2146, 2, 1, 294, 5758, 24, 1, 5824, 2, 5, 1, 694, 24, 1, 7572, 3]\n", "[1, 7845, 24, 1785, 1155, 894, 2, 19, 9157, 15, 10950, 2, 5, 2357, 706, 3, 1, 924, 2210, 939, 19, 3008, 88, 1284, 38, 5803, 3, 32, 3850, 229, 2, 6, 509, 2, 1, 924, 8692, 10266, 24, 2195, 7, 8, 14462, 520, 2, 34, 1, 520, 24, 1451, 87, 16, 1, 3336, 4, 5803, 2, 6533, 1, 924, 8692, 3]\n" ] } ], "source": [ "myiter = iter(valid)\n", "print(repr(next(myiter)))\n", "print(repr(next(myiter)))\n", "print(repr(next(myiter)))\n", "print(repr(next(myiter)))" ] }, { "cell_type": "markdown", "id": "9eee1f77", "metadata": {}, "source": [ "We want to build a \"vocabulary\" of all the words in our corpus. To do this we need to \"tokenize\" the corpus into individual words. We could write our own simple tokenizer by splitting on spaces, etc; instead we will use one included with Torchtext:
\n", "https://pytorch.org/text/stable/data_utils.html#get-tokenizer\n", "\n", "Then we build our vocabulary, assigning an index to each word:
\n", "https://pytorch.org/text/stable/vocab.html#build-vocab-from-iterator" ] }, { "cell_type": "code", "execution_count": 27, "id": "f3749e3b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ShardingFilterIterDataPipe" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid" ] }, { "cell_type": "code", "execution_count": 28, "id": "c71bc32b", "metadata": {}, "outputs": [], "source": [ "from torchtext.vocab import build_vocab_from_iterator\n", "from torchtext.data.utils import get_tokenizer\n", "\n", "# This tokenizer converts words to token ID numbers\n", "tokenizer = get_tokenizer('basic_english')\n", "\n", "remove_titles = lambda x: x[:2]!=' ='\n", "remove_short_sequences = lambda x: len(x)>2*WINDOW\n", "\n", "valid = valid.filter(remove_titles)\n", "valid = valid.map(tokenizer)\n", "valid = valid.filter(remove_short_sequences)\n", "\n", "vocab_valid = build_vocab_from_iterator(valid, min_freq=1, specials=[\"\"])\n", "VOCAB_SIZE = len(vocab_valid)\n", "valid = valid.map(vocab_valid.lookup_indices)" ] }, { "cell_type": "code", "execution_count": 29, "id": "8ed465dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of words in vocab: 16502\n", "4819\n" ] }, { "data": { "text/plain": [ "[13028, 24, 8, 4819, 3331]" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(f\"Number of words in vocab: {VOCAB_SIZE}\")\n", "\n", "# Example of getting ID using vocab\n", "print(\n", " vocab_valid['cool'])\n", "\n", "# We can tokenize a whole sentence like so:\n", "vocab_valid.lookup_indices(tokenizer('Josh is a cool guy'))" ] }, { "cell_type": "code", "execution_count": 35, "id": "37ae23b6", "metadata": {}, "outputs": [], "source": [ "# Using a dataloader allows us to efficiently work with large amounts of data\n", "loader_valid = DataLoader(\n", " dataset=valid,\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " drop_last=True,\n", " num_workers=NUM_WORKERS,\n", " collate_fn=lambda x: x)" ] }, { "cell_type": "code", "execution_count": 36, "id": "36ec416e", "metadata": {}, "outputs": [], "source": [ "def make_context(tokens, window):\n", " # This gives the context words within the window around the target word. For example:\n", " # the quick BROWN fox jumps\n", " # If the target word is \"brown\". For a window size of 2 the context words are:\n", " # \"the, quick, fox, jumps\"\n", " \n", " # The first/last \"window\" amount of words have their left/right context window instersect with the start/end\n", " # We still want 2*window words for our context, so we take the next closest words by\n", " # padding the start and end of a document with a \"window\" amount of words. For example:\n", " # fox jumps | the quick brown fox jumps\n", " # Where the \"|\" indicates the end of the padded portion. So if our window size is still 2 and the target\n", " # is \"quick\" then we only have one word in the left window in th original sequence, but with the padded one:\n", " # fox jumps | the QUICK brown fox jumps\n", " # Our left window now extends into the padded region. The effect is as if our left window size is one and\n", " # the right window size is three, but it makes it easier to code.\n", " \n", " padded = (tokens[window+1 : 2*window+1]\n", " + tokens\n", " + tokens[-(2*window+1) : -(window+1)])\n", "\n", " data = []\n", " for i in range(window, len(padded) - window):\n", " context = padded[i-window : i] + padded[i+1 : i+1+window]\n", " data.append(context)\n", "\n", " return torch.tensor([tokens]), torch.tensor(data)" ] }, { "cell_type": "code", "execution_count": 37, "id": "063757bf", "metadata": {}, "outputs": [], "source": [ "class CBOW(torch.nn.Module):\n", " def __init__(self, vocab_size, embedding_dim):\n", " super(CBOW, self).__init__()\n", " \n", " self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n", " self.linear = nn.Linear(embedding_dim, vocab_size)\n", " self.activation_function = nn.LogSoftmax(dim = -1)\n", " \n", " self.embeddings.weight.data.uniform_(-1,1)\n", "\n", " def forward(self, inputs):\n", " embeds = torch.sum(self.embeddings(inputs), dim=1)\n", " out = self.linear(embeds)\n", " out = self.activation_function(out)\n", " return out\n", "\n", " def get_word_emdedding(self, word):\n", " word = torch.tensor([word_to_ix[word]])\n", " return self.embeddings(word).view(1,-1)" ] }, { "cell_type": "code", "execution_count": 23, "id": "03dcc99f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OOOOOOO\n", "[1, 3808, 143, 18, 349, 9, 362, 5, 14532, 12, 11, 288, 3, 1, 7628, 581, 16, 1, 11975, 4, 699, 18, 417, 25, 6421, 7, 9480, 112, 5, 25, 143, 14, 4451, 2, 851, 19, 94, 16299, 1695, 7628, 3, 1, 236, 8709, 25, 15626, 19, 3966, 5, 2479, 22, 8, 8016, 15, 10828, 4426, 51, 8, 323, 4, 10710, 15, 12047, 1, 4451, 3]\n" ] }, { "ename": "NameError", "evalue": "name 'doc_id' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[23], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOOOOOOO\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(sample)\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mdoc_id\u001b[49m\u001b[38;5;241m==\u001b[39m\u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", "\u001b[0;31mNameError\u001b[0m: name 'doc_id' is not defined" ] } ], "source": [ "for idx, batch in enumerate(loader_valid):\n", " for sample in batch:\n", " print('OOOOOOO')\n", " print(sample)\n", " if doc_id==0:\n", " break" ] }, { "cell_type": "code", "execution_count": 38, "id": "dc9b8022", "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "expected sequence of length 145 at dim 2 (got 74)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[38], line 11\u001b[0m\n\u001b[1;32m 9\u001b[0m total_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(loader_valid):\n\u001b[0;32m---> 11\u001b[0m targets, context \u001b[38;5;241m=\u001b[39m \u001b[43mmake_context\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mWINDOW\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m targets \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(targets, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mlong)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 13\u001b[0m context \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(context, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mlong)\u001b[38;5;241m.\u001b[39mto(device)\n", "Cell \u001b[0;32mIn[36], line 26\u001b[0m, in \u001b[0;36mmake_context\u001b[0;34m(tokens, window)\u001b[0m\n\u001b[1;32m 23\u001b[0m context \u001b[38;5;241m=\u001b[39m padded[i\u001b[38;5;241m-\u001b[39mwindow : i] \u001b[38;5;241m+\u001b[39m padded[i\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m : i\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;241m+\u001b[39mwindow]\n\u001b[1;32m 24\u001b[0m data\u001b[38;5;241m.\u001b[39mappend(context)\n\u001b[0;32m---> 26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m, torch\u001b[38;5;241m.\u001b[39mtensor(data)\n", "\u001b[0;31mValueError\u001b[0m: expected sequence of length 145 at dim 2 (got 74)" ] } ], "source": [ "model = CBOW(VOCAB_SIZE, EMBED_DIM).to(device)\n", "loss_function = nn.NLLLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)\n", "\n", "latch = True\n", "#TRAINING\n", "mytime = time.time()\n", "for epoch in range(EPOCHS):\n", " total_loss = 0\n", " for idx, batch in enumerate(loader_valid):\n", " targets, context = make_context(batch, WINDOW)\n", " targets = torch.tensor(targets, dtype=torch.long).to(device)\n", " context = torch.tensor(context, dtype=torch.long).to(device)\n", " log_probs = model(context)\n", " total_loss += loss_function(log_probs, targets)\n", " if idx%100==0:\n", " print(idx, total_loss.item())\n", " print(epoch, total_loss.item(), time.time()-mytime)\n", " \n", " #optimize at the end of each epoch\n", " optimizer.zero_grad()\n", " total_loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "id": "108a3c69", "metadata": {}, "outputs": [], "source": [ "#TESTING\n", "# widely caught using lobster pots\n", "#Atlantic Ocean , Mediterranean Sea and\n", "#widely caught lobster pots\n", "context = 'Atlantic Ocean Sea and'\n", "context_vector = v1.lookup_indices(tokenizer(context))\n", "a = model(torch.tensor(context_vector, dtype=torch.long).to(device))\n", "val,ind = torch.topk(a,10)\n", "v1.lookup_tokens(ind.tolist()[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "fe3c3d54", "metadata": {}, "outputs": [], "source": [ "def mysim(a,b):\n", " aa = model.embeddings(torch.tensor(v1[a], dtype=torch.long).to(device))\n", " bb = model.embeddings(torch.tensor(v1[b], dtype=torch.long).to(device))\n", " sim = torch.nn.CosineSimilarity(dim=0, eps=1e-6)\n", " return sim(aa,bb).item()\n", "\n", "mysim('lobster','josh')\n", " " ] }, { "cell_type": "markdown", "id": "3d05fe9b", "metadata": {}, "source": [ "# TODO\n", "* Create a counter from collections to get word frequency\n", "* Use the counter to create a vocab\n", "* Use the word frequency for subsampling to improve performance (use subsampling distribution from word2vec paper) by sampling common words less frequently" ] }, { "cell_type": "code", "execution_count": null, "id": "fab6ab40", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }