From 31d107009366a70940ff0a9a79e5b1c80226443d Mon Sep 17 00:00:00 2001 From: Thierry Streiff Date: Tue, 21 Jul 2020 19:34:38 +0200 Subject: [PATCH] Fix issue #111 on switch/case/default statements. Ensure that the switch expression is of integer type and apply integer promotion on it. Ensure default statements is not used twice for a given switch. Ensure case values (or ranges for the GCC extension) are integer constants, and that no case value or range is in conflict with the preceding cases of the same switch. --- ppci/lang/c/parser.py | 9 +++--- ppci/lang/c/semantics.py | 60 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/ppci/lang/c/parser.py b/ppci/lang/c/parser.py index de229a15..bc894a34 100644 --- a/ppci/lang/c/parser.py +++ b/ppci/lang/c/parser.py @@ -876,9 +876,10 @@ def parse_switch_statement(self): self.consume("(") expression = self.parse_expression() self.consume(")") - self.semantics.on_switch_enter(expression) + # expression type can be changed by integer promotion + prom_expression = self.semantics.on_switch_enter(expression) statement = self.parse_statement() - return self.semantics.on_switch_exit(expression, statement, location) + return self.semantics.on_switch_exit(prom_expression, statement, location) def parse_case_statement(self): """ Parse a case. @@ -891,11 +892,11 @@ def parse_case_statement(self): 'case 5 ... 10:' """ location = self.consume("case").loc - value = self.parse_expression() + value = self.parse_constant_expression() if self.peek == "...": # gnu extension! self.consume("...") - value2 = self.parse_expression() + value2 = self.parse_constant_expression() value = (value, value2) self.consume(":") diff --git a/ppci/lang/c/semantics.py b/ppci/lang/c/semantics.py index f45aa453..53a9b0c5 100644 --- a/ppci/lang/c/semantics.py +++ b/ppci/lang/c/semantics.py @@ -25,6 +25,37 @@ from .printer import expr_to_str, type_to_str +class CSwitch: + """ Switch instruction context """ + def __init__(self, typ): + self.typ = typ + self.default_seen = False + self.values = [] + self.ranges = [] + + def add_value(self, val): + """ Add a case value, returns True if OK, False otherwise """ + for low, high in self.ranges: + if val >= low and val <= high: + return False + for value in self.values: + if val == value: + return False + self.values.append(val) + return True + + def add_range(self, val1, val2): + """ Add a case range, returns True if OK, False otherwise """ + for low, high in self.ranges: + if (val1 >= low and val1 <= high) or (val2 >= low and val2 <= high): + return False + for value in self.values: + if value >= val1 and value <= val2: + return False + self.ranges.append((val1, val2)) + return True + + class CSemantics: """ This class handles the C semantics """ @@ -540,7 +571,10 @@ def on_if(self, condition, then_statement, no, location): return statements.If(condition, then_statement, no, location) def on_switch_enter(self, expression): - self.switch_stack.append(expression.typ) + self.ensure_integer(expression) + prom_expression = self.promote(expression) + self.switch_stack.append(CSwitch(prom_expression.typ)) + return prom_expression def on_switch_exit(self, expression, statement, location): """ Handle switch statement """ @@ -553,19 +587,35 @@ def on_case(self, value, statement, location): self.error("Case statement outside of a switch!", location) if isinstance(value, tuple): + # case value1 ... value2: value1, value2 = value - value1 = self.coerce(value1, self.switch_stack[-1]) - value2 = self.coerce(value2, self.switch_stack[-1]) + self.ensure_integer(value1) + self.ensure_integer(value2) + value1 = self.promote(value1) + value2 = self.promote(value2) + v1 = self.context.eval_expr(value1) + v2 = self.context.eval_expr(value2) + if v1 > v2: + self.error("Inconsistent case range", location) + if not self.switch_stack[-1].add_range(v1, v2): + self.error("Case range conflicts with previous cases", location) return statements.RangeCase(value1, value2, statement, location) else: - value = self.coerce(value, self.switch_stack[-1]) + # case value: + self.ensure_integer(value) + value = self.promote(value) + v = self.context.eval_expr(value) + if not self.switch_stack[-1].add_value(v): + self.error("Case value conflicts with previous cases", location) return statements.Case(value, statement, location) def on_default(self, statement, location): """ Handle a default label """ if not self.switch_stack: self.error("Default statement outside of a switch!", location) - + if self.switch_stack[-1].default_seen: + self.error("Duplicate default statement in current switch", location) + self.switch_stack[-1].default_seen = True return statements.Default(statement, location) def on_while(self, condition, body, location):