{ "cells": [ { "cell_type": "code", "execution_count": 345, "id": "40f8cabd", "metadata": {}, "outputs": [], "source": [ "from sentence_transformers import SentenceTransformer, util\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 377, "id": "1fa548b5", "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('./data/biden.csv', header=None)" ] }, { "cell_type": "code", "execution_count": 366, "id": "1e99a71e", "metadata": {}, "outputs": [], "source": [ "examples_text = data.loc[0:4][5]\n", "# In current *.csv files the example tweets start with some number\n", "examples_text = examples_text.str.split(n=1).str[1]\n", "\n", "# ST wants indexing to start at 0, so we reset the index and drop the old indices\n", "tweets_text = data.loc[5:][5].reset_index(drop=True)\n", "tweets_text = pd.concat([examples_text, tweets_text], ignore_index=True)" ] }, { "cell_type": "code", "execution_count": 370, "id": "56bd6624", "metadata": {}, "outputs": [], "source": [ "# path to the folder that contains the config.json file\n", "# This uses the following HF model:\n", "# https://huggingface.co/sentence-transformers/all-mpnet-base-v2\n", "model = SentenceTransformer('/projectnb/jbrcs/tweet2/models/mpnet')" ] }, { "cell_type": "code", "execution_count": 371, "id": "0634dc33", "metadata": {}, "outputs": [], "source": [ "# Calculate sentence embeddings for examples and tweets\n", "examples = model.encode(examples_text)\n", "tweets = model.encode(tweets_text)" ] }, { "cell_type": "code", "execution_count": 372, "id": "19dbc735", "metadata": {}, "outputs": [], "source": [ "similarity = util.cos_sim(examples, tweets)\n", "vals,inds = similarity.topk(5,dim=1)" ] }, { "cell_type": "code", "execution_count": 374, "id": "69788e18", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Wow. Joe Biden resigned? Awesome.\n", "tensor([1.0000, 0.8618, 0.8504, 0.7893, 0.7893])\n", "['Wow. Joe Biden resigned? Awesome.' 'Joe Biden has resigned'\n", " 'Biden resigned ?? Did I miss something?' 'Biden resigned as POTUS 😊'\n", " 'Biden resigned as POTUS 😊']\n" ] } ], "source": [ "num=2\n", "print(examples_text[num])\n", "print(vals[num,:])\n", "print(tweets_text[inds[num].numpy()].to_numpy())" ] }, { "cell_type": "code", "execution_count": 378, "id": "ce434dae", "metadata": {}, "outputs": [], "source": [ "data = pd.concat([data,pd.DataFrame(similarity.numpy().T)],axis=1)" ] }, { "cell_type": "code", "execution_count": 381, "id": "19ddb71b", "metadata": {}, "outputs": [], "source": [ "data.to_csv('biden2.csv', index=False, header=False)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.5" } }, "nbformat": 4, "nbformat_minor": 5 }