{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "5d31b5dc",
"metadata": {},
"outputs": [],
"source": [
"NUM_WORKERS = 1\n",
"BATCH_SIZE = 1\n",
"EPOCHS = 10\n",
"LEARNING_RATE = 0.002\n",
"EMBED_DIM = 300\n",
"WINDOW = 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e674452f",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed399515",
"metadata": {},
"outputs": [],
"source": [
"# Device configuration\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"id": "aadad342",
"metadata": {},
"source": [
"# Environment\n",
"We are running this on the SCC using the `python3/3.10.12` and `pytorch/1.13.1` modules. "
]
},
{
"cell_type": "markdown",
"id": "1ba2b3cb",
"metadata": {},
"source": [
"# Corpus for training\n",
"\"The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia.\"\n",
"https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/\n",
"\n",
"This is available to use from the Torchtext package.\n",
"\n",
"The `pytorch/1.13.1` module comes with a version of `torchtext`, but apparently this requires `torchdata` to run. According to:
\n",
"https://github.com/pytorch/data#installation
\n",
"for PyTorch 1.13.1 we should use Torchdata 0.5.1\n",
"\n",
"`pip install --no-cache-dir --prefix=/projectnb/jbrcs/word2vec/packages torchdata==0.5.1`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "993b347f",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from torchtext.datasets import WikiText103\n",
"\n",
"train, valid, test = WikiText103(root='./data', split= ('train', 'valid', 'test'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8bd8e9f",
"metadata": {},
"outputs": [],
"source": [
"valid"
]
},
{
"cell_type": "markdown",
"id": "f84cd619",
"metadata": {},
"source": [
"This is a \"DataPipe\" which allows us to stream in data efficiently to our program. We can make an iterator using the DataPipe if we want:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da63b4db",
"metadata": {},
"outputs": [],
"source": [
"myiter = iter(valid)\n",
"print(repr(next(myiter)))\n",
"print(repr(next(myiter)))\n",
"print(repr(next(myiter)))\n",
"print(repr(next(myiter)))"
]
},
{
"cell_type": "markdown",
"id": "9eee1f77",
"metadata": {},
"source": [
"We want to build a \"vocabulary\" of all the words in our corpus. To do this we need to \"tokenize\" the corpus into individual words. We could write our own simple tokenizer by splitting on spaces, etc; instead we will use one included with Torchtext:
\n",
"https://pytorch.org/text/stable/data_utils.html#get-tokenizer\n",
"\n",
"Then we build our vocabulary, assigning an index to each word:
\n",
"https://pytorch.org/text/stable/vocab.html#build-vocab-from-iterator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c71bc32b",
"metadata": {},
"outputs": [],
"source": [
"from torchtext.vocab import build_vocab_from_iterator\n",
"from torchtext.data.utils import get_tokenizer\n",
"\n",
"# This tokenizer converts words to token ID numbers\n",
"tokenizer = get_tokenizer('basic_english')\n",
"\n",
"remove_titles = lambda x: x[:2]!=' ='\n",
"remove_short_sequences = lambda x: len(x)>2*WINDOW\n",
"\n",
"valid = valid.filter(remove_titles)\n",
"valid = valid.map(tokenizer)\n",
"valid = valid.filter(remove_short_sequences)\n",
"\n",
"vocab_valid = build_vocab_from_iterator(valid, min_freq=1, specials=[\"\"])\n",
"VOCAB_SIZE = len(vocab_valid)\n",
"valid = valid.map(vocab_valid.lookup_indices)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ed465dd",
"metadata": {},
"outputs": [],
"source": [
"print(f\"Number of words in vocab: {VOCAB_SIZE}\")\n",
"\n",
"# Example of getting ID using vocab\n",
"print(\n",
" vocab_valid['cool'])\n",
"\n",
"# We can tokenize a whole sentence like so:\n",
"vocab_valid.lookup_indices(tokenizer('Josh is a cool guy'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37ae23b6",
"metadata": {},
"outputs": [],
"source": [
"# Using a dataloader allows us to efficiently work with large amounts of data\n",
"loader_valid = DataLoader(\n",
" dataset=valid,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" drop_last=True,\n",
" num_workers=NUM_WORKERS,\n",
" collate_fn=lambda x: x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36ec416e",
"metadata": {},
"outputs": [],
"source": [
"def make_context(tokens, window):\n",
" # This gives the context words within the window around the target word. For example:\n",
" # the quick BROWN fox jumps\n",
" # If the target word is \"brown\". For a window size of 2 the context words are:\n",
" # \"the, quick, fox, jumps\"\n",
" \n",
" # The first/last \"window\" amount of words have their left/right context window instersect with the start/end\n",
" # We still want 2*window words for our context, so we take the next closest words by\n",
" # padding the start and end of a document with a \"window\" amount of words. For example:\n",
" # fox jumps | the quick brown fox jumps\n",
" # Where the \"|\" indicates the end of the padded portion. So if our window size is still 2 and the target\n",
" # is \"quick\" then we only have one word in the left window in th original sequence, but with the padded one:\n",
" # fox jumps | the QUICK brown fox jumps\n",
" # Our left window now extends into the padded region. The effect is as if our left window size is one and\n",
" # the right window size is three, but it makes it easier to code.\n",
" \n",
" padded = (tokens[window+1 : 2*window+1]\n",
" + tokens\n",
" + tokens[-(2*window+1) : -(window+1)])\n",
"\n",
" data = []\n",
" for i in range(window, len(padded) - window):\n",
" context = padded[i-window : i] + padded[i+1 : i+1+window]\n",
" data.append(context)\n",
"\n",
" return torch.tensor([tokens]), torch.tensor(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "063757bf",
"metadata": {},
"outputs": [],
"source": [
"class CBOW(torch.nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim):\n",
" super(CBOW, self).__init__()\n",
" \n",
" self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
" self.linear = nn.Linear(embedding_dim, vocab_size)\n",
" self.activation_function = nn.LogSoftmax(dim = -1)\n",
" \n",
" self.embeddings.weight.data.uniform_(-1,1)\n",
"\n",
" def forward(self, inputs):\n",
" embeds = torch.sum(self.embeddings(inputs), dim=1)\n",
" out = self.linear(embeds)\n",
" out = self.activation_function(out)\n",
" return out\n",
"\n",
" def get_word_emdedding(self, word):\n",
" word = torch.tensor([word_to_ix[word]])\n",
" return self.embeddings(word).view(1,-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03dcc99f",
"metadata": {},
"outputs": [],
"source": [
"for idx, batch in enumerate(loader_valid):\n",
" for sample in batch:\n",
" print('OOOOOOO')\n",
" print(sample)\n",
" if doc_id==0:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc9b8022",
"metadata": {},
"outputs": [],
"source": [
"model = CBOW(VOCAB_SIZE, EMBED_DIM).to(device)\n",
"loss_function = nn.NLLLoss()\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)\n",
"\n",
"latch = True\n",
"#TRAINING\n",
"mytime = time.time()\n",
"for epoch in range(EPOCHS):\n",
" total_loss = 0\n",
" for idx, batch in enumerate(loader_valid):\n",
" targets, context = make_context(batch, WINDOW)\n",
" targets = torch.tensor(targets, dtype=torch.long).to(device)\n",
" context = torch.tensor(context, dtype=torch.long).to(device)\n",
" log_probs = model(context)\n",
" total_loss += loss_function(log_probs, targets)\n",
" if idx%100==0:\n",
" print(idx, total_loss.item())\n",
" print(epoch, total_loss.item(), time.time()-mytime)\n",
" \n",
" #optimize at the end of each epoch\n",
" optimizer.zero_grad()\n",
" total_loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "108a3c69",
"metadata": {},
"outputs": [],
"source": [
"#TESTING\n",
"# widely caught using lobster pots\n",
"#Atlantic Ocean , Mediterranean Sea and\n",
"#widely caught lobster pots\n",
"context = 'Atlantic Ocean Sea and'\n",
"context_vector = v1.lookup_indices(tokenizer(context))\n",
"a = model(torch.tensor(context_vector, dtype=torch.long).to(device))\n",
"val,ind = torch.topk(a,10)\n",
"v1.lookup_tokens(ind.tolist()[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe3c3d54",
"metadata": {},
"outputs": [],
"source": [
"def mysim(a,b):\n",
" aa = model.embeddings(torch.tensor(v1[a], dtype=torch.long).to(device))\n",
" bb = model.embeddings(torch.tensor(v1[b], dtype=torch.long).to(device))\n",
" sim = torch.nn.CosineSimilarity(dim=0, eps=1e-6)\n",
" return sim(aa,bb).item()\n",
"\n",
"mysim('lobster','josh')\n",
" "
]
},
{
"cell_type": "markdown",
"id": "3d05fe9b",
"metadata": {},
"source": [
"# TODO\n",
"* Create a counter from collections to get word frequency\n",
"* Use the counter to create a vocab\n",
"* Use the word frequency for subsampling to improve performance (use subsampling distribution from word2vec paper) by sampling common words less frequently"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fab6ab40",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}