From 13a1d5ff48d9c4dc341df9070bb73736c0801a57 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 12 Nov 2021 21:12:09 -0800 Subject: [PATCH] sequester gross details about database instantiation in the filesystem away from the scripts --- arxiv_daemon.py | 9 ++++----- aslite/db.py | 26 ++++++++++++++++++++++++++ compute.py | 4 ++-- serve.py | 30 ++++++++++++++---------------- 4 files changed, 46 insertions(+), 23 deletions(-) diff --git a/arxiv_daemon.py b/arxiv_daemon.py index 0f6876e..c00bfdc 100644 --- a/arxiv_daemon.py +++ b/arxiv_daemon.py @@ -1,7 +1,7 @@ """ This script is intended to wake up every 30 min or so (eg via cron), it checks for any new arxiv papers via the arxiv API and stashes -them into a sqlite database papers.db +them into a sqlite database. """ import sys @@ -11,7 +11,7 @@ import logging import argparse from aslite.arxiv import get_response, parse_response -from aslite.db import SqliteDict, CompressedSqliteDict +from aslite.db import get_papers_db, get_metas_db if __name__ == '__main__': @@ -25,9 +25,8 @@ if __name__ == '__main__': # query string of papers to look for q = 'cat:cs.CV+OR+cat:cs.LG+OR+cat:cs.CL+OR+cat:cs.AI+OR+cat:cs.NE+OR+cat:cs.RO' - # flag='c': default mode, open for read/write, creating the db/table if necessary. - pdb = CompressedSqliteDict('papers.db', tablename='papers', flag='c', autocommit=True) - mdb = SqliteDict('papers.db', tablename='metas', flag='c', autocommit=True) + pdb = get_papers_db(flag='c', autocommit=True) + mdb = get_metas_db(flag='c', autocommit=True) prevn = len(pdb) def store(p): diff --git a/aslite/db.py b/aslite/db.py index 11bc0fd..39a7e97 100644 --- a/aslite/db.py +++ b/aslite/db.py @@ -19,3 +19,29 @@ class CompressedSqliteDict(SqliteDict): return pickle.loads(zlib.decompress(bytes(obj))) super().__init__(*args, **kwargs, encode=encode, decode=decode) + +# ----------------------------------------------------------------------------- + +""" +some docs to self: +flag='c': default mode, open for read/write, and creating the db/table if necessary +flag='r': open for read-only +""" + +PAPERS_DB_FILE = 'papers.db' # stores info about papers, and also their lighter-weight metadata +DICT_DB_FILE = 'dict.db' # stores account-relevant info, like which tags exist for which papers + +def get_papers_db(flag='r', autocommit=True): + assert flag in ['r', 'c'] + pdb = CompressedSqliteDict(PAPERS_DB_FILE, tablename='papers', flag=flag, autocommit=autocommit) + return pdb + +def get_metas_db(flag='r', autocommit=True): + assert flag in ['r', 'c'] + mdb = SqliteDict(PAPERS_DB_FILE, tablename='metas', flag=flag, autocommit=autocommit) + return mdb + +def get_tags_db(flag='r', autocommit=True): + assert flag in ['r', 'c'] + ddb = CompressedSqliteDict(DICT_DB_FILE, tablename='tags', flag=flag, autocommit=autocommit) + return ddb diff --git a/compute.py b/compute.py index 986c1fe..5460455 100644 --- a/compute.py +++ b/compute.py @@ -10,7 +10,7 @@ import argparse import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer -from aslite.db import SqliteDict, CompressedSqliteDict +from aslite.db import get_papers_db # ----------------------------------------------------------------------------- @@ -31,7 +31,7 @@ if __name__ == '__main__': norm='l2', use_idf=True, smooth_idf=True, sublinear_tf=True, max_df=args.max_df, min_df=args.min_df) - pdb = CompressedSqliteDict('papers.db', tablename='papers', flag='r') + pdb = get_papers_db(flag='r') def make_corpus(): for p, d in pdb.items(): diff --git a/serve.py b/serve.py index c003c81..4a54be6 100644 --- a/serve.py +++ b/serve.py @@ -18,7 +18,7 @@ from flask import Flask, request, redirect, url_for from flask import render_template from flask import g # global session-level object -from aslite.db import SqliteDict, CompressedSqliteDict +from aslite.db import get_papers_db, get_metas_db, get_tags_db # ----------------------------------------------------------------------------- # TODO: user accounts / password login are necessary... @@ -30,19 +30,19 @@ def get_tags(): if not hasattr(g, '_tags'): user = 'root' # root for now, the only default user print("reading tags for user %s" % (user, )) - with CompressedSqliteDict('dict.db', tablename='tags', flag='r') as dict_db: - tags_dict = dict_db[user] if user in dict_db else {} + with get_tags_db() as tags_db: + tags_dict = tags_db[user] if user in tags_db else {} g._tags = tags_dict return g._tags def get_papers(): if not hasattr(g, '_pdb'): - g._pdb = CompressedSqliteDict('papers.db', tablename='papers', flag='r') + g._pdb = get_papers_db() return g._pdb def get_metas(): if not hasattr(g, '_mdb'): - g._mdb = SqliteDict('papers.db', tablename='metas', flag='r') + g._mdb = get_metas_db() return g._mdb def render_pids(pids): @@ -231,14 +231,14 @@ def search(): @app.route('/add//') def add(pid=None, tag=None): user = 'root' - with CompressedSqliteDict('dict.db', tablename='tags', flag='c') as dict_db: + with get_tags_db(flag='c') as tags_db: # create the user if we don't know about them yet with an empty library - if not user in dict_db: - dict_db[user] = {} + if not user in tags_db: + tags_db[user] = {} # fetch the user library object - d = dict_db[user] + d = tags_db[user] # add the paper to the tag if tag not in d: @@ -246,8 +246,7 @@ def add(pid=None, tag=None): d[tag].add(pid) # write back to database - dict_db[user] = d - dict_db.commit() + tags_db[user] = d print("added paper %s to tag %s for user %s" % (pid, tag, user)) return "ok: " + str(d) # return back the user library for debugging atm @@ -255,12 +254,12 @@ def add(pid=None, tag=None): @app.route('/del/') def delete_tag(tag=None): user = 'root' - with CompressedSqliteDict('dict.db', tablename='tags', flag='c') as dict_db: + with get_tags_db() as tags_db: - if user not in dict_db: + if user not in tags_db: return "user does not have a library" - d = dict_db[user] + d = tags_db[user] if tag not in d: return "user does not have this tag" @@ -269,8 +268,7 @@ def delete_tag(tag=None): del d[tag] # write back to database - dict_db[user] = d - dict_db.commit() + tags_db[user] = d print("deleted tag %s for user %s" % (tag, user)) return "ok: " + str(d) # return back the user library for debugging atm