Files
arxiv-sanity-lite/serve.py
T
2021-11-14 21:49:34 -08:00

299 lines
8.9 KiB
Python

"""
Flask server backend
ideas:
- allow delete of tags
- unify all different pages into single search filter sort interface
- special single-image search just for paper similarity
"""
import time
import pickle
from random import shuffle
import numpy as np
from sklearn import svm
from flask import Flask, request, redirect, url_for
from flask import render_template
from flask import g # global session-level object
from aslite.db import get_papers_db, get_metas_db, get_tags_db
# -----------------------------------------------------------------------------
# TODO: user accounts / password login are necessary...
app = Flask(__name__)
RET_NUM = 100 # number of papers to return per page
def get_tags():
if not hasattr(g, '_tags'):
user = 'root' # root for now, the only default user
print("reading tags for user %s" % (user, ))
with get_tags_db() as tags_db:
tags_dict = tags_db[user] if user in tags_db else {}
g._tags = tags_dict
return g._tags
def get_papers():
if not hasattr(g, '_pdb'):
g._pdb = get_papers_db()
return g._pdb
def get_metas():
if not hasattr(g, '_mdb'):
g._mdb = get_metas_db()
return g._mdb
def render_pids(pids):
pdb = get_papers()
tags = get_tags()
papers = []
for pid in pids:
d = pdb[pid]
ptags = [t for t, pids in tags.items() if pid in pids]
papers.append({
'weight': 0.0,
'id': d['_id'],
'title': d['title'],
'time': d['_time_str'],
'authors': ', '.join(a['name'] for a in d['authors']),
'tags': ', '.join(t['term'] for t in d['tags']),
'utags': ptags,
'summary': d['summary'],
})
return papers
def random_rank():
pdb = get_papers()
pids = list(pdb.keys())
shuffle(pids)
scores = [0 for _ in pids]
return pids, scores
def time_rank():
mdb = get_metas()
ms = sorted(mdb.items(), key=lambda kv: kv[1]['_time'], reverse=True)
tnow = time.time()
pids = [k for k, v in ms]
scores = [(tnow - v['_time'])/60/60/24 for k, v in ms] # time delta in days
return pids, scores
def svm_rank(tags=None, pid=None):
# tag can be one tag or a few comma-separated tags or 'all' for all tags we have in db
# pid can be a specific paper id to set as positive for a kind of nearest neighbor search
assert (tags is not None) or (pid is not None)
# load all of the features
features = pickle.load(open('features.p', 'rb'))
x, pids = features['x'], features['pids']
n, d = x.shape
ptoi, itop = {}, {}
for i, p in enumerate(pids):
ptoi[p] = i
itop[i] = p
# construct the positive set
y = np.zeros(n, dtype=np.float32)
if pid is not None:
y[ptoi[pid]] = 1.0
elif tags is not None:
tags_db = get_tags()
tags_filter_to = tags_db.keys() if tags == 'all' else set(tags.split(','))
for tag, pids in tags_db.items():
if tag in tags_filter_to:
for pid in pids:
y[ptoi[pid]] = 1.0
if y.sum() == 0:
return [], [] # there are no positives?
# classify
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(x, y)
s = clf.decision_function(x)
sortix = np.argsort(-s)
pids = [itop[ix] for ix in sortix]
scores = [100*float(s[ix]) for ix in sortix]
return pids, scores
def default_context(papers, **kwargs):
context = {}
# insert the papers
context['papers'] = papers
# fetch and insert the available tags
tags = get_tags()
context['tags'] = [{'name':t, 'n':len(pids)} for t, pids in tags.items()] + [{'name': 'all'}]
# various other globals
gvars = {}
gvars['search_query'] = ''
gvars['time_filter'] = ''
gvars['message'] = 'default_message'
context['gvars'] = gvars
return context
# -----------------------------------------------------------------------------
@app.teardown_request
def close_connection(error=None):
# close any opened database connections
if hasattr(g, '_pdb'):
g._pdb.close()
if hasattr(g, '_mdb'):
g._mdb.close()
# -----------------------------------------------------------------------------
@app.route('/', methods=['GET'])
def main():
# GET options
opt_rank = request.args.get('rank', 'time') # rank type. tags|pid|time|random
opt_tags = request.args.get('tags', 'all') # tags to rank by if opt_rank == 'tag'
opt_pid = request.args.get('pid', None) # pid to find nearest neighbors to
opt_time_filter = request.args.get('time_filter', '') # number of days to filter by
opt_skip_have = request.args.get('skip_have', 'no') # hide papers we already have?
# rank papers: by tags, by time, by random
if opt_rank in ['tags', 'pid']:
pids, scores = svm_rank(tags=opt_tags, pid=opt_pid)
elif opt_rank == 'time':
pids, scores = time_rank()
elif opt_rank == 'random':
pids, scores = random_rank()
else:
raise ValueError("opt_rank %s is not a thing" % (opt_rank, ))
# filter by time
if opt_time_filter:
pdb = get_papers()
tnow = time.time()
deltat = int(opt_time_filter)*60*60*24 # allowed time delta in seconds
keep = [i for i,pid in enumerate(pids) if (tnow - pdb[pid]['_time']) < deltat]
pids, scores = [pids[i] for i in keep], [scores[i] for i in keep]
# optionally hide papers we already have
if opt_skip_have == 'yes':
tags_db = get_tags()
have = set().union(*tags_db.values())
keep = [i for i,pid in enumerate(pids) if pid not in have]
pids, scores = [pids[i] for i in keep], [scores[i] for i in keep]
# crop
pids = pids[:min(len(pids), RET_NUM)]
papers = render_pids(pids)
for i, p in enumerate(papers):
p['weight'] = float(scores[i])
context = default_context(papers)
context['gvars']['rank'] = opt_rank
context['gvars']['tags'] = opt_tags
context['gvars']['time_filter'] = opt_time_filter
return render_template('index.html', **context)
@app.route("/search", methods=['GET'])
def search():
q = request.args.get('q', '') # get the search request
if not q:
return redirect(url_for('main')) # if someone just hits enter with empty field
qs = q.lower().strip().split() # split by spaces
match = lambda s: sum(s.lower().count(qp) for qp in qs)
pairs = []
pdb = get_papers()
for pid, p in pdb.items():
score = 0.0
score += 5.0 * match(' '.join([a['name'] for a in p['authors']]))
score += 10.0 * match(p['title'])
score += 1.0 * match(p['summary'])
if score > 0:
pairs.append((score, pid))
pairs.sort(reverse=True)
pids = [p[1] for p in pairs]
pids = pids[:min(RET_NUM, len(pids))] # crop if needed
papers = render_pids(pids)
for i, p in enumerate(papers):
p['weight'] = pairs[i][0]
context = default_context(papers)
context['gvars']['search_query'] = q
return render_template('index.html', **context)
@app.route('/add/<pid>/<tag>')
def add(pid=None, tag=None):
user = 'root'
with get_tags_db(flag='c') as tags_db:
# create the user if we don't know about them yet with an empty library
if not user in tags_db:
tags_db[user] = {}
# fetch the user library object
d = tags_db[user]
# add the paper to the tag
if tag not in d:
d[tag] = set()
d[tag].add(pid)
# write back to database
tags_db[user] = d
print("added paper %s to tag %s for user %s" % (pid, tag, user))
return "ok: " + str(d) # return back the user library for debugging atm
@app.route('/sub/<pid>/<tag>')
def sub(pid=None, tag=None):
user = 'root'
with get_tags_db(flag='c') as tags_db:
# if the user doesn't have any tags, there is nothing to do
if not user in tags_db:
return "user has no library of tags ¯\_(ツ)_/¯"
# fetch the user library object
d = tags_db[user]
# add the paper to the tag
if tag not in d:
return "user doesn't have the tag %s" % (tag, )
else:
d[tag].remove(pid)
# write back to database
tags_db[user] = d
print("removed from paper %s the tag %s for user %s" % (pid, tag, user))
return "ok: " + str(d) # return back the user library for debugging atm
@app.route('/del/<tag>')
def delete_tag(tag=None):
user = 'root'
with get_tags_db(flag='c') as tags_db:
if user not in tags_db:
return "user does not have a library"
d = tags_db[user]
if tag not in d:
return "user does not have this tag"
# delete the tag
del d[tag]
# write back to database
tags_db[user] = d
print("deleted tag %s for user %s" % (tag, user))
return "ok: " + str(d) # return back the user library for debugging atm