Skip to content

Commit

Permalink
Fix get_paginated_resource
Browse files Browse the repository at this point in the history
  • Loading branch information
Rusty Brooks committed Mar 12, 2020
1 parent 953d4f1 commit 8fca8d8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 34 deletions.
57 changes: 29 additions & 28 deletions OTXv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, api_key, proxy=None, proxy_https=None, server="https://otx.al
self.request_session = None
self.headers = {
'X-OTX-API-KEY': self.key,
'User-Agent': user_agent or 'OTX Python {}/1.5.7'.format(project),
'User-Agent': user_agent or 'OTX Python {}/1.5.8'.format(project),
'Content-Type': 'application/json'
}

Expand All @@ -120,6 +120,24 @@ def session(self):

return self.request_session

def now(self):
return pytz.utc.localize(datetime.datetime.utcnow())

@classmethod
def fix_date(cls, date_str):
if date_str is None:
return None

if isinstance(date_str, datetime.datetime):
dt = date_str
else:
dt = dateutil.parser.parse(date_str) if date_str else None

if dt and dt.tzinfo is None:
dt = pytz.utc.localize(dt)

return dt

@classmethod
def handle_response_errors(cls, response):
def _response_json():
Expand Down Expand Up @@ -215,7 +233,7 @@ def create_pulse(self, **kwargs):
:param industries(list of strings) list of industries related to pulse
:param malware_families(list of strings) list of malware families related to pulse
:param attack_ids(list of strings) list of ATT&CK ids related to pulse
:return: request body response
:raises BadRequest (400) On failure, BadRequest will be raised containing the invalid fields.
Expand All @@ -237,7 +255,7 @@ def create_pulse(self, **kwargs):
'tags': kwargs.get('tags', []),
'references': kwargs.get('references', []),
'indicators': kwargs.get('indicators', []),
'group_ids': kwargs.get('group_ids', []),
'group_ids': kwargs.get('group_ids', []),
'adversary': kwargs.get('adversary'),
'targeted_countries': kwargs.get('targeted_countries', []),
'industries': kwargs.get('industries', []),
Expand Down Expand Up @@ -404,7 +422,7 @@ def search_pulses(self, query, max_results=25):
:param max_results: Limit the number of pulses returned in response
:return: All pulses matching `query`
"""
search_pulses_url = self.create_url(SEARCH_PULSES, q=query, page=1, limit=20)
search_pulses_url = self.create_url(SEARCH_PULSES, q=query, page=1, limit=25)
return self._get_paginated_resource(search_pulses_url, max_results=max_results)

def search_users(self, query, max_results=25):
Expand All @@ -430,7 +448,6 @@ def _get_paginated_resource(self, url=SUBSCRIBED, max_results=25):
additional_fields = {}
while next_page_url and len(results) < max_results:
json_data = self.get(next_page_url)
max_results -= len(json_data.get('results'))
for r in json_data.pop("results"):
results.append(r)
next_page_url = json_data.pop("next")
Expand Down Expand Up @@ -552,12 +569,14 @@ def add_or_update_pulse_indicators(self, pulse_id, indicators):
indicators_to_update = []

for indicator in indicators:
indicator = copy.deepcopy(indicator)
if indicator['indicator'] in current_indicators:
new_indicator = copy.deepcopy(indicator)
new_indicator.update({
'id': current_indicators[indicator['indicator']]['id'],
})
indicators_to_update.append(new_indicator)
indicator = copy.deepcopy(indicator)
indicator['id'] = current_indicators[indicator['indicator']]['id']
if 'expiration' in indicator and 'is_active' not in indicator:
if self.fix_date(indicator['expiration']) > self.now():
indicator['is_active'] = 1
indicators_to_update.append(indicator)
else:
indicators_to_add.append(indicator)

Expand Down Expand Up @@ -776,9 +795,6 @@ def __init__(self, api_key, cache_dir=None, max_age=None, *args, **kwargs):

self.load_data()

def now(self):
return pytz.utc.localize(datetime.datetime.utcnow())

def load_data(self):
datfile = os.path.join(self.cache_dir, 'data.json')

Expand Down Expand Up @@ -929,21 +945,6 @@ def find_pulses(self, return_type='pulse_id', author_names=None, modified_since=
else:
raise Exception("return_type should be one of ['pulse_id', 'pulse']")

@classmethod
def fix_date(cls, dtstr):
if dtstr is None:
return None

if isinstance(dtstr, datetime.datetime):
dt = dtstr
else:
dt = dateutil.parser.parse(dtstr) if dtstr else None

if dt and dt.tzinfo is None:
dt = pytz.utc.localize(dt)

return dt

# FIXME this is unordered...
def getall(self, modified_since=None, author_name=None, iter=False, limit=None, max_page=None, max_items=None):
if iter:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

setup(
name='OTXv2',
version='1.5.7',
version='1.5.8',
description='AlienVault OTX API',
author='AlienVault Team',
author_email='[email protected]',
url='https://github.com/AlienVault-Labs/OTX-Python-SDK',
download_url='https://github.com/AlienVault-Labs/OTX-Python-SDK/tarball/1.5.7',
download_url='https://github.com/AlienVault-Labs/OTX-Python-SDK/tarball/1.5.8',
py_modules=['OTXv2', 'IndicatorTypes','patch_pulse'],
install_requires=['requests', 'python-dateutil', 'pytz']
)
9 changes: 5 additions & 4 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def test_add_or_update_indicators(self):
# add some indicators and update others. omitted indicators should stay same
indicators = [
{'indicator': "two.com", 'type': 'domain'}, # no change
{'indicator': "three.com", 'type': 'domain'}, # new indicator
{'indicator': "three.com", 'type': 'domain'}, # new indicator
{'indicator': "one.com", 'type': 'domain', 'title': 'one.com title'}, # change title
]
self.otx.add_or_update_pulse_indicators(pulse_id, indicators)
Expand Down Expand Up @@ -566,12 +566,13 @@ def test_add_or_update_indicators(self):
self.assertEqual(expected, actual)

# set a new expiration
new_expiration = (datetime.datetime.utcnow().replace(microsecond=0) + datetime.timedelta(days=14)).isoformat()
indicators = [
{'indicator': u'8.8.8.8', 'is_active': 1, 'expiration': '2020-01-01T00:00:00'},
{'indicator': u'8.8.8.8', 'expiration': new_expiration},
]
self.otx.add_or_update_pulse_indicators(pulse_id, indicators)
expected = [
{'indicator': u'8.8.8.8', 'type': u'IPv4', 'expiration': '2020-01-01T00:00:00', 'is_active': 1, 'title': u''},
{'indicator': u'8.8.8.8', 'type': u'IPv4', 'expiration': new_expiration, 'is_active': 1, 'title': u''},
{'indicator': u'[email protected]', 'type': u'email', 'expiration': None, 'is_active': 1, 'title': u''},
{'indicator': u'[email protected]', 'type': u'email', 'expiration': None, 'is_active': 1, 'title': u''},
{'indicator': u'one.com', 'type': u'domain', 'expiration': None, 'is_active': 1, 'title': u'one.com title'},
Expand Down Expand Up @@ -881,7 +882,7 @@ def _test_backoff(self):

def test_user_agent(self):
o = OTXv2(self.api_key, server=ALIEN_DEV_SERVER, project='foo')
self.assertEqual(o.headers['User-Agent'], 'OTX Python foo/1.5.7')
self.assertEqual(o.headers['User-Agent'], 'OTX Python foo/1.5.8')

o = OTXv2(self.api_key, server=ALIEN_DEV_SERVER, user_agent='foo')
self.assertEqual(o.headers['User-Agent'], 'foo')
Expand Down

0 comments on commit 8fca8d8

Please sign in to comment.