Skip to content

Commit

Permalink
use stringPict.root to draw n-roots
Browse files Browse the repository at this point in the history
adjust comments on root
  • Loading branch information
mmatera committed Nov 13, 2024
1 parent ccd8fd9 commit 4dec809
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 69 deletions.
38 changes: 8 additions & 30 deletions sympy/printing/pretty/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
pprint_try_use_unicode = pretty_try_use_unicode


# TODO: use high-level methods from stringPict
# instead of stack/next methods to simplify
# this code.

class PrettyPrinter(Printer):
"""Printer, which converts an expression into 2D ASCII-art figure."""
printmethod = "_pretty"
Expand Down Expand Up @@ -2044,39 +2048,13 @@ def _print_nth_root(self, base, root):
or (base.is_Integer and base.is_nonnegative))):
return bpretty.left(nth_root[2])

# Construct root sign, start with the \/ shape
_zZ = xobj('/', 1)
rootsign = xobj('\\', 1) + _zZ
# Constructing the number to put on root
rpretty = self._print(root)
# roots look bad if they are not a single line
if rpretty.height() != 1:
return self._print(base)**self._print(1/root)
# If power is half, no number should appear on top of root sign
exp = '' if root == 2 else str(rpretty).ljust(2)
if len(exp) > 2:
rootsign = ' '*(len(exp) - 2) + rootsign
# Stack the exponent
rootsign = stringPict(exp + '\n' + rootsign)
rootsign.baseline = 0
# Diagonal: length is one less than height of base
linelength = bpretty.height() - 1
diagonal = stringPict('\n'.join(
' '*(linelength - i - 1) + _zZ + ' '*i
for i in range(linelength)
))
# Put baseline just below lowest line: next to exp
diagonal.baseline = linelength - 1
# Make the root symbol
rootsign = rootsign.right(diagonal)
# Det the baseline to match contents to fix the height
# but if the height of bpretty is one, the rootsign must be one higher
rootsign.baseline = max(1, bpretty.baseline)
#build result
s = prettyForm(hobj('_', 2 + bpretty.width()))
s = bpretty.above(s)
s = s.left(rootsign)
return s

if root == 2:
return bpretty.root()
return bpretty.root(rpretty)

def _print_Pow(self, power):
from sympy.simplify.simplify import fraction
Expand Down
83 changes: 44 additions & 39 deletions sympy/printing/pretty/stringpict.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,34 +337,34 @@ def root(self, n=None):
>>> from sympy.printing.pretty.stringpict import stringPict, prettyForm
>>> print(stringPict("x+3").root().right(" + a"))
___
\\/x+3 + a
_____
\\/ x+3 + a
>>> print(stringPict("x+3").root(stringPict("3")).right(" + a"))
3 ___
\\/x+3 + a
3 _____
\\/ x+3 + a
>>> print((prettyForm("x")**stringPict("a")).root().right(" + a"))
__
/ a
\\/ x + a
____
/ a
\\/ x + a
>>> print((prettyForm("x")**stringPict("a")).root(stringPict("3")).right(" + a"))
__
3 / a
\\/ x + a
____
3 / a
\\/ x + a
>>> print((prettyForm("x+3")/prettyForm("y")).root().right(" + a"))
___
/x+3
/ --- + a
_____
/ x+3
/ --- + a
\\/ y
>>> print((prettyForm("x+3")/prettyForm("y")).root(stringPict("3")).right(" + a"))
___
/x+3
3 / --- + a
\\/ y
_____
/ x+3
3 / --- + a
\\/ y
For indices with more than one line, use the Pow form:
Expand All @@ -377,39 +377,44 @@ def root(self, n=None):
|---| + a
\\ y /
"""
# TODO: use it in root drawing in PrettyPrinter.
#
# put line over expression
# Decide if using a square root symbol or
# an base - exponent form:
if n is not None:
if isinstance(n, str):
n = n.ljust(2)
n = stringPict(n)
elif n.width()<2:
n = stringPict(str(n).ljust(2))
if n.height() > 1:
exponent = n.parens().left(stringPict("1 / "))
exponent = n.parens().left(stringPict("1 / "), align="c")
return self ** exponent

result = self.above('_' * self.width())
# put line over expression
result = self.above(hobj('_', 2 + self.width()))
#construct right half of root symbol
height = self.height()

root_sign = prettyForm(xobj('\\', 1))
_zZ = xobj('/', 1)
root_sign = prettyForm(xobj('\\', 1)+ _zZ)
if n is not None:
root_sign = root_sign.above(n, align="r")
if height>1:
slash = '\n'.join(' ' * (height - i - 2) + _zZ + ' ' * i for i in range(height-1))
# TODO: To improve the use of the space, consider
# using a vertical line instead '/', like
# -
# |x
# 20|-
# \|2
#
# # remove the `.ljust.(2)` in `n` and
# # replace the previous line by
# _zZ = xobj('|', 1)
# slash = "\n".join(height*[_zZ])
#
# but this requires to change many tests.
slash = stringPict(" ").above(slash, align='l')
root_sign = root_sign.right(slash, align="b")

_zZ = xobj('/', 1)
slash = '\n'.join(' ' * (height - i - 1) + _zZ + ' ' * i for i in range(height))
# TODO: To improve the use of the space, consider
# using a vertical line instead '/', like
# -
# |x
# 20|-
# \|2
#
# _zZ = xobj('|', 1)
# slash = "\n".join(height*[_zZ])
#
# but this requires to change many tests.
slash = stringPict(slash, self.baseline)
root_sign = root_sign.right(slash, align="b")
return result.left(root_sign, align="b")

def render(self, *args, **kwargs):
Expand Down

0 comments on commit 4dec809

Please sign in to comment.