Skip to content

Commit

Permalink
Merge pull request #83 from v1tal3/development
Browse files Browse the repository at this point in the history
v1.3.6 update
  • Loading branch information
matt852 authored Jul 4, 2018
2 parents a5c0bcc + cc2d8c1 commit 7d3df1d
Show file tree
Hide file tree
Showing 40 changed files with 450 additions and 239 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*.log
*.db
*.vscode
*.idea
app/log/
log/*
scripts_bank/logs/*
Expand Down
17 changes: 10 additions & 7 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from flask_sqlalchemy import SQLAlchemy
from flask_bootstrap import Bootstrap
from flask_script import Manager
from data_handler import DataHandler
from log_handler import LogHandler
from ssh_handler import SSHHandler
from celery import Celery
from .data_handler import DataHandler
from .log_handler import LogHandler
from .ssh_handler import SSHHandler


app = Flask(__name__, instance_relative_config=True)
Expand All @@ -24,9 +23,13 @@

sshhandler = SSHHandler()

# Celery
celery = Celery(app.name, broker=app.config['CELERY_BROKER_URL'], backend=app.config['CELERY_RESULT_BACKEND'])
celery.conf.update(app.config)
# Errors blueprint
from app.errors import bp as errors_bp
app.register_blueprint(errors_bp)

# Authentication blueprint
from app.auth import bp as auth_bp
app.register_blueprint(auth_bp, url_prefix='/auth')

from app import views, models

Expand Down
5 changes: 5 additions & 0 deletions app/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flask import Blueprint

bp = Blueprint('auth', __name__)

from app.auth import routes
11 changes: 11 additions & 0 deletions app/auth/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from flask_wtf import FlaskForm
from wtforms.fields import StringField, PasswordField, SubmitField
from wtforms.validators import DataRequired


class LoginForm(FlaskForm):
"""User login form."""

user = StringField('Username', validators=[DataRequired()], render_kw={'autofocus': True})
pw = PasswordField('Password', validators=[DataRequired()])
submit_button = SubmitField('Login')
50 changes: 50 additions & 0 deletions app/auth/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from app import app, logger, sshhandler
from app.auth import bp
from flask import redirect, request, render_template, session, url_for
from app.auth.forms import LoginForm
from app.scripts_bank.redis_logic import storeUserInRedis, deleteUserInRedis


@bp.route('/login', methods=['GET', 'POST'])
def login():
"""Login page for user to save credentials."""
form = LoginForm()
if request.method == 'POST':
# If page is accessed via form POST submission
if form.validate_on_submit():
# Validate form
if 'USER' in session:
# If user already stored in session variable, return home page
return redirect(url_for('viewHosts'))
else:
# Try to save user credentials in Redis. Return index if fails
try:
if storeUserInRedis(request.form['user'], request.form['pw']):
logger.write_log('logged in')
return redirect(url_for('viewHosts'))
except:
logger.write_log('failed to store user data in Redis when logging in')
# Return login page if accessed via GET request
return render_template('auth/login.html', title='Login with SSH credentials', form=form)


@bp.route('/logout')
def logout():
"""Disconnect all SSH sessions by user."""
sshhandler.disconnectAllSSHSessions()
try:
currentUser = session['USER']
deleteUserInRedis()
logger.write_log('deleted user %s data stored in Redis' % (currentUser))
session.pop('USER', None)
logger.write_log('deleted user %s as stored in session variable' % (currentUser), user=currentUser)
u = session['UUID']
session.pop('UUID', None)
logger.write_log('deleted UUID %s for user %s as stored in session variable' % (u, currentUser), user=currentUser)
u = None
except KeyError:
logger.write_log('Exception thrown on logout.')
return redirect(url_for('index'))
logger.write_log('logged out')

return redirect(url_for('index'))
6 changes: 3 additions & 3 deletions app/device_classes/device_definitions/base_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def run_ssh_command(self, command, activeSession):
activeSession.exit_config_mode()
# Try to retrieve command results again
try:
result = self.run_ssh_command('show ip interface brief', activeSession)
result = self.run_ssh_command(command, activeSession)
# If command still failed, return nothing
if "Invalid input detected" in result:
return self.cleanup_ios_output('', '')
return ''
except:
# If failure to access SSH channel or run command, return nothing
return self.cleanup_ios_output('', '')
return ''

# Return command output
return result
Expand Down
6 changes: 6 additions & 0 deletions app/device_classes/device_definitions/cisco/cisco_asa.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def pull_device_uptime(self, activeSession):
uptime = x.split(' ', 2)[2]
return uptime

def pull_device_poe_status(self, activeSession):
"""Retrieve PoE status for all interfaces."""
# Return empty result - unsupported on ASA
return {}

def pull_host_interfaces(self, activeSession):
"""Retrieve list of interfaces on device."""
# result = self.run_ssh_command('show interface ip brief', activeSession)
Expand Down Expand Up @@ -107,6 +112,7 @@ def clean_interface_description(self, x):
def cleanup_asa_output(self, asaOutput):
"""Clean up returned ASA output from 'show ip interface brief'."""
data = []
interface = {}
# Used to set if we're on the first loop or not
notFirstLoop = False

Expand Down
35 changes: 34 additions & 1 deletion app/device_classes/device_definitions/cisco/cisco_ios.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def pull_interface_mac_addresses(self, activeSession):
# Split line on commas
x = line.split(',')
# Remove empty fields from string, specifically if first field is empty (1-2 digit vlan causes this)
x = filter(None, x)
x = list(filter(None, x))
if x:
y = {}
y['vlan'] = x[0].strip()
Expand Down Expand Up @@ -143,6 +143,39 @@ def pull_device_uptime(self, activeSession):
output = x.split(' ', 3)[-1]
return output

def pull_device_poe_status(self, activeSession): # TODO - WRITE TEST FOR
"""Retrieve PoE status for all interfaces."""
status = {}
command = 'show power inline | begin Interface'
result = self.get_cmd_output(command, activeSession)
checkStrings = ['Interface', 'Watts', '---']

# If output returned from command execution, parse output
if result:
for x in result:
# If any string from checkStrings in line, or line is blank, skip to next loop iteration
if any(y in x for y in checkStrings) or not x:
continue
line = x.split()

# Convert interface short abbreviation to long name
regExp = re.compile(r'[A-Z][a-z][0-9]\/')
if regExp.search(line[0]):
if line[0][0] == 'G':
line[0] = line[0].replace('Gi', 'GigabitEthernet')
elif line[0][0] == 'F':
line[0] = line[0].replace('Fa', 'FastEthernet')
elif line[0][0] == 'T':
line[0] = line[0].replace('Te', 'TenGigabitEthernet')
elif line[0][0] == 'E':
line[0] = line[0].replace('Eth', 'Ethernet')

# Line[0] is interface name
# Line[2] is operational status
status[line[0]] = line[2]
# Return dictionary with results
return status

def pull_host_interfaces(self, activeSession):
"""Retrieve list of interfaces on device."""
resultA = self.run_ssh_command('show ip interface brief', activeSession)
Expand Down
5 changes: 5 additions & 0 deletions app/device_classes/device_definitions/cisco/cisco_nxos.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ def pull_device_uptime(self, activeSession):
output = x.split(' ', 3)[-1]
return output

def pull_device_poe_status(self, activeSession): # TODO - WRITE TEST FOR
"""Retrieve PoE status for all interfaces."""
# Return empty result - unsupported on NX-OS
return {}

def pull_host_interfaces(self, activeSession):
"""Retrieve list of interfaces on device."""
outputResult = ''
Expand Down
5 changes: 5 additions & 0 deletions app/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flask import Blueprint

bp = Blueprint('errors', __name__)

from app.errors import handlers
14 changes: 14 additions & 0 deletions app/errors/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from flask import render_template
from app.errors import bp


@bp.errorhandler(404)
def not_found_error(error):
"""Return 404 page on 404 error."""
return render_template('errors/404.html', error=error), 404


@bp.errorhandler(500)
def handle_500(error):
"""Return 500 page on 500 error."""
return render_template('errors/500.html', error=error), 500
7 changes: 0 additions & 7 deletions app/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@
from wtforms.validators import DataRequired, IPAddress


class LoginForm(FlaskForm):
"""User login form."""

user = StringField('Username', validators=[DataRequired()])
pw = PasswordField('Password', validators=[DataRequired()])


class LocalCredentialsForm(FlaskForm):
"""Local credentials form, on a per device basis."""

Expand Down
62 changes: 41 additions & 21 deletions app/scripts_bank/lib/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
try:
from urllib import urlopen # Python 2
except ImportError:
from urllib.parse import urlopen # Python 3
from urllib.request import urlopen # Python 3


class UserCredentials(object):
Expand Down Expand Up @@ -64,32 +64,52 @@ def isInteger(x):
def checkForVersionUpdate(config):
"""Check for NetConfig updates on GitHub."""
try:
# with urlopen(config['GH_MASTER_BRANCH_URL']) as a:
# masterConfig = a.read().decode('utf-8')
masterConfig = urlopen(config['GH_MASTER_BRANCH_URL'])
except IOError:
masterConfig = masterConfig.read().decode('utf-8')
# Reverse lookup as the VERSION variable should be close to the bottom
if masterConfig:
for x in masterConfig.splitlines():
if 'VERSION' in x:
x = x.split('=')
try:
# Strip whitespace and single quote mark
masterVersion = x[-1].strip().strip("'")
except IndexError:
continue
# Verify if current version matches most recent GitHub release
if masterVersion != config['VERSION']:
# Return False if the current version does not match the most recent version
return jsonify(status="False", masterVersion=masterVersion)
# If VERSION can't be found, successfully compared, or is identical, just return True
return jsonify(status="True")
else:
# Error when accessing URL. Default to True
return "True"
except IOError as e:
# Catch exception if unable to access URL, or access to internet is blocked/down. Default to True
return "True"
# Reverse lookup as the VERSION variable should be close to the bottom
for x in masterConfig:
if 'VERSION' in x:
x = x.split('=')
try:
# Strip whitespace and single quote mark
masterVersion = x[-1].strip().strip("'")
except IndexError:
continue
# Verify if current version matches most recent GitHub release
if masterVersion != config['VERSION']:
# Return False if the current version does not match the most recent version
return jsonify(status="False", masterVersion=masterVersion)
# If VERSION can't be found, successfully compared, or is identical, just return True
return jsonify(status="True")
except Exception as e:
# Return True for all other exceptions
return "True"


# Get current timestamp for when starting a script
def getCurrentTime():
currentTime = datetime.now()
return currentTime
"""Get current timestamp."""
currentTime = datetime.now()
return currentTime


# Returns time elapsed between current time and provided time in 'startTime'
def getScriptRunTime(startTime):
endTime = getCurrentTime() - startTime
return endTime
"""Calculate time elapsed since startTime was first measured."""
endTime = getCurrentTime() - startTime
return endTime


def interfaceReplaceSlash(x):
"""Replace all forward slashes in string 'x' with an underscore."""
x = x.replace('_', '/')
return x
32 changes: 15 additions & 17 deletions app/scripts_bank/redis_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ def generateSessionUUID():

def deleteUserInRedis():
"""Delete logged in user in Redis."""
saved_id = str(g.db.hget('users', session['USER']))
g.db.delete(str(saved_id))
saved_id = g.db.hget('users', session['USER'])
g.db.delete(saved_id)

# Delete any locally saved credentials tied to user
pattern = '*--' + str(session['USER'])
for key in g.db.hscan_iter('localusers', match=pattern):
# key[1] is the value we need to delete
g.db.delete(str(key[1]))
g.db.delete(str(saved_id))
g.db.delete(key[1])
g.db.delete(saved_id)


def resetUserRedisExpireTimer():
Expand All @@ -28,7 +28,7 @@ def resetUserRedisExpireTimer():
x is Redis key to reset timer on.
"""
try:
saved_id = str(g.db.hget('users', session['USER']))
saved_id = g.db.hget('users', session['USER'])
g.db.expire(saved_id, app.config['REDISKEYTIMEOUT'])
except:
pass
Expand All @@ -45,11 +45,11 @@ def storeUserInRedis(user, pw, privpw='', host=''):
# If user id doesn't exist, create new one with next available UUID
# Else reuse existing key,
# to prevent incrementing id each time the same user logs in
if str(g.db.hget('users', user)) == 'None':
# Create new user id, incrementing by 10
user_id = str(g.db.incrby('next_user_id', 10))
if g.db.hget('users', user):
user_id = g.db.hget('users', user)
else:
user_id = str(g.db.hget('users', user))
# Create new user id, incrementing by 10
user_id = g.db.incrby('next_user_id', 10)
g.db.hmset(user_id, dict(user=user, pw=pw))
g.db.hset('users', user, user_id)
# Set user info timer to auto expire and clear data
Expand All @@ -66,18 +66,16 @@ def storeUserInRedis(user, pw, privpw='', host=''):
# Key to save variable is host id, --, and username of logged in
# user
key = str(host.id) + "--" + str(session['USER'])
if str(g.db.hget('localusers', key)) == 'None':
# Create new host id, incrementing by 10
saved_id = str(g.db.incrby('next_user_id', 10))
if g.db.hget('localusers', key):
saved_id = g.db.hget('localusers', key)
else:
saved_id = str(g.db.hget('localusers', key))
# Create new host id, incrementing by 10
saved_id = g.db.incrby('next_user_id', 10)

if privpw:
g.db.hmset(saved_id, dict(user=user, localuser=session[
'USER'], pw=pw, privpw=privpw))
g.db.hmset(saved_id, dict(user=user, localuser=session['USER'], pw=pw, privpw=privpw))
else:
g.db.hmset(saved_id, dict(
user=user, localuser=session['USER'], pw=pw))
g.db.hmset(saved_id, dict(user=user, localuser=session['USER'], pw=pw))
g.db.hset('localusers', key, saved_id)
# Set user info timer to auto expire and clear data
g.db.expire(saved_id, app.config['REDISKEYTIMEOUT'])
Expand Down
Loading

0 comments on commit 7d3df1d

Please sign in to comment.