Skip to content

Commit

Permalink
Optimizations (#132)
Browse files Browse the repository at this point in the history
* optimizations

* bump version
  • Loading branch information
jbukhari authored Nov 23, 2020
1 parent 6f145ff commit 9527bfd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 49 deletions.
101 changes: 53 additions & 48 deletions dlx/marc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
'''
'''
"""dlx.marc"""

import re, json, time
import re, json
from datetime import datetime
from warnings import warn
from xml.etree import ElementTree as XML
Expand Down Expand Up @@ -39,9 +38,11 @@ def __init__(self, rtype, tag, code):
### Decorators

class Decorators():
def check_connection(method):
def check_connected(method):
def wrapper(*args, **kwargs):
DB.check_connection()
if not DB.connected:
raise Exception('Must be connected to DB before exececuting this function')

return method(*args, **kwargs)

return wrapper
Expand All @@ -59,7 +60,7 @@ class MarcSet():
# constructors

@classmethod
@Decorators.check_connection
@Decorators.check_connected
def from_query(cls, *args, **kwargs):
"""Instatiates a MarcSet object from a Pymongo database query.
Expand Down Expand Up @@ -90,6 +91,7 @@ def from_query(cls, *args, **kwargs):
self.query_params = [args, kwargs]
Marc = self.record_class
ac = kwargs.pop('auth_control', False)

self.records = map(lambda r: Marc(r, auth_control=ac), self.handle.find(*args, **kwargs))

return self
Expand Down Expand Up @@ -214,7 +216,7 @@ def to_xml(self, xref_prefix=''):

def to_mrk(self):
return '\n'.join([r.to_mrk() for r in self.records])

def to_str(self):
return '\n'.join([r.to_str() for r in self.records])

Expand Down Expand Up @@ -264,23 +266,23 @@ def max_id(cls):
return max_dict.get('_id') or 0

@classmethod
@Decorators.check_connection
@Decorators.check_connected
def handle(cls):
return DB.bibs if cls.__name__ == 'Bib' else DB.auths

@classmethod
def match_id(cls, id):
def match_id(cls, idx):
"""
Deprecated
"""

warn('dlx.marc.Marc.match_id() is deprecated. Use dlx.marc.Marc.from_id() instead')

return cls.find_one(filter={'_id' : id})
return cls.find_one(filter={'_id' : idx})

@classmethod
def from_id(cls, id):
return cls.from_query({'_id' : id})
def from_id(cls, idx, *args, **kwargs):
return cls.from_query({'_id' : idx}, *args, **kwargs)

@classmethod
def match_ids(cls, *ids, **kwargs):
Expand Down Expand Up @@ -367,7 +369,7 @@ def count_documents(cls, *args, **kwargs):

# Instance methods

def __init__(self, doc={}, *, auth_control=True, **kwargs):
def __init__(self, doc={}, *, auth_control=False, **kwargs):
self.id = int(doc['_id']) if '_id' in doc else None
self.updated = doc['updated'] if 'updated' in doc else None
self.user = doc['user'] if 'user' in doc else None
Expand All @@ -382,7 +384,7 @@ def controlfields(self):
def datafields(self):
return list(filter(lambda x: x.tag[:2] != '00', sorted(self.fields, key=lambda x: x.tag)))

def parse(self, doc, *, auth_control=True):
def parse(self, doc, *, auth_control=False):
for tag in filter(lambda x: False if x in ('_id', 'updated', 'user') else True, doc.keys()):
if tag == '000':
self.leader = doc['000'][0]
Expand All @@ -393,7 +395,7 @@ def parse(self, doc, *, auth_control=True):
else:
for field in filter(lambda x: [s.get('xref') or s.get('value') for s in x.get('subfields')], doc[tag]):
self.fields.append(Datafield.from_dict(record_type=self.record_type, tag=tag, data=field, auth_control=auth_control))

#### "get"-type methods

def get_fields(self, *tags):
Expand All @@ -403,12 +405,13 @@ def get_fields(self, *tags):
return list(filter(lambda x: True if x.tag in tags else False, sorted(self.fields, key=lambda x: x.tag)))

def get_field(self, tag, place=0):
fields = self.get_fields(tag)

if len(fields) == 0:
if place == 0:
return next(filter(lambda x: True if x.tag == tag else False, self.fields), None)

try:
return self.get_fields(tag)[place]
except IndexError:
return
if len(fields) > place:
return fields[place]

def get_values(self, tag, *codes, **kwargs):
if tag[:2] == '00':
Expand All @@ -422,8 +425,8 @@ def get_value(self, tag, code=None, address=[0, 0], language=None):
if isinstance(field, Controlfield):
return field.value

if isinstance(field, Datafield):
sub = self.get_subfield(tag, code, address=address)
if isinstance(field, Datafield):
sub = field.get_subfield(code, place=address[1])

if sub:
if language:
Expand Down Expand Up @@ -546,7 +549,7 @@ def set_008(self):
date_tag, date_code = Config.date_field
pub_date = self.get_value(date_tag, date_code)
pub_year = pub_date[0:4].ljust(4, '|')
cat_date = time.strftime('%y%m%d')
cat_date = datetime.utcnow().strftime('%y%m%d')

self.set('008', None, cat_date + text[6] + pub_year + text[11:])

Expand Down Expand Up @@ -581,7 +584,7 @@ def validate(self):
msg = '{} in {} : {}'.format(e.message, str(list(e.path)), self.to_json())
raise jsonschema.exceptions.ValidationError(msg)

@Decorators.check_connection
@Decorators.check_connected
def commit(self, user='admin'):
# clear the caches in case there is a new auth value
if isinstance(self, Auth):
Expand Down Expand Up @@ -846,7 +849,9 @@ def to_xml_raw(self, *tags, language=None, xref_prefix=''):
xref = None

for sub in field.subfields:
if not sub.value:
val = sub.value

if not val:
continue

if hasattr(sub, 'xref'):
Expand All @@ -859,7 +864,7 @@ def to_xml_raw(self, *tags, language=None, xref_prefix=''):
subnode.text = sub.translated(language)
continue

subnode.text = sub.value
subnode.text = val

if xref:
subnode = XML.SubElement(node, 'subfield')
Expand Down Expand Up @@ -977,43 +982,38 @@ class Auth(Marc):
_langcache = {}

@classmethod
@Decorators.check_connection
def lookup(cls, xref, code, language=None):
cache, langcache = Auth._cache, Auth._langcache

if language:
cached = langcache.get(xref, {}).get(code, {}).get(language, None)
cached = Auth._langcache.get(xref, {}).get(code, {}).get(language, None)
else:
cached = cache.get(xref, {}).get(code, None)

cached = Auth._cache.get(xref, {}).get(code, None)
if cached:
return cached

label_tags = Config.auth_heading_tags()
label_tags += Config.auth_language_tags() if language else []
auth = Auth.find_one({'_id': xref}, projection=dict.fromkeys(label_tags, 1))
auth = Auth.from_query({'_id': xref}, projection=dict.fromkeys(label_tags, 1))
value = auth.heading_value(code, language) if auth else None

if language:
langcache[xref] = {code: {language: value}}
Auth._langcache[xref] = {code: {language: value}}
else:
if xref in cache:
cache[xref].update({code: value})
if xref in Auth._cache:
Auth._cache[xref][code] = value
else:
cache[xref] = {code: value}
Auth._cache[xref] = {code: value}

return value

@classmethod
@Decorators.check_connection
def xlookup(cls, tag, code, value, *, record_type):
auth_tag = Config.authority_source_tag(record_type, tag, code)
xcache = Auth._xcache

if auth_tag is None:
return

cached = xcache.get(value, {}).get(auth_tag, {}).get(code, None)
cached = Auth._xcache.get(value, {}).get(auth_tag, {}).get(code, None)

if cached:
return cached
Expand All @@ -1022,12 +1022,11 @@ def xlookup(cls, tag, code, value, *, record_type):
auths = AuthSet.from_query(query.compile(), projection={'_id': 1})
xrefs = [r.id for r in list(auths)]

xcache.setdefault(value, {}).setdefault(auth_tag, {})[code] = xrefs
Auth._xcache.setdefault(value, {}).setdefault(auth_tag, {})[code] = xrefs

return xrefs

@classmethod
@Decorators.check_connection
def partial_lookup(cls, tag, code, string, *, record_type, limit=25):
"""Returns a list of tuples containing the authority-controlled values
that match the given string
Expand Down Expand Up @@ -1077,11 +1076,17 @@ def partial_lookup(cls, tag, code, string, *, record_type, limit=25):

def __init__(self, doc={}, **kwargs):
self.record_type = 'auth'
self._heading_field = None
super().__init__(doc, **kwargs)

@property
def heading_field(self):
return next(filter(lambda field: field.tag[0:1] == '1', self.fields), None)
if self._heading_field:
return self._heading_field

self._heading_field = next(filter(lambda field: field.tag[0:1] == '1', self.fields), None)

return self._heading_field

def heading_value(self, code, language=None):
if language:
Expand Down Expand Up @@ -1172,7 +1177,7 @@ def __eq__(self, other):
return self.to_dict() == other.to_dict()

@classmethod
def from_dict(cls, *, record_type, tag, data, auth_control=True):
def from_dict(cls, *, record_type, tag, data, auth_control=False):
self = cls()
self.record_type = record_type
self.tag = tag
Expand All @@ -1181,7 +1186,7 @@ def from_dict(cls, *, record_type, tag, data, auth_control=True):
self.ind2 = data['indicators'][1]

assert len(data['subfields']) > 0

for sub in data['subfields']:
if 'xref' in sub:
if auth_control:
Expand Down Expand Up @@ -1239,7 +1244,7 @@ def get_subfields(self, code):
return filter(lambda x: x.code == code, self.subfields)

def get_subfield(self, code, place=None):
if place is None:
if place is None or place == 0:
return next(self.get_subfields(code), None)

for i, sub in enumerate(self.get_subfields(code)):
Expand Down Expand Up @@ -1354,7 +1359,7 @@ def to_mrk(self, language=None):
else:
value = sub.value

string += ''.join(['${}{}'.format(sub.code, sub.value)])
string += ''.join(['${}{}'.format(sub.code, value)])

return string

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = '1.2.2'
version = '1.2.3'

import sys
from setuptools import setup, find_packages
Expand Down

0 comments on commit 9527bfd

Please sign in to comment.