Skip to content

Commit

Permalink
unquote_and_split refactor (#67)
Browse files Browse the repository at this point in the history
* Move unquote_and_split from build_tar.py to SplitNameValuePairAtSeparator

* Add tests for SplitNameValuePairAtSeparator

* Rename argument from c to sep for readability
  • Loading branch information
j3parker authored and aiuto committed Jul 31, 2019
1 parent 6303028 commit 5223103
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 48 deletions.
48 changes: 6 additions & 42 deletions pkg/build_tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from rules_pkg import archive
from absl import flags

from helpers import GetFlagValue
from helpers import GetFlagValue, SplitNameValuePairAtSeparator

flags.DEFINE_string('output', None, 'The output file, mandatory')
flags.mark_flag_as_required('output')
Expand Down Expand Up @@ -262,42 +262,6 @@ def add_deb(self, deb):
self.add_tar(tmpfile[1])
os.remove(tmpfile[1])


def unquote_and_split(arg, c):
"""Split a string at the first unquoted occurrence of a character.
Split the string arg at the first unquoted occurrence of the character c.
Here, in the first part of arg, the backslash is considered the
quoting character indicating that the next character is to be
added literally to the first part, even if it is the split character.
Args:
arg: the string to be split
c: the character at which to split
Returns:
The unquoted string before the separator and the string after the
separator.
"""
head = ''
i = 0
while i < len(arg):
if arg[i] == c:
return (head, arg[i + 1:])
elif arg[i] == '\\':
i += 1
if i == len(arg):
# dangling quotation symbol
return (head, '')
else:
head += arg[i]
else:
head += arg[i]
i += 1
# if we leave the loop, the character c was not found unquoted
return (head, '')


def main(unused_argv):
# Parse modes arguments
default_mode = None
Expand All @@ -308,7 +272,7 @@ def main(unused_argv):
mode_map = {}
if FLAGS.modes:
for filemode in FLAGS.modes:
(f, mode) = unquote_and_split(filemode, '=')
(f, mode) = SplitNameValuePairAtSeparator(filemode, '=')
if f[0] == '/':
f = f[1:]
mode_map[f] = int(mode, 8)
Expand All @@ -319,7 +283,7 @@ def main(unused_argv):
names_map = {}
if FLAGS.owner_names:
for file_owner in FLAGS.owner_names:
(f, owner) = unquote_and_split(file_owner, '=')
(f, owner) = SplitNameValuePairAtSeparator(file_owner, '=')
(user, group) = owner.split('.', 1)
if f[0] == '/':
f = f[1:]
Expand All @@ -330,7 +294,7 @@ def main(unused_argv):
ids_map = {}
if FLAGS.owners:
for file_owner in FLAGS.owners:
(f, owner) = unquote_and_split(file_owner, '=')
(f, owner) = SplitNameValuePairAtSeparator(file_owner, '=')
(user, group) = owner.split('.', 1)
if f[0] == '/':
f = f[1:]
Expand Down Expand Up @@ -368,7 +332,7 @@ def file_attributes(filename):
output.add_deb(deb)

for f in FLAGS.file:
(inf, tof) = unquote_and_split(f, '=')
(inf, tof) = SplitNameValuePairAtSeparator(f, '=')
output.add_file(inf, tof, **file_attributes(tof))
for f in FLAGS.empty_file:
output.add_empty_file(f, **file_attributes(f))
Expand All @@ -381,7 +345,7 @@ def file_attributes(filename):
for deb in FLAGS.deb:
output.add_deb(deb)
for link in FLAGS.link:
l = unquote_and_split(link, ':')
l = SplitNameValuePairAtSeparator(link, ':')
output.add_link(l[0], l[1])


Expand Down
35 changes: 34 additions & 1 deletion pkg/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,39 @@
import os
import sys

def SplitNameValuePairAtSeparator(arg, sep):
"""Split a string at the first unquoted occurrence of a character.
Split the string arg at the first unquoted occurrence of the character c.
Here, in the first part of arg, the backslash is considered the
quoting character indicating that the next character is to be
added literally to the first part, even if it is the split character.
Args:
arg: the string to be split
sep: the character at which to split
Returns:
The unquoted string before the separator and the string after the
separator.
"""
head = ''
i = 0
while i < len(arg):
if arg[i] == sep:
return (head, arg[i + 1:])
elif arg[i] == '\\':
i += 1
if i == len(arg):
# dangling quotation symbol
return (head, '')
else:
head += arg[i]
else:
head += arg[i]
i += 1
# if we leave the loop, the character sep was not found unquoted
return (head, '')

def GetFlagValue(flagvalue, strip=True):
"""Converts a raw flag string to a useable value.
Expand Down Expand Up @@ -53,4 +86,4 @@ def GetFlagValue(flagvalue, strip=True):

if strip:
return flagvalue.strip()
return flagvalue
return flagvalue
61 changes: 56 additions & 5 deletions pkg/tests/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,80 @@
import helpers


class HelpersTestCase(unittest.TestCase):
def test_getFlagValue_nonStripped(self):
class GetFlagValueTestCase(unittest.TestCase):
def testNonStripped(self):
self.assertEqual(helpers.GetFlagValue('value ', strip=False), 'value ')

def test_getFlagValue_Stripped(self):
def testStripped(self):
self.assertEqual(helpers.GetFlagValue('value ', strip=True), 'value')

def test_getFlagValue_nonStripped_fromFile(self):
def testNonStripped_fromFile(self):
with tempfile.NamedTemporaryFile() as f:
f.write('value ')
f.flush()
self.assertEqual(helpers.GetFlagValue(
'@{}'.format(f.name),
strip=False), 'value ')

def test_getFlagValue_Stripped_fromFile(self):
def testStripped_fromFile(self):
with tempfile.NamedTemporaryFile() as f:
f.write('value ')
f.flush()
self.assertEqual(helpers.GetFlagValue(
'@{}'.format(f.name),
strip=True), 'value')

class SplitNameValuePairAtSeparatorTestCase(unittest.TestCase):
def testNoSep(self):
key, val = helpers.SplitNameValuePairAtSeparator('abc', '=')

self.assertEqual(key, 'abc')
self.assertEqual(val, '')

def testNoSepWithEscape(self):
key, val = helpers.SplitNameValuePairAtSeparator('a\\=bc', '=')

self.assertEqual(key, 'a=bc')
self.assertEqual(val, '')

def testNoSepWithDanglingEscape(self):
key, val = helpers.SplitNameValuePairAtSeparator('abc\\', '=')

self.assertEqual(key, 'abc')
self.assertEqual(val, '')

def testHappyCase(self):
key, val = helpers.SplitNameValuePairAtSeparator('abc=xyz', '=')

self.assertEqual(key, 'abc')
self.assertEqual(val, 'xyz')

def testHappyCaseWithEscapes(self):
key, val = helpers.SplitNameValuePairAtSeparator('a\\=\\=b\\=c=xyz', '=')

self.assertEqual(key, 'a==b=c')
self.assertEqual(val, 'xyz')

def testStopsAtFirstSep(self):
key, val = helpers.SplitNameValuePairAtSeparator('a=b=c', '=')

self.assertEqual(key, 'a')
self.assertEqual(val, 'b=c')

def testDoesntUnescapeVal(self):
key, val = helpers.SplitNameValuePairAtSeparator('abc=x\\=yz\\', '=')

self.assertEqual(key, 'abc')

# the val doesn't get unescaped at all
self.assertEqual(val, 'x\\=yz\\')

def testUnescapesNonsepCharsToo(self):
key, val = helpers.SplitNameValuePairAtSeparator('na\\xffme=value', '=')

# this behaviour is surprising
self.assertEqual(key, 'naxffme')
self.assertEqual(val, 'value')

if __name__ == "__main__":
unittest.main()

0 comments on commit 5223103

Please sign in to comment.