roll.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878
  1. #!/usr/bin/env python
  2. import attr
  3. import logging
  4. import re
  5. import sys
  6. import operator
  7. import traceback
  8. import functools
  9. from numbers import Number
  10. from random import SystemRandom
  11. from copy import copy
  12. from arpeggio import ParserPython, RegExMatch, Optional, ZeroOrMore, OneOrMore, OrderedChoice, Sequence, Combine, Not, EOF, PTNodeVisitor, visit_parse_tree, ParseTreeNode, SemanticActionResults
  13. from typing import Union, List, Any, Tuple, Dict, Callable, Set, TextIO
  14. from typing import Optional as OptionalType
  15. try:
  16. import colorama
  17. colorama.init()
  18. from colors import color
  19. except ImportError:
  20. # Fall back to no color
  21. def color(s: str, *args, **kwargs):
  22. '''Fake color function that does nothing.
  23. Used when the colors module cannot be imported.'''
  24. return s
  25. EXPR_COLOR = "green"
  26. RESULT_COLOR = "red"
  27. DETAIL_COLOR = "yellow"
  28. logFormatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s', '%Y-%m-%d %H:%M:%S')
  29. logger = logging.getLogger(__name__)
  30. logger.setLevel(logging.INFO)
  31. logger.handlers = []
  32. logger.addHandler(logging.StreamHandler())
  33. for handler in logger.handlers:
  34. handler.setFormatter(logFormatter)
  35. try:
  36. # If imported, input() automatically uses it
  37. import readline
  38. except ImportError:
  39. logger.warning("Could not import readline: Advanced line editing unavailable")
  40. sysrand = SystemRandom()
  41. randint = sysrand.randint
  42. # Implementing the syntax described here: https://www.critdice.com/roll-advanced-dice
  43. # https://stackoverflow.com/a/23956778/125921
  44. # Whitespace parsing
  45. def Whitespace(): return RegExMatch(r'\s+')
  46. def OpWS(): return Optional(Whitespace)
  47. # Number parsing
  48. def Digits(): return RegExMatch('[0-9]+')
  49. def NonzeroDigits():
  50. '''Digits with at least one nonzero number.'''
  51. return RegExMatch('0*[1-9][0-9]*')
  52. def Sign(): return ['+', '-']
  53. def Integer(): return Optional(Sign), Digits
  54. def PositiveInteger(): return Optional('+'), Digits
  55. def FloatingPoint():
  56. return (
  57. Optional(Sign),
  58. [
  59. # e.g. '1.', '1.0'
  60. (Digits, '.', Optional(Digits)),
  61. # e.g. '.1'
  62. ('.', Digits),
  63. ]
  64. )
  65. def Scientific():
  66. return ([FloatingPoint, Integer], RegExMatch('[eE]'), Integer)
  67. def Number(): return Combine([Scientific, FloatingPoint, Integer])
  68. def ReservedWord():
  69. '''Matches identifiers that aren't allowed as variable names.'''
  70. command_word_parsers = []
  71. for cmd_type in Command():
  72. cmd_parser = cmd_type()
  73. if isinstance(cmd_parser, tuple):
  74. command_word_parsers.append(cmd_parser[0])
  75. else:
  76. command_word_parsers.append(cmd_parser)
  77. return([
  78. # Starts with a roll expression
  79. RollExpr,
  80. # Matches a command word exactly
  81. (command_word_parsers, [RegExMatch('[^A-Za-z0-9_]'), EOF]),
  82. ])
  83. # Valid variable name parser (should disallow names like 'help', 'quit', or 'd4r')
  84. def Identifier(): return (
  85. Not(ReservedWord),
  86. RegExMatch(r'[A-Za-z_][A-Za-z0-9_]*')
  87. )
  88. def MyNum(): return (
  89. Not('0'),
  90. RegExMatch('[0-9]+'),
  91. )
  92. # Roll expression parsing
  93. def PercentileFace(): return '%'
  94. def DieFace(): return [NonzeroDigits, PercentileFace, RegExMatch(r'F(\.[12])?')]
  95. def BasicRollExpr():
  96. return (
  97. Optional(NonzeroDigits),
  98. RegExMatch('[dD]'),
  99. DieFace,
  100. )
  101. def DropSpec(): return 'kh kl K k X x -H -L'.split(' '), Optional(NonzeroDigits)
  102. def CompareOp(): return '<= < >= > ≤ ≥ ='.split(' ')
  103. def Comparison(): return CompareOp, Integer
  104. def RerollType(): return Combine(['r', 'R', ('!', Optional('!'), Optional('p'))])
  105. def RerollSpec():
  106. return (
  107. RerollType,
  108. Optional(
  109. Optional(CompareOp),
  110. Integer,
  111. ),
  112. )
  113. def CountSpec():
  114. return (
  115. Comparison,
  116. Optional('f', Comparison),
  117. )
  118. def RollExpr():
  119. return (
  120. BasicRollExpr,
  121. Optional([DropSpec, RerollSpec]),
  122. Optional(CountSpec),
  123. )
  124. # Arithmetic expression parsing
  125. def PrimaryTerm(): return [RollExpr, Number, Identifier]
  126. def TermOrGroup(): return [PrimaryTerm, ParenExpr]
  127. def Exponent(): return ['**', '^'], OpWS, TermOrGroup
  128. def ExponentExpr(): return TermOrGroup, ZeroOrMore(OpWS, Exponent)
  129. def Mul(): return ['*', '×'], OpWS, ExponentExpr
  130. def Div(): return ['/', '÷'], OpWS, ExponentExpr
  131. def ProductExpr(): return ExponentExpr, ZeroOrMore(OpWS, [Mul, Div])
  132. def Add(): return '+', OpWS, ProductExpr
  133. def Sub(): return '-', OpWS, ProductExpr
  134. def SumExpr(): return ProductExpr, ZeroOrMore(OpWS, [Add, Sub])
  135. def ParenExpr(): return Optional(Sign), '(', OpWS, SumExpr, OpWS, ')'
  136. def Expression():
  137. # Wrapped in a tuple to force a separate entry in the parse tree
  138. return (SumExpr,)
  139. # For parsing vars/expressions without evaluating them. The Combine()
  140. # hides the child nodes from a visitor.
  141. def UnevaluatedExpression(): return Combine(Expression)
  142. def UnevaluatedVariable(): return Combine(Identifier)
  143. # Variable assignment
  144. def VarAssignment(): return (
  145. UnevaluatedVariable,
  146. OpWS, '=', OpWS,
  147. UnevaluatedExpression
  148. )
  149. # Commands
  150. def DeleteCommand(): return (
  151. Combine(['delete', 'del', 'd']),
  152. Whitespace,
  153. UnevaluatedVariable,
  154. )
  155. def HelpCommand(): return Combine(['help', 'h', '?'], Not(Identifier))
  156. def QuitCommand(): return Combine(['quit', 'exit', 'q'], Not(Identifier))
  157. def ListVarsCommand(): return Combine(['variables', 'vars', 'v'], Not(Identifier))
  158. def Command(): return [ ListVarsCommand, DeleteCommand, HelpCommand, QuitCommand, ]
  159. def InputParser(): return Optional([Command, VarAssignment, Expression, Whitespace])
  160. # Allow whitespace padding at start or end of inputs
  161. def WSPaddedExpression(): return OpWS, Expression, OpWS
  162. def WSPaddedInputParser(): return OpWS, InputParser, OpWS
  163. def FullParserPython(language_def, *args, **kwargs):
  164. '''Like ParserPython, but auto-adds EOF to the end of the parser.'''
  165. def TempFullParser(): return (language_def, EOF)
  166. return ParserPython(TempFullParser, *args, **kwargs)
  167. expr_parser = FullParserPython(WSPaddedExpression, skipws = False, memoization = True)
  168. input_parser = FullParserPython(WSPaddedInputParser, skipws = False, memoization = True)
  169. def test_parse(rule, text):
  170. if isinstance(text, str):
  171. return FullParserPython(rule, skipws=False, memoization = True).parse(text)
  172. else:
  173. return [ test_parse(rule, x) for x in text ]
  174. def eval_infix(terms: List[float],
  175. operators: List[Callable[[float,float],float]],
  176. associativity: str = 'l') -> float:
  177. '''Evaluate an infix expression with N terms and N-1 operators.'''
  178. assert associativity in ['l', 'r']
  179. assert len(terms) == len(operators) + 1, 'Need one more term than operator'
  180. if len(terms) == 1:
  181. return terms[0]
  182. elif associativity == 'l':
  183. value = terms[0]
  184. for op, term in zip(operators, terms[1:]):
  185. value = op(value, term)
  186. return value
  187. elif associativity == 'r':
  188. value = terms[-1]
  189. for op, term in zip(reversed(operators), reversed(terms[:-1])):
  190. value = op(term, value)
  191. return value
  192. else:
  193. raise ValueError(f'Invalid associativity: {associativity!r}')
  194. class UndefinedVariableError(KeyError):
  195. pass
  196. def print_vars(env: Dict[str,str]) -> None:
  197. if len(env):
  198. print('Currently defined variables:')
  199. for k in sorted(env.keys()):
  200. print('{var} = {value}'.format(
  201. var = color(k, RESULT_COLOR),
  202. value = color(repr(env[k]), EXPR_COLOR)))
  203. else:
  204. print('No variables are currently defined.')
  205. def print_interactive_help() -> None:
  206. print('\n' + '''
  207. To make a roll, type in the roll in dice notation, e.g. '4d4 + 4'.
  208. Nearly all dice notation forms listed in the following references should be supported:
  209. - http://rpg.greenimp.co.uk/dice-roller/
  210. - https://www.critdice.com/roll-advanced-dice
  211. Expressions can include numeric constants, addition, subtraction,
  212. multiplication, division, and exponentiation.
  213. To assign a variable, use 'VAR = VALUE'. For example 'health_potion =
  214. 2d4+2'. Subsequent roll expressions (and other variables) can refer to
  215. this variable, whose value will be substituted in to the expression.
  216. If a variable's value includes any dice rolls, those dice will be
  217. rolled (and produce a different value) every time the variable is
  218. used.
  219. Special commands:
  220. - To show the values of all currently assigned variables, type 'vars'.
  221. - To delete a previously defined variable, type 'del VAR'.
  222. - To show this help text, type 'help'.
  223. - To quit, type 'quit'.
  224. '''.strip() + '\n', file=sys.stdout)
  225. DieFaceType = Union[int, str]
  226. def roll_die(face: DieFaceType = 6) -> int:
  227. '''Roll a single die.
  228. Supports any valid integer number of sides as well as 'F', 'F.1', and
  229. 'F.2' for a Face die, which can return -1, 0, or 1.
  230. '''
  231. if face in ('F', 'F.2'):
  232. # Fate die = 1d3-2
  233. return roll_die(3) - 2
  234. elif face == 'F.1':
  235. d6 = roll_die(6)
  236. if d6 == 1:
  237. return -1
  238. elif d6 == 6:
  239. return 1
  240. else:
  241. return 0
  242. else:
  243. face = int(face)
  244. if face < 2:
  245. raise ValueError(f"Can't roll a {face}-sided die")
  246. return randint(1, face)
  247. def roll_die_with_rerolls(face: int, reroll_condition: Callable, reroll_limit = None) -> List[int]:
  248. '''Roll a single die, and maybe reroll it several times.
  249. After each roll, 'reroll_condition' is called on the result, and
  250. if it returns True, the die is rolled again. All rolls are
  251. collected in a list, and the list is returned as soon as the
  252. condition returns False.
  253. If reroll_limit is provided, it limits the maximum number of
  254. rerolls. Note that the total number of rolls can be one more than
  255. the reroll limit, since the first roll is not considered a reroll.
  256. '''
  257. all_rolls = [roll_die(face)]
  258. while reroll_condition(all_rolls[-1]):
  259. if reroll_limit is not None and len(all_rolls) > reroll_limit:
  260. break
  261. all_rolls.append(roll_die(face))
  262. return all_rolls
  263. class DieRolled(int):
  264. '''Subclass of int that allows an alternate string representation.
  265. This is meant for recording the result of rolling a die. The
  266. formatter argument should include '{}' anywhere that the integer
  267. value should be substituted into the string representation.
  268. (However, it can also override the string representation entirely
  269. by not including '{}'.) The string representation has no effect on
  270. the numeric value. It can be used to indicate a die roll that has
  271. been re-rolled or exploded, or to indicate a critical hit/miss.
  272. '''
  273. formatter: str
  274. def __new__(cls: type, value: int, formatter: str = '{}') -> 'DieRolled':
  275. # https://github.com/python/typeshed/issues/2686
  276. newval = super(DieRolled, cls).__new__(cls, value) # type: ignore
  277. newval.formatter = formatter
  278. return newval
  279. def __str__(self) -> str:
  280. return self.formatter.format(super().__str__())
  281. def __repr__(self) -> str:
  282. if self.formatter != '{}':
  283. return f'DieRolled(value={int(self)!r}, formatter={self.formatter!r})'
  284. else:
  285. return f'DieRolled({int(self)!r})'
  286. def format_dice_roll_list(rolls: List[int], always_list: bool = False) -> str:
  287. if len(rolls) == 0:
  288. raise ValueError('Need at least one die rolled')
  289. elif len(rolls) == 1 and not always_list:
  290. return color(str(rolls[0]), DETAIL_COLOR)
  291. else:
  292. return '[' + color(" ".join(map(str, rolls)), DETAIL_COLOR) + ']'
  293. def int_or_none(x: OptionalType[Any]) -> OptionalType[int]:
  294. if x is None:
  295. return None
  296. else:
  297. return int(x)
  298. def str_or_none(x: OptionalType[Any]) -> OptionalType[str]:
  299. if x is None:
  300. return None
  301. else:
  302. return str(x)
  303. @attr.s
  304. class DiceRolled(object):
  305. '''Class representing the result of rolling one or more similar dice.'''
  306. dice_results: List[int] = attr.ib()
  307. @dice_results.validator
  308. def validate_dice_results(self, attribute, value):
  309. if len(value) == 0:
  310. raise ValueError('Need at least one non-dropped roll')
  311. dropped_results: OptionalType[List[int]] = attr.ib(default = None)
  312. roll_text: OptionalType[str] = attr.ib(
  313. default = None, converter = str_or_none)
  314. success_count: OptionalType[int] = attr.ib(
  315. default = None, converter = int_or_none)
  316. def total(self) -> int:
  317. if self.success_count is not None:
  318. return int(self.success_count)
  319. else:
  320. return sum(self.dice_results)
  321. def __str__(self) -> str:
  322. results = format_dice_roll_list(self.dice_results)
  323. if self.roll_text:
  324. prefix = '{text} rolled'.format(text=color(self.roll_text, EXPR_COLOR))
  325. else:
  326. prefix = 'Rolled'
  327. if self.dropped_results:
  328. drop = ' (dropped {dropped})'.format(dropped = format_dice_roll_list(self.dropped_results))
  329. else:
  330. drop = ''
  331. if self.success_count is not None:
  332. tot = ', Total successes: ' + color(str(self.total()), DETAIL_COLOR)
  333. elif len(self.dice_results) > 1:
  334. tot = ', Total: ' + color(str(self.total()), DETAIL_COLOR)
  335. else:
  336. tot = ''
  337. return f'{prefix}: {results}{drop}{tot}'
  338. def __int__(self) -> int:
  339. return self.total()
  340. def __float__(self) -> float:
  341. return float(self.total())
  342. cmp_dict = {
  343. '<=': operator.le,
  344. '<': operator.lt,
  345. '>=': operator.ge,
  346. '>': operator.gt,
  347. '≤': operator.le,
  348. '≥': operator.ge,
  349. '=': operator.eq,
  350. }
  351. @attr.s
  352. class Comparator(object):
  353. cmp_dict = {
  354. '<=': operator.le,
  355. '<': operator.lt,
  356. '>=': operator.ge,
  357. '>': operator.gt,
  358. '≤': operator.le,
  359. '≥': operator.ge,
  360. '=': operator.eq,
  361. }
  362. operator: str = attr.ib(converter = str)
  363. @operator.validator
  364. def validate_operator(self, attribute, value):
  365. if not value in self.cmp_dict:
  366. raise ValueError(f'Unknown comparison operator: {value!r}')
  367. value: int = attr.ib(converter = int)
  368. def __str__(self) -> str:
  369. return '{op}{val}'.format(op=self.operator, val=self.value)
  370. def compare(self, x: float) -> bool:
  371. '''Return True if x satisfies the comparator.
  372. In other words, x is placed on the left-hand side of the
  373. comparison, the Comparator's value is placed on the right hand
  374. side, and the truth value of the resulting test is returned.
  375. '''
  376. return self.cmp_dict[self.operator](x, self.value)
  377. def __call__(self, x: float) -> bool:
  378. '''Calls Comparator.compare() on x.
  379. This allows the Comparator to be used as a callable.'''
  380. return self.compare(x)
  381. def validate_comparators(face: DieFaceType, *comparators: Comparator):
  382. '''Validate one or more comparators for a die face type.
  383. This will test every possible die value, making sure that each
  384. test succeeds on at least one value and fails on at least one
  385. value, and it will make sure that at most one test succeeds on any
  386. given value.
  387. '''
  388. if isinstance(face, str):
  389. assert face.startswith('F')
  390. values = range(-1, 2)
  391. else:
  392. values = range(1, face+1)
  393. # Validate individual comparators
  394. for comp in comparators:
  395. results = [ comp(val) for val in values ]
  396. if all(results):
  397. raise ValueError(f"Test {str(comp)!r} can never fail for d{face}")
  398. if not any(results):
  399. raise ValueError(f"Test {str(comp)!r} can never succeed for d{face}")
  400. # Check for comparator overlap
  401. for val in values:
  402. passing_comps = [ comp for comp in comparators if comp(val) ]
  403. if len(passing_comps) > 1:
  404. raise ValueError(f"Can't have overlapping tests: {str(passing_comps[0])!r} and {str(passing_comps[1])!r}")
  405. def roll_dice(roll_desc: Dict) -> DiceRolled:
  406. '''Roll dice based on a roll description.
  407. See InputHandler.visit_RollExpr(), which generates roll
  408. descriptions. This function assumes the roll description is
  409. already validated.
  410. Returns a tuple of two lists. The first list is the kept rolls,
  411. and the second list is the dropped rolls.
  412. '''
  413. die_face: DieFaceType = roll_desc['die_face']
  414. dice_count: int = roll_desc['dice_count']
  415. kept_rolls: List[int] = []
  416. dropped_rolls: OptionalType[List[int]] = None
  417. success_count: Optional[int] = None
  418. if 'reroll_type' in roll_desc:
  419. die_face = int(die_face) # No fate dice
  420. reroll_type: str = roll_desc['reroll_type']
  421. reroll_limit = 1 if reroll_type == 'r' else None
  422. reroll_desc: Dict = roll_desc['reroll_desc']
  423. reroll_comparator = Comparator(
  424. operator = reroll_desc['comparator'],
  425. value = reroll_desc['target'],
  426. )
  427. validate_comparators(die_face, reroll_comparator)
  428. for i in range(dice_count):
  429. current_rolls = roll_die_with_rerolls(die_face, reroll_comparator, reroll_limit)
  430. if len(current_rolls) == 1:
  431. # If no rerolls happened, then just add the single
  432. # roll as is.
  433. kept_rolls.append(current_rolls[0])
  434. elif reroll_type in ['r', 'R']:
  435. # Keep only the last roll, and mark it as rerolled
  436. kept_rolls.append(DieRolled(current_rolls[-1], '{}' + reroll_type))
  437. elif reroll_type in ['!', '!!', '!p', '!!p']:
  438. if reroll_type.endswith('p'):
  439. # For penetration, subtract 1 from all rolls
  440. # except the first
  441. for i in range(1, len(current_rolls)):
  442. current_rolls[i] -= 1
  443. if reroll_type.startswith('!!'):
  444. # For compounding, return the sum, marked as a
  445. # compounded roll.
  446. kept_rolls.append(DieRolled(sum(current_rolls),
  447. '{}' + reroll_type))
  448. else:
  449. # For exploding, add each individual roll to the
  450. # list. Mark each roll except the last as
  451. # rerolled.
  452. for i in range(0, len(current_rolls) - 1):
  453. current_rolls[i] = DieRolled(current_rolls[i], '{}' + reroll_type)
  454. kept_rolls.extend(current_rolls)
  455. else:
  456. raise ValueError(f'Unknown reroll type: {reroll_type}')
  457. else:
  458. # Roll the requested number of dice
  459. all_rolls = [ roll_die(die_face) for i in range(dice_count) ]
  460. if 'drop_type' in roll_desc:
  461. keep_count: int = roll_desc['keep_count']
  462. keep_high: bool = roll_desc['keep_high']
  463. # We just need to keep the highest/lowest N rolls. The
  464. # extra complexity here is just to preserve the original
  465. # order of those rolls.
  466. rolls_to_keep = sorted(all_rolls, reverse = keep_high)[:keep_count]
  467. kept_rolls = []
  468. for roll in rolls_to_keep:
  469. kept_rolls.append(all_rolls.pop(all_rolls.index(roll)))
  470. # Remaining rolls are dropped
  471. dropped_rolls = all_rolls
  472. else:
  473. kept_rolls = all_rolls
  474. if 'count_success' in roll_desc:
  475. die_face = int(die_face) # No fate dice
  476. success_desc = roll_desc['count_success']
  477. success_test = Comparator(
  478. operator = success_desc['comparator'],
  479. value = success_desc['target'],
  480. )
  481. # Sanity check: make sure the success test can be met
  482. if not any(map(success_test, range(1, die_face +1))):
  483. raise ValueError(f"Test {str(success_test)!r} can never succeed for d{die_face}")
  484. validate_comparators(die_face, success_test)
  485. success_count = sum(success_test(x) for x in kept_rolls)
  486. if 'count_failure' in roll_desc:
  487. failure_desc = roll_desc['count_failure']
  488. failure_test = Comparator(
  489. operator = failure_desc['comparator'],
  490. value = failure_desc['target'],
  491. )
  492. validate_comparators(die_face, success_test, failure_test)
  493. success_count -= sum(failure_test(x) for x in kept_rolls)
  494. else:
  495. # TODO: Label crits and critfails here
  496. pass
  497. return DiceRolled(
  498. dice_results = kept_rolls,
  499. dropped_results = dropped_rolls,
  500. success_count = success_count,
  501. roll_text = roll_desc['roll_text'],
  502. )
  503. class ExpressionStringifier(PTNodeVisitor):
  504. def __init__(self, **kwargs):
  505. self.env: Dict[str, str] = kwargs.pop('env', {})
  506. self.recursed_vars: Set[str] = kwargs.pop('recursed_vars', set())
  507. self.expr_parser = kwargs.pop('expr_parser', expr_parser)
  508. super().__init__(**kwargs)
  509. def visit__default__(self, node, children):
  510. if children:
  511. return ''.join(children)
  512. else:
  513. return node.value
  514. def visit_Identifier(self, node, children):
  515. '''Interpolate variable.'''
  516. var_name = node.value
  517. if var_name in self.recursed_vars:
  518. raise ValueError(f'Recursive variable definition detected for {var_name!r}')
  519. try:
  520. var_expression = self.env[var_name]
  521. except KeyError as ex:
  522. raise UndefinedVariableError(*ex.args)
  523. recursive_visitor = copy(self)
  524. recursive_visitor.recursed_vars = self.recursed_vars.union([var_name])
  525. return self.expr_parser.parse(var_expression).visit(recursive_visitor)
  526. class QuitRequested(BaseException):
  527. pass
  528. class InputHandler(PTNodeVisitor):
  529. def __init__(self, **kwargs):
  530. self.expr_stringifier = ExpressionStringifier(**kwargs)
  531. self.env: Dict[str, str] = kwargs.pop('env', {})
  532. self.recursed_vars: Set[str] = kwargs.pop('recursed_vars', set())
  533. self.expr_parser = kwargs.pop('expr_parser', expr_parser)
  534. self.print_results = kwargs.pop('print_results', True)
  535. self.print_rolls = kwargs.pop('print_rolls', True)
  536. super().__init__(**kwargs)
  537. def visit_Whitespace(self, node, children):
  538. '''Remove whitespace nodes'''
  539. return None
  540. def visit_Number(self, node, children):
  541. '''Return the numeric value.
  542. Uses int if possible, otherwise float.'''
  543. try:
  544. return int(node.value)
  545. except ValueError:
  546. return float(node.value)
  547. def visit_NonzeroDigits(self, node, children):
  548. return int(node.flat_str())
  549. def visit_Integer(self, node, children):
  550. return int(node.flat_str())
  551. def visit_PercentileFace(self, node, children):
  552. return 100
  553. def visit_BasicRollExpr(self, node, children):
  554. die_face = children[-1]
  555. if isinstance(die_face, int) and die_face < 2:
  556. raise ValueError(f"Invalid roll: Can't roll a {die_face}-sided die")
  557. return {
  558. 'dice_count': children[0] if len(children) == 3 else 1,
  559. 'die_face': die_face,
  560. }
  561. def visit_DropSpec(self, node, children):
  562. return {
  563. 'drop_type': children[0],
  564. 'drop_or_keep_count': children[1] if len(children) > 1 else 1,
  565. }
  566. def visit_RerollSpec(self, node, children):
  567. if len(children) == 1:
  568. return {
  569. 'reroll_type': children[0],
  570. # The default reroll condition depends on other parts
  571. # of the roll expression, so it will be "filled in"
  572. # later.
  573. }
  574. elif len(children) == 2:
  575. return {
  576. 'reroll_type': children[0],
  577. 'reroll_desc': {
  578. 'comparator': '=',
  579. 'target': children[1],
  580. },
  581. }
  582. elif len(children) == 3:
  583. return {
  584. 'reroll_type': children[0],
  585. 'reroll_desc': {
  586. 'comparator': children[1],
  587. 'target': children[2],
  588. },
  589. }
  590. else:
  591. raise ValueError("Invalid reroll specification")
  592. def visit_Comparison(self, node, children):
  593. return {
  594. 'comparator': children[0],
  595. 'target': children[1],
  596. }
  597. def visit_CountSpec(self, node, children):
  598. result = { 'count_success': children[0], }
  599. if len(children) > 1:
  600. result['count_failure'] = children[1]
  601. return result
  602. def visit_RollExpr(self, node, children):
  603. # Collect all child dicts into one
  604. roll_desc = {
  605. 'roll_text': node.flat_str(),
  606. }
  607. for child in children:
  608. roll_desc.update(child)
  609. logger.debug(f'Initial roll description: {roll_desc!r}')
  610. # Perform some validation that can only be done once the
  611. # entire roll description is collected.
  612. if not isinstance(roll_desc['die_face'], int):
  613. if 'reroll_type' in roll_desc:
  614. raise ValueError('Can only reroll/explode numeric dice, not Fate dice')
  615. if 'count_success' in roll_desc:
  616. raise ValueError('Can only count successes on numeric dice, not Fate dice')
  617. # Fill in implicit reroll type
  618. if 'reroll_type' in roll_desc and not 'reroll_desc' in roll_desc:
  619. rrtype = roll_desc['reroll_type']
  620. if rrtype in ['r', 'R']:
  621. roll_desc['reroll_desc'] = {
  622. 'comparator': '=',
  623. 'target': 1,
  624. }
  625. else:
  626. roll_desc['reroll_desc'] = {
  627. 'comparator': '=',
  628. 'target': roll_desc['die_face'],
  629. }
  630. # Validate drop spec and determine exactly how many dice to
  631. # drop/keep
  632. if 'drop_type' in roll_desc:
  633. dtype = roll_desc['drop_type']
  634. keeping = dtype in ['K', 'k', 'kl', 'kh']
  635. if keeping:
  636. roll_desc['keep_count'] = roll_desc['drop_or_keep_count']
  637. else:
  638. roll_desc['keep_count'] = roll_desc['dice_count'] - roll_desc['drop_or_keep_count']
  639. if roll_desc['keep_count'] < 1:
  640. drop_count = roll_desc['dice_count'] - roll_desc['keep_count']
  641. raise ValueError(f"Can't drop {drop_count} dice out of {roll_desc['dice_count']}")
  642. if roll_desc['keep_count'] >= roll_desc['dice_count']:
  643. raise ValueError(f"Can't keep {roll_desc['keep_count']} dice out of {roll_desc['dice_count']}")
  644. # Keeping high rolls is the same as dropping low rolls
  645. roll_desc['keep_high'] = dtype in ['K', 'kh', 'x', '-L']
  646. # Validate count spec
  647. elif 'count_failure' in roll_desc and not 'count_success' in roll_desc:
  648. # The parser shouldn't allow this, but just in case
  649. raise ValueError("Can't have a failure condition without a success condition")
  650. logger.debug(f'Final roll description: {roll_desc!r}')
  651. result = roll_dice(roll_desc)
  652. if self.print_rolls:
  653. print(str(result))
  654. return int(result)
  655. def visit_Identifier(self, node, children):
  656. '''Interpolate variable.'''
  657. var_name = node.value
  658. if var_name in self.recursed_vars:
  659. raise ValueError(f'Recursive variable definition detected for {var_name!r}')
  660. try:
  661. var_expression = self.env[var_name]
  662. except KeyError as ex:
  663. raise UndefinedVariableError(*ex.args)
  664. recursive_visitor = copy(self)
  665. recursive_visitor.recursed_vars = self.recursed_vars.union([var_name])
  666. # Don't print the results of evaluating variables
  667. recursive_visitor.print_results = False
  668. if self.debug:
  669. self.dprint(f'Evaluating variable {var_name} with expression {var_expression!r}')
  670. return self.expr_parser.parse(var_expression).visit(recursive_visitor)
  671. def visit_Expression(self, node, children):
  672. if self.print_results:
  673. expr_full_text = node.visit(self.expr_stringifier)
  674. print('Result: {result} (rolled {expr})'.format(
  675. expr=color(expr_full_text, EXPR_COLOR),
  676. result=color(f'{children[0]:g}', RESULT_COLOR),
  677. ))
  678. return children[0]
  679. # Each of these returns a tuple of (operator, value)
  680. def visit_Add(self, node, children):
  681. return (operator.add, children[-1])
  682. def visit_Sub(self, node, children):
  683. return (operator.sub, children[-1])
  684. def visit_Mul(self, node, children):
  685. return (operator.mul, children[-1])
  686. def visit_Div(self, node, children):
  687. return (operator.truediv, children[-1])
  688. def visit_Exponent(self, node, children):
  689. return (operator.pow, children[-1])
  690. # Each of these receives a first child that is a number and the
  691. # remaining children are tuples of (operator, number)
  692. def visit_SumExpr(self, node, children):
  693. values = [children[0]]
  694. ops = []
  695. for (op, val) in children[1:]:
  696. values.append(val)
  697. ops.append(op)
  698. if self.debug:
  699. self.dprint(f'Sum: values: {values!r}; ops: {ops!r}')
  700. return eval_infix(values, ops, 'l')
  701. def visit_ProductExpr(self, node, children):
  702. values = [children[0]]
  703. ops = []
  704. for (op, val) in children[1:]:
  705. values.append(val)
  706. ops.append(op)
  707. if self.debug:
  708. self.dprint(f'Product: values: {values!r}; ops: {ops!r}')
  709. return eval_infix(values, ops, 'l')
  710. def visit_ExponentExpr(self, node, children):
  711. values = [children[0]]
  712. ops = []
  713. for (op, val) in children[1:]:
  714. values.append(val)
  715. ops.append(op)
  716. if self.debug:
  717. self.dprint(f'Exponent: values: {values!r}; ops: {ops!r}')
  718. return eval_infix(values, ops, 'l')
  719. def visit_Sign(self, node, children):
  720. if node.value == '-':
  721. return -1
  722. else:
  723. return 1
  724. def visit_ParenExpr(self, node, children):
  725. assert len(children) > 0
  726. # Multiply the sign (if present) and the value inside the
  727. # parens
  728. return functools.reduce(operator.mul, children)
  729. def visit_VarAssignment(self, node, children):
  730. logger.debug(f'Doing variable assignment: {node.flat_str()}')
  731. var_name, var_value = children
  732. print('Saving "{var}" as "{expr}"'.format(
  733. var=color(var_name, RESULT_COLOR),
  734. expr=color(var_value, EXPR_COLOR),
  735. ))
  736. self.env[var_name] = var_value
  737. def visit_ListVarsCommand(self, node, children):
  738. print_vars(self.env)
  739. def visit_DeleteCommand(self, node, children):
  740. var_name = children[-1]
  741. print('Deleting saved value for "{var}".'.format(
  742. var=color(var_name, RESULT_COLOR)))
  743. try:
  744. self.env.pop(var_name)
  745. except KeyError as ex:
  746. raise UndefinedVariableError(*ex.args)
  747. def visit_HelpCommand(self, node, children):
  748. print_interactive_help()
  749. def visit_QuitCommand(self, node, children):
  750. raise QuitRequested()
  751. # def handle_input(expr: str, **kwargs) -> float:
  752. # return input_parser.parse(expr).visit(InputHandler(**kwargs))
  753. # handle_input('help')
  754. # handle_input('2+2 * 2 ** 2')
  755. # env = {}
  756. # handle_input('y = 2 + 2', env = env)
  757. # handle_input('x = y + 2', env = env)
  758. # handle_input('2 + x', env = env)
  759. # handle_input('del x', env = env)
  760. # handle_input('vars', env = env)
  761. # handle_input('2 + x', env = env)
  762. # handle_input('d4 = 5', env = env)
  763. def read_input(handle: TextIO = sys.stdin) -> str:
  764. if handle == sys.stdin:
  765. return input("Enter roll> ")
  766. else:
  767. return handle.readline()[:-1]
  768. if __name__ == '__main__':
  769. expr_string = " ".join(sys.argv[1:])
  770. if re.search("\\S", expr_string):
  771. try:
  772. expr_parser.parse(expr_string).visit(InputHandler())
  773. except Exception as exc:
  774. logger.error("Error while rolling: %s", repr(exc))
  775. raise exc
  776. sys.exit(1)
  777. else:
  778. env: Dict[str, str] = {}
  779. handler = InputHandler(env = env)
  780. while True:
  781. try:
  782. input_string = read_input()
  783. input_parser.parse(input_string).visit(handler)
  784. except KeyboardInterrupt:
  785. print('')
  786. except (EOFError, QuitRequested):
  787. print('')
  788. logger.info('Quitting.')
  789. break
  790. except Exception as exc:
  791. logger.error('Error while evaluating {expr!r}:\n{tb}'.format(
  792. expr=expr_string,
  793. tb=traceback.format_exc(),
  794. ))