{
"cells": [
{
"cell_type": "code",
"execution_count": 34,
"id": "5d31b5dc",
"metadata": {},
"outputs": [],
"source": [
"NUM_WORKERS = 1\n",
"BATCH_SIZE = 5\n",
"EPOCHS = 10\n",
"LEARNING_RATE = 0.002\n",
"EMBED_DIM = 300\n",
"WINDOW = 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e674452f",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ed399515",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cpu\n"
]
}
],
"source": [
"# Device configuration\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"id": "aadad342",
"metadata": {},
"source": [
"# Environment\n",
"We are running this on the SCC using the `python3/3.10.12` and `pytorch/1.13.1` modules. "
]
},
{
"cell_type": "markdown",
"id": "1ba2b3cb",
"metadata": {},
"source": [
"# Corpus for training\n",
"\"The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia.\"\n",
"https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/\n",
"\n",
"This is available to use from the Torchtext package.\n",
"\n",
"The `pytorch/1.13.1` module comes with a version of `torchtext`, but apparently this requires `torchdata` to run. According to:
\n",
"https://github.com/pytorch/data#installation
\n",
"for PyTorch 1.13.1 we should use Torchdata 0.5.1\n",
"\n",
"`pip install --no-cache-dir --prefix=/projectnb/jbrcs/word2vec/packages torchdata==0.5.1`"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "993b347f",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from torchtext.datasets import WikiText103\n",
"\n",
"train, valid, test = WikiText103(root='./data', split= ('train', 'valid', 'test'))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "f8bd8e9f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ShardingFilterIterDataPipe"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"valid"
]
},
{
"cell_type": "markdown",
"id": "f84cd619",
"metadata": {},
"source": [
"This is a \"DataPipe\" which allows us to stream in data efficiently to our program. We can make an iterator using the DataPipe if we want:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "da63b4db",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1788, 868, 2, 167, 14, 1, 1110, 1803, 60, 758, 1803, 2, 24, 8, 815, 4, 7428, 1803, 23, 1, 378, 1449, 1593, 2, 3017, 616, 5, 1180, 4, 1, 282, 616, 3, 26, 24, 4142, 1830, 7, 1, 122, 1803, 2, 351, 3, 2858, 3, 26, 85, 4973, 7, 8, 1121, 4, 933, 2404, 20, 595, 6, 21, 5, 8, 3015, 4, 277, 8259, 20, 338, 6219, 21, 2, 5, 1885, 8, 7501, 1179, 4, 4783, 3, 6, 509, 2, 1, 3008, 49, 1155, 2, 88, 1215, 1803, 924, 12, 5803, 3, 8414, 3850, 6, 1, 1529, 2, 2508, 4891, 31, 49, 778, 16, 1, 6014, 18, 87, 7, 8, 120, 98, 12530, 51, 14281, 6212, 3, 1788, 868, 24, 8, 1923, 5973, 1780, 2, 5, 24, 4020, 1653, 517, 1803, 5215, 2, 744, 114, 1, 99, 3354, 3]\n",
"[1788, 868, 24, 8, 196, 11262, 2, 19, 8, 1049, 1121, 87, 7, 933, 10800, 20, 595, 6, 21, 5, 9485, 87, 7, 156, 50, 277, 8259, 20, 337, 50, 338, 6219, 21, 2, 132, 1, 3008, 1653, 6, 1803, 5215, 49, 2146, 665, 50, 1439, 2404, 20, 216, 50, 331, 6, 21, 217, 5, 16318, 238, 17, 3, 17, 227, 50, 54, 17, 3, 17, 54, 2268, 20, 58, 17, 3, 17, 156, 50, 138, 17, 3, 17, 216, 6219, 21, 3, 258, 65, 7558, 2, 3008, 48, 8, 2970, 7845, 31, 40, 2285, 15178, 6, 286, 7, 4973, 2, 6, 8, 923, 199, 11700, 20, 13700, 21, 3, 32, 85, 3038, 107, 427, 8, 120, 18, 822, 3008, 2, 34, 11351, 7, 637, 652, 58, 50, 54, 121, 18, 1800, 2183, 3]\n",
"[1, 44, 1179, 4, 14191, 24, 1753, 19, 8, 196, 2, 10274, 1179, 4, 4783, 3, 1, 1800, 36, 24, 1, 5824, 2, 5, 47, 6579, 13859, 143, 18, 11261, 8750, 1, 65, 24, 1, 7572, 2, 31, 47, 3457, 2698, 4884, 2, 5, 24, 143, 18, 1578, 60, 15754, 1, 8750, 3, 2146, 2, 1, 294, 5758, 24, 1, 5824, 2, 5, 1, 694, 24, 1, 7572, 3]\n",
"[1, 7845, 24, 1785, 1155, 894, 2, 19, 9157, 15, 10950, 2, 5, 2357, 706, 3, 1, 924, 2210, 939, 19, 3008, 88, 1284, 38, 5803, 3, 32, 3850, 229, 2, 6, 509, 2, 1, 924, 8692, 10266, 24, 2195, 7, 8, 14462, 520, 2, 34, 1, 520, 24, 1451, 87, 16, 1, 3336, 4, 5803, 2, 6533, 1, 924, 8692, 3]\n"
]
}
],
"source": [
"myiter = iter(valid)\n",
"print(repr(next(myiter)))\n",
"print(repr(next(myiter)))\n",
"print(repr(next(myiter)))\n",
"print(repr(next(myiter)))"
]
},
{
"cell_type": "markdown",
"id": "9eee1f77",
"metadata": {},
"source": [
"We want to build a \"vocabulary\" of all the words in our corpus. To do this we need to \"tokenize\" the corpus into individual words. We could write our own simple tokenizer by splitting on spaces, etc; instead we will use one included with Torchtext:
\n",
"https://pytorch.org/text/stable/data_utils.html#get-tokenizer\n",
"\n",
"Then we build our vocabulary, assigning an index to each word:
\n",
"https://pytorch.org/text/stable/vocab.html#build-vocab-from-iterator"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "f3749e3b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ShardingFilterIterDataPipe"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"valid"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "c71bc32b",
"metadata": {},
"outputs": [],
"source": [
"from torchtext.vocab import build_vocab_from_iterator\n",
"from torchtext.data.utils import get_tokenizer\n",
"\n",
"# This tokenizer converts words to token ID numbers\n",
"tokenizer = get_tokenizer('basic_english')\n",
"\n",
"remove_titles = lambda x: x[:2]!=' ='\n",
"remove_short_sequences = lambda x: len(x)>2*WINDOW\n",
"\n",
"valid = valid.filter(remove_titles)\n",
"valid = valid.map(tokenizer)\n",
"valid = valid.filter(remove_short_sequences)\n",
"\n",
"vocab_valid = build_vocab_from_iterator(valid, min_freq=1, specials=[\"\"])\n",
"VOCAB_SIZE = len(vocab_valid)\n",
"valid = valid.map(vocab_valid.lookup_indices)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "8ed465dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of words in vocab: 16502\n",
"4819\n"
]
},
{
"data": {
"text/plain": [
"[13028, 24, 8, 4819, 3331]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(f\"Number of words in vocab: {VOCAB_SIZE}\")\n",
"\n",
"# Example of getting ID using vocab\n",
"print(\n",
" vocab_valid['cool'])\n",
"\n",
"# We can tokenize a whole sentence like so:\n",
"vocab_valid.lookup_indices(tokenizer('Josh is a cool guy'))"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "37ae23b6",
"metadata": {},
"outputs": [],
"source": [
"# Using a dataloader allows us to efficiently work with large amounts of data\n",
"loader_valid = DataLoader(\n",
" dataset=valid,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" drop_last=True,\n",
" num_workers=NUM_WORKERS,\n",
" collate_fn=lambda x: x)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "36ec416e",
"metadata": {},
"outputs": [],
"source": [
"def make_context(tokens, window):\n",
" # This gives the context words within the window around the target word. For example:\n",
" # the quick BROWN fox jumps\n",
" # If the target word is \"brown\". For a window size of 2 the context words are:\n",
" # \"the, quick, fox, jumps\"\n",
" \n",
" # The first/last \"window\" amount of words have their left/right context window instersect with the start/end\n",
" # We still want 2*window words for our context, so we take the next closest words by\n",
" # padding the start and end of a document with a \"window\" amount of words. For example:\n",
" # fox jumps | the quick brown fox jumps\n",
" # Where the \"|\" indicates the end of the padded portion. So if our window size is still 2 and the target\n",
" # is \"quick\" then we only have one word in the left window in th original sequence, but with the padded one:\n",
" # fox jumps | the QUICK brown fox jumps\n",
" # Our left window now extends into the padded region. The effect is as if our left window size is one and\n",
" # the right window size is three, but it makes it easier to code.\n",
" \n",
" padded = (tokens[window+1 : 2*window+1]\n",
" + tokens\n",
" + tokens[-(2*window+1) : -(window+1)])\n",
"\n",
" data = []\n",
" for i in range(window, len(padded) - window):\n",
" context = padded[i-window : i] + padded[i+1 : i+1+window]\n",
" data.append(context)\n",
"\n",
" return torch.tensor([tokens]), torch.tensor(data)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "063757bf",
"metadata": {},
"outputs": [],
"source": [
"class CBOW(torch.nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim):\n",
" super(CBOW, self).__init__()\n",
" \n",
" self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
" self.linear = nn.Linear(embedding_dim, vocab_size)\n",
" self.activation_function = nn.LogSoftmax(dim = -1)\n",
" \n",
" self.embeddings.weight.data.uniform_(-1,1)\n",
"\n",
" def forward(self, inputs):\n",
" embeds = torch.sum(self.embeddings(inputs), dim=1)\n",
" out = self.linear(embeds)\n",
" out = self.activation_function(out)\n",
" return out\n",
"\n",
" def get_word_emdedding(self, word):\n",
" word = torch.tensor([word_to_ix[word]])\n",
" return self.embeddings(word).view(1,-1)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "03dcc99f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OOOOOOO\n",
"[1, 3808, 143, 18, 349, 9, 362, 5, 14532, 12, 11, 288, 3, 1, 7628, 581, 16, 1, 11975, 4, 699, 18, 417, 25, 6421, 7, 9480, 112, 5, 25, 143, 14, 4451, 2, 851, 19, 94, 16299, 1695, 7628, 3, 1, 236, 8709, 25, 15626, 19, 3966, 5, 2479, 22, 8, 8016, 15, 10828, 4426, 51, 8, 323, 4, 10710, 15, 12047, 1, 4451, 3]\n"
]
},
{
"ename": "NameError",
"evalue": "name 'doc_id' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[23], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOOOOOOO\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(sample)\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mdoc_id\u001b[49m\u001b[38;5;241m==\u001b[39m\u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
"\u001b[0;31mNameError\u001b[0m: name 'doc_id' is not defined"
]
}
],
"source": [
"for idx, batch in enumerate(loader_valid):\n",
" for sample in batch:\n",
" print('OOOOOOO')\n",
" print(sample)\n",
" if doc_id==0:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "dc9b8022",
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "expected sequence of length 145 at dim 2 (got 74)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[38], line 11\u001b[0m\n\u001b[1;32m 9\u001b[0m total_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(loader_valid):\n\u001b[0;32m---> 11\u001b[0m targets, context \u001b[38;5;241m=\u001b[39m \u001b[43mmake_context\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mWINDOW\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m targets \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(targets, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mlong)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 13\u001b[0m context \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(context, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mlong)\u001b[38;5;241m.\u001b[39mto(device)\n",
"Cell \u001b[0;32mIn[36], line 26\u001b[0m, in \u001b[0;36mmake_context\u001b[0;34m(tokens, window)\u001b[0m\n\u001b[1;32m 23\u001b[0m context \u001b[38;5;241m=\u001b[39m padded[i\u001b[38;5;241m-\u001b[39mwindow : i] \u001b[38;5;241m+\u001b[39m padded[i\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m : i\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;241m+\u001b[39mwindow]\n\u001b[1;32m 24\u001b[0m data\u001b[38;5;241m.\u001b[39mappend(context)\n\u001b[0;32m---> 26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m, torch\u001b[38;5;241m.\u001b[39mtensor(data)\n",
"\u001b[0;31mValueError\u001b[0m: expected sequence of length 145 at dim 2 (got 74)"
]
}
],
"source": [
"model = CBOW(VOCAB_SIZE, EMBED_DIM).to(device)\n",
"loss_function = nn.NLLLoss()\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)\n",
"\n",
"latch = True\n",
"#TRAINING\n",
"mytime = time.time()\n",
"for epoch in range(EPOCHS):\n",
" total_loss = 0\n",
" for idx, batch in enumerate(loader_valid):\n",
" targets, context = make_context(batch, WINDOW)\n",
" targets = torch.tensor(targets, dtype=torch.long).to(device)\n",
" context = torch.tensor(context, dtype=torch.long).to(device)\n",
" log_probs = model(context)\n",
" total_loss += loss_function(log_probs, targets)\n",
" if idx%100==0:\n",
" print(idx, total_loss.item())\n",
" print(epoch, total_loss.item(), time.time()-mytime)\n",
" \n",
" #optimize at the end of each epoch\n",
" optimizer.zero_grad()\n",
" total_loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "108a3c69",
"metadata": {},
"outputs": [],
"source": [
"#TESTING\n",
"# widely caught using lobster pots\n",
"#Atlantic Ocean , Mediterranean Sea and\n",
"#widely caught lobster pots\n",
"context = 'Atlantic Ocean Sea and'\n",
"context_vector = v1.lookup_indices(tokenizer(context))\n",
"a = model(torch.tensor(context_vector, dtype=torch.long).to(device))\n",
"val,ind = torch.topk(a,10)\n",
"v1.lookup_tokens(ind.tolist()[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe3c3d54",
"metadata": {},
"outputs": [],
"source": [
"def mysim(a,b):\n",
" aa = model.embeddings(torch.tensor(v1[a], dtype=torch.long).to(device))\n",
" bb = model.embeddings(torch.tensor(v1[b], dtype=torch.long).to(device))\n",
" sim = torch.nn.CosineSimilarity(dim=0, eps=1e-6)\n",
" return sim(aa,bb).item()\n",
"\n",
"mysim('lobster','josh')\n",
" "
]
},
{
"cell_type": "markdown",
"id": "3d05fe9b",
"metadata": {},
"source": [
"# TODO\n",
"* Create a counter from collections to get word frequency\n",
"* Use the counter to create a vocab\n",
"* Use the word frequency for subsampling to improve performance (use subsampling distribution from word2vec paper) by sampling common words less frequently"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fab6ab40",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}