# TODO # https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp # compile with PyTorch 2+ # SGD -> momentum # other optimizers # negative sampling (embedding for second layer, by giving indices for target and negative samples) # compute perplexity based on subsampling to be able to more fairly compare models with different subsampling # I wanted to evaluate the quality of my trained word embeddings by evaluating them against a word similarity dataset, like the Stanford Rare Word Similarity dataset. # 8192 samples # V100 1 89 | 4 55 | 8 51 | 12+ OoM # 8192 batches * 8 samples # V100 1e-3 290 # A100 350 samples/s (32 batch, t=1e-3, 300 dim, 2 window, Sparse, default weight init) RESUME = False CHECKPOINT_FILE = "EMBED10.pt" CHECKPOINT = True NUM_WORKERS = 1 BATCH_SIZE = 64 # Optimal: CPU-16, A40-16, A100-32 V100-8 EPOCHS = 50 LEARNING_RATE = 0.002 EMBED_DIM = 300 WINDOW = 5 SUBSAMPLE_T = 0.001 SPARSE = False NAME = "EMBED10_log" TOT_SAMPLES = 786200 FREQ = 0.1 LOG_INTERVAL = int(FREQ * TOT_SAMPLES/BATCH_SIZE) DATA_DIR = "/projectnb/jbrcs/word2vec/data/" import time import torch import torch.nn as nn from torch.utils.data import DataLoader from torchtext.datasets import WikiText103 from torchtext.vocab import vocab as build_vocab from torchtext.data.utils import get_tokenizer from collections import Counter from math import sqrt from random import random, seed torch.manual_seed(13373435) seed(14921986) def clean_data(data, tokenizer, window): remove_titles = lambda x: x[:2]!=' =' data = data.filter(remove_titles) # Change x.filter to split desired data = data.map(tokenizer) remove_short_sequences = lambda x: len(x)>2*window data = data.filter(remove_short_sequences) return data def preprocess_WikiText103(directory, tokenizer, vocab, word_freq, window, sub_t): train, valid, test = WikiText103(root=directory, split= ('train', 'valid', 'test')) [train, valid, test] = [clean_data(x,tokenizer,window) for x in [train, valid, test]] all_data = train.concat(valid).concat(test) #all_data = valid.concat(test) # leave out train for speed if vocab==None: word_freq = Counter() for text in all_data: word_freq.update(text) vocab = build_vocab(word_freq, min_freq=1, specials=[""]) vocab.set_default_index(vocab[""]) # Subsampling, see: https://github.com/tmikolov/word2vec/blob/20c129af10659f7c50e86e3be406df663beff438/word2vec.c#L407 threshold = word_freq.total() * sub_t probs = lambda x: (sqrt(x/threshold) + 1) * threshold/x subsample_probs = dict((vocab[k], probs(v)) for (k,v) in word_freq.items()) # we need to map word freq to token subsampling prob subsample_probs[vocab[""]]=-1 # we never want to sample unknowns [train, valid, test] = [x.map(vocab.lookup_indices) for x in [train, valid, test]] loader_train = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=lambda x: x) loader_valid = DataLoader(valid, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=lambda x: x) loader_test = DataLoader(test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=lambda x: x) return vocab, word_freq, subsample_probs, loader_train, loader_valid, loader_test def make_context(batch, window): data = [] for tokens in batch: tokens = (tokens[window+1 : 2*window+1] + tokens + tokens[-(2*window+1) : -(window+1)]) for i in range(window, len(tokens) - window): context = tokens[i-window : i] + tokens[i+1 : i+1+window] data.append(context) return torch.tensor(data) class CBOW(torch.nn.Module): def __init__(self, vocab_size, embedding_dim): super(CBOW, self).__init__() self.embeddings = nn.Embedding(vocab_size, embedding_dim, sparse=SPARSE) self.linear = nn.Linear(embedding_dim, vocab_size, bias=False) torch.nn.init.uniform_(self.linear.weight, a=-0.003, b=0.003) torch.nn.init.normal_(self.embeddings.weight, mean=0.0, std=10) #0.033 marginal #torch.nn.init.normal_(self.linear.bias, mean=0.0, std=0.03) def forward(self, inputs): embeds = torch.sum(self.embeddings(inputs), dim=1) out = self.linear(embeds) return out def measure_accuracy(output, targets): return (output.argmax(dim=1)==targets).float().sum() def flatten(lol): return [x for xs in lol for x in xs] def validation(model, loader_valid, subsample_probs, window): hits = 0 evals = 0 model.eval() with torch.inference_mode(): for idx, batch in enumerate(loader_valid): batch = subsample(batch, subsample_probs, window) targets = torch.tensor(flatten(batch)).to(device) context = make_context(batch, window).to(device) output = model(context) hits += measure_accuracy(output, targets) evals += len(targets) accuracy = (hits/evals).item() model.train() return accuracy def subsample(batch, subsample_probs, window): lol = [] for sample in batch: sub = [x for x in sample if subsample_probs[x] > random()] if len(sub) > 2*window: # make sure subsampled result is long enough still lol.append(sub) return lol device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using {device} for computation") print("Preprocessing") tokenizer = get_tokenizer('basic_english') vocab = None; word_freq = None if RESUME: if device.type=='cpu': checkpoint = torch.load(CHECKPOINT_FILE, map_location=torch.device('cpu')) else: checkpoint = torch.load(CHECKPOINT_FILE) vocab = checkpoint['vocab'] word_freq = checkpoint['word_freq'] vocab, word_freq, subsample_probs, loader_train, loader_valid, loader_test =\ preprocess_WikiText103(DATA_DIR, tokenizer, vocab, word_freq, WINDOW, SUBSAMPLE_T) VOCAB_SIZE = len(vocab) print("Sending model") model = CBOW(VOCAB_SIZE, EMBED_DIM).to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE) if RESUME: model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) f = open(NAME+".txt", "a") f.write(f"{BATCH_SIZE}, {LEARNING_RATE}, {SUBSAMPLE_T}, {WINDOW}, {SPARSE}\n") model.train() total_samples = 0 start_time = time.time() for epoch in range(EPOCHS): for idx, batch in enumerate(loader_train): total_samples += len(batch) batch = subsample(batch, subsample_probs, WINDOW) if len(batch)==0: # Possible for large windows and short samples continue targets = torch.tensor(flatten(batch)).to(device) context = make_context(batch, WINDOW).to(device) output = model(context) loss = criterion(output, targets) optimizer.zero_grad() loss.backward() optimizer.step() if idx%LOG_INTERVAL==0: acc = measure_accuracy(output,targets).item()/len(targets) f.write(f"{idx}, {total_samples}, {loss.item():.4f}, {acc:.3f}, {time.time()-start_time:.1f}\n") f.flush() valid_accuracy = validation(model, loader_valid, subsample_probs, WINDOW) f.write(f"Epoch: {epoch}, {total_samples}, {valid_accuracy}, {time.time()-start_time}\n") f.flush() if CHECKPOINT: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'vocab': vocab, 'word_freq': word_freq }, CHECKPOINT_FILE) f.close()