Explorar el Código

Add variable support and better printing

Ryan C. Thompson hace 7 años
padre
commit
5c2e74226b
Se han modificado 1 ficheros con 211 adiciones y 63 borrados
  1. 211 63
      roll.py

+ 211 - 63
roll.py

@@ -8,7 +8,7 @@ import readline
 import operator
 from numbers import Number
 from random import randint
-from pyparsing import Regex, oneOf, Optional, Group, Combine, Literal, CaselessLiteral, ZeroOrMore, StringStart, StringEnd, opAssoc, infixNotation, ParseException, Empty, pyparsing_common, ParseResults
+from pyparsing import Regex, oneOf, Optional, Group, Combine, Literal, CaselessLiteral, ZeroOrMore, StringStart, StringEnd, opAssoc, infixNotation, ParseException, Empty, pyparsing_common, ParseResults, White, Suppress
 logFormatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
 logger = logging.getLogger(__name__)
 logger.setLevel(logging.INFO)
@@ -77,7 +77,7 @@ def ImplicitToken(x):
 
 # https://stackoverflow.com/a/46583691/125921
 
-var_name = pyparsing_common.identifier.copy()
+var_name = pyparsing_common.identifier.copy().setResultsName('varname')
 real_num = pyparsing_common.fnumber.copy()
 positive_int = pyparsing_common.integer.copy().setParseAction(lambda toks: [ IntegerValidator(min_val=1)(toks[0]) ])
 
@@ -90,7 +90,7 @@ pos_int_implicit_one = (positive_int | ImplicitToken(1))
 
 comparator_type = oneOf('<= < >= > ≤ ≥ =')
 
-reroll_type = Combine(oneOf('R r') | ( oneOf('! !!') + Optional('p')))
+reroll_type = Combine(oneOf('R r') ^ ( oneOf('! !!') + Optional('p')))
 reroll_spec = Group(
     reroll_type.setResultsName('type') +
     Optional(
@@ -117,19 +117,20 @@ roll_spec = Group(
     (positive_int | ImplicitToken(1)).setResultsName('dice_count') +
     CaselessLiteral('d') +
     (positive_int | oneOf('% F')).setResultsName('die_type') +
-    Optional(reroll_spec | drop_spec) +
+    Optional(reroll_spec ^ drop_spec) +
     Optional(count_spec)
 ).setResultsName('roll')
 
-expr_parser = (StringStart() + infixNotation(
+expr_parser = infixNotation(
     baseExpr=(roll_spec | positive_int | real_num | var_name),
     opList=[
-        (oneOf('**').setResultsName('operator', True), 2, opAssoc.RIGHT),
+        (oneOf('** ^').setResultsName('operator', True), 2, opAssoc.RIGHT),
         (oneOf('* / × ÷').setResultsName('operator', True), 2, opAssoc.LEFT),
         (oneOf('+ -').setResultsName('operator', True), 2, opAssoc.LEFT),
     ]
-) + StringEnd())
+).setResultsName('expr')
 
+assignment_parser = var_name + Literal('=').setResultsName('assignment') + expr_parser
 
 def roll_die(sides=6):
     '''Roll a single die.
@@ -173,8 +174,8 @@ class DieRolled(int):
 def validate_dice_roll_list(instance, attribute, value):
     for x in value:
         # Not using positive_int here because 0 is a valid roll for
-        # penetrating dice
-        pyparsing_common.integer.parseString(str(x))
+        # penetrating dice, and -1 and 0 are valid for fate dice
+        pyparsing_common.signed_integer.parseString(str(x))
 def format_dice_roll_list(rolls, always_list=False):
     if len(rolls) == 0:
         raise ValueError('Need at least one die rolled')
@@ -268,22 +269,16 @@ class RerollSpec(object):
     type = attr.ib(convert=str, validator=validate_by_parser(reroll_type))
     operator = attr.ib(default=None)
     value = attr.ib(default=None)
-    comparator = attr.ib(default=None, repr=False)
 
     def __attrs_post_init__(self):
-        if self.comparator is None:
-            self.comparator = Comparator(self.operator, self.value)
-        else:
-            if self.operator is not None or self.value is not None:
-                raise ValueError('Do not provide opeartor or value if providing a pre-build comparator')
-            self.operator = self.comparator.operator
-            self.value = self.comparator.value
+        if (self.operator is None) != (self.value is None):
+            raise ValueError('Operator and value must be provided together')
 
     def __str__(self):
-        return '{typ}{cmp}'.format(typ=self.type, cmp=self.comparator)
-
-    def compare(self, x):
-        return self.comparator.compare(x)
+        result = self.type
+        if self.operator is not None:
+            result += self.operator + self.value
+        return result
 
     def roll_die(self, sides):
         '''Roll a single die, following specified re-rolling rules.
@@ -293,23 +288,33 @@ class RerollSpec(object):
 
         '''
         if sides == 'F':
-            raise ValueError("Re-rolling/exploding is incompatible with Fate")
+            raise ValueError("Re-rolling/exploding is incompatible with Fate dice")
+        sides = int(sides)
+
+        if self.value is None:
+            if self.type in ('R', 'r'):
+                cmpr = Comparator('=', 1)
+            else:
+                cmpr = Comparator('=', sides)
+        else:
+            cmpr = Comparator(self.operator, self.value)
+
         if self.type == 'r':
             # Single reroll
             roll = roll_die(sides)
-            if self.compare(roll):
+            if cmpr.compare(roll):
                 roll = DieRolled(roll_die(sides), '{}' + self.type)
             return [ roll ]
         elif self.type == 'R':
             # Indefinite reroll
             roll = roll_die(sides)
-            while self.compare(roll):
+            while cmpr.compare(roll):
                 roll = DieRolled(roll_die(sides), '{}' + self.type)
             return [ roll ]
         elif self.type in ['!', '!!', '!p', '!!p']:
             # Explode/penetrate/compound
             all_rolls = [ roll_die(sides) ]
-            while self.compare(all_rolls[-1]):
+            while cmpr.compare(all_rolls[-1]):
                 all_rolls.append(roll_die(sides))
             # If we never re-rolled, no need to do anything special
             if len(all_rolls) == 1:
@@ -333,7 +338,10 @@ class DropSpec(object):
     count = attr.ib(default=1, convert=int, validator=validate_by_parser(positive_int))
 
     def __str__(self):
-        return self.type + str(self.count)
+        if self.count > 1:
+            return self.type + str(self.count)
+        else:
+            return self.type
 
     def drop_rolls(self, rolls):
         '''Drop the appripriate rolls from a list of rolls.
@@ -395,12 +403,12 @@ class DiceRoller(object):
 
     def __str__(self):
         return '{count}d{type}{reroll}{drop}{success}{fail}'.format(
-            count = self.dice_count,
+            count = self.dice_count if self.dice_count > 1 else '',
             type = self.die_type,
             reroll = self.reroll_spec or '',
             drop = self.drop_spec or '',
             success = self.success_comparator or '',
-            fail = 'f' + str(self.success_comparator) if self.success_comparator else '',
+            fail = ('f' + str(self.failure_comparator)) if self.failure_comparator else '',
         )
 
     def roll(self):
@@ -436,7 +444,7 @@ class DiceRoller(object):
 
 def make_dice_roller(expr):
     if isinstance(expr, str):
-        expr = roll_spec.parseString(expr, True)[0]
+        expr = roll_spec.parseString(expr, True)['expr']
     assert expr.getName() == 'roll'
     expr = expr.asDict()
 
@@ -454,13 +462,6 @@ def make_dice_roller(expr):
     rrdict = None
     if 'reroll' in expr:
         rrdict = expr['reroll']
-        if 'value' not in rrdict:
-            rrdict['operator'] = '='
-            # Default value is 1 for rerollers, dtype for exploders
-            if rrdict['type'] in ('R', 'r'):
-                rrdict['value'] = 1
-            else:
-                rrdict['value'] = dtype
         constructor_args['reroll_spec'] = RerollSpec(**rrdict)
 
     if 'drop' in expr:
@@ -530,23 +531,41 @@ op_dict = {
     '/': operator.truediv,
     '÷': operator.truediv,
     '**': operator.pow,
+    '^': operator.pow,
 }
 
-def eval_expr(expr, env={}, print_rolls=True):
+def normalize_expr(expr):
     if isinstance(expr, str):
-        expr = expr_parser.parseString(expr)[0]
+        return expr_parser.parseString(expr)['expr']
+    try:
+        if 'expr' in expr:
+            return expr['expr']
+    except TypeError:
+        pass
+    return expr
+
+def _eval_expr_internal(expr, env={}, print_rolls=True, recursed_vars=set()):
     if isinstance(expr, Number):
         # Numeric literal
         return expr
     elif isinstance(expr, str):
         # variable name
-        raise NotImplementedError("Variables are not implemented yet")
-    if 'operator' in expr:
+        if expr in recursed_vars:
+            raise ValueError('Recursive variable definition detected for {!r}'.format(expr))
+        elif expr in env:
+            var_value = env[expr]
+            parsed = normalize_expr(var_value)
+            return _eval_expr_internal(parsed, env, print_rolls,
+                                       recursed_vars = recursed_vars.union([expr]))
+        else:
+            raise ValueError('Expression referenced undefined variable {!r}'.format(expr))
+    elif 'operator' in expr:
         # Compound expression
         operands = expr[::2]
         operators = expr[1::2]
         assert len(operands) == len(operators) + 1
-        values = [ eval_expr(x) for x in operands ]
+        values = [ _eval_expr_internal(x, env, print_rolls, recursed_vars)
+                   for x in operands ]
         result = values[0]
         for (op, nextval) in zip(operators, values[1:]):
             opfun = op_dict[op]
@@ -560,34 +579,163 @@ def eval_expr(expr, env={}, print_rolls=True):
             print(result)
         return int(result)
 
+def eval_expr(expr, env={}, print_rolls=True):
+    expr = normalize_expr(expr)
+    return _eval_expr_internal(expr, env, print_rolls)
+
+def _expr_as_str_internal(expr, env={}, recursed_vars = set()):
+    if isinstance(expr, Number):
+        # Numeric literal
+        return str(expr)
+    elif isinstance(expr, str):
+        # variable name
+        if expr in recursed_vars:
+            raise ValueError('Recursive variable definition detected for {!r}'.format(expr))
+        elif expr in env:
+            var_value = env[expr]
+            parsed = normalize_expr(var_value)
+            return _expr_as_str_internal(parsed, env, recursed_vars = recursed_vars.union([expr]))
+        else:
+            raise ValueError('Expression referenced undefined variable {!r}'.format(expr))
+    elif 'operator' in expr:
+        # Compound expression
+        operands = expr[::2]
+        operators = expr[1::2]
+        assert len(operands) == len(operators) + 1
+        values = [ _expr_as_str_internal(x, env, recursed_vars)
+                   for x in operands ]
+        result = str(values[0])
+        for (op, nextval) in zip(operators, values[1:]):
+            result += ' {} {}'.format(op, nextval)
+        return '(' + result + ')'
+    else:
+        # roll specification
+        return str(make_dice_roller(expr))
+
+def expr_as_str(expr, env={}):
+    expr = normalize_expr(expr)
+    expr = _expr_as_str_internal(expr, env)
+    if expr.startswith('(') and expr.endswith(')'):
+        expr = expr[1:-1]
+    return expr
+
 def read_roll(handle=sys.stdin):
     return input("Enter roll> ")
 
+special_command_parser = (
+    oneOf('h help ?').setResultsName('help') |
+    oneOf('q quit exit').setResultsName('quit') |
+    oneOf('v vars').setResultsName('vars') |
+    (oneOf('d del delete').setResultsName('delete').leaveWhitespace() + Suppress(White()) + var_name)
+)
+
+def var_name_allowed(vname):
+    '''Disallow variable names like 'help' and 'quit'.'''
+    try:
+        special_command_parser.parseString(vname, True)
+        return False
+    except ParseException:
+        return True
+
+line_parser = (special_command_parser ^ (assignment_parser | expr_parser))
+
+def print_interactive_help():
+    print('\n' + '''
+To make a roll, type in the roll in dice notation, e.g. '4d4 + 4'. All
+dice notation forms listed in
+https://www.critdice.com/roll-advanced-dice and
+http://rpg.greenimp.co.uk/dice-roller/ should be supported.
+Expressions can include addition, subtraction, multiplication,
+division, and exponentiation.
+
+To assign a variable, use 'VAR = VALUE'. For example 'health_potion =
+4d4+4'. Subsequent roll expressions can refer to this variable, whose
+value will be substituted in to the expression.
+
+If a variable's value includes any dice rolls, those dice will be
+rolled (and produce a different value) every each time the variable is
+used.
+
+To delete a variable, type 'del VAR'.
+
+To show the values of all currently assigned variables, type 'vars'.
+
+To show this help text, type 'help'.
+
+To quit, type 'quit'.
+    '''.strip() + '\n', file=sys.stdout)
+
+def print_vars(env):
+    if len(env):
+        logger.info('Currently defined variables:')
+        for k in sorted(env.keys()):
+            print('{} = {!r}'.format(k, env[k]), file=sys.stderr)
+    else:
+        logger.info('No vars are currently defined.')
+
 if __name__ == '__main__':
-    expr = " ".join(sys.argv[1:])
-    if re.search("\\S", expr):
+    expr_string = " ".join(sys.argv[1:])
+    if re.search("\\S", expr_string):
         try:
-            result = roll(expr)
-            logger.info("Total roll: %s", result)
+            # Note: using expr_parser instead of line_parser, because
+            # on the command line only roll expressions are valid.
+            expr = expr_parser.parseString(expr_string, True)
+            result = eval_expr(expr)
+            logger.info("Total roll for {expr!r}: {result}".format(
+                expr=expr_as_str(expr),
+                result=result
+            ))
         except Exception as exc:
             logger.error("Error while rolling: %s", repr(exc))
+            raise exc
             sys.exit(1)
     else:
-        try:
-            while True:
-                try:
-                    expr = read_roll()
-                    if expr in ('exit', 'quit', 'q'):
-                        break
-                    if re.search("\\S", expr):
-                        try:
-                            result = eval_expr(expr)
-                            logger.info("Total roll: %s", result)
-                            print('', file=sys.stderr)
-                        except Exception as exc:
-                            logger.error("Error while rolling: %s", repr(exc))
-                except KeyboardInterrupt:
-                    print('')
-        except EOFError:
-            # Print a newline before exiting
-            print('')
+        env = {}
+        while True:
+            try:
+                expr_string = read_roll()
+                if not re.search("\\S", expr_string):
+                    continue
+                parsed = line_parser.parseString(expr_string, True)
+                if 'help' in parsed:
+                    print_interactive_help()
+                elif 'quit' in parsed:
+                    logger.info('Quitting.')
+                    break
+                elif 'vars' in parsed:
+                    print_vars(env)
+                elif 'delete' in parsed:
+                    vname = parsed['varname']
+                    if vname in env:
+                        logger.info('Deleting saved value for {var!r}.'.format(var=vname))
+                        del env[vname]
+                    else:
+                        logger.error('Variable {var!r} is not defined.'.format(var=vname))
+                elif re.search("\\S", expr_string):
+                    if 'assignment' in parsed:
+                        # We have an assignment operation
+                        vname = parsed['varname']
+                        if var_name_allowed(vname):
+                            env[vname] = expr_as_str(parsed['expr'], env)
+                            logger.info('Saving {var} as {expr!r}'.format(
+                                var=vname, expr=env[vname],
+                            ))
+                        else:
+                            logger.error('You cannot use {!r} as a variable name because it is a special command.'.format(vname))
+                    else:
+                        # Just an expression to evaluate
+                        result = eval_expr(parsed['expr'], env)
+                        logger.info('Total roll for {expr!r}: {result}'.format(
+                            expr=expr_as_str(parsed, env),
+                            result=result,
+                        ))
+                print('', file=sys.stderr)
+            except KeyboardInterrupt:
+                print('', file=sys.stderr)
+            except EOFError:
+                logger.info('Quitting.')
+                break
+            except Exception as exc:
+                logger.error('Error while evaluating {expr!r}: {ex!r}'.format(
+                    expr=expr_string, ex=exc,
+                ))