Solving the 'Kakuro' Sudoku-like puzzle with a Theorem Prover¶
The Kakuro puzzle is similar to Sudoku. The goal is to insert the numbers 1-9 in a grid so that each horizontal or vertical 'block' has no repeating digits and a specified sum.
Rather than solve this with a normal backtracking search, instead I will use the Z3 SMT solver from Microsoft. SMT (Satisfiability Modulo Theories) is a generalisation of the Boolean Satisfiability Problem and involves taking a formula finding how to set the inputs to make the output 'true'.
In traditional Boolean Satisfiability the input can only be 'true' and 'false'. SMT extends the language to include other mathmatical objects, such as integers.
! pip install z3-solver
import z3
import re
For example, we can encode the question 'find 3 distict numbers that add up to 7' as a problem and get Z3 to solve it:
solver = z3.Solver()
# There are 3 integer variables...
x = z3.Int('x')
y = z3.Int('y')
z = z3.Int('z')
# ...that are positive..
solver.add(0 < x)
solver.add(0 < y)
solver.add(0 < z)
# ... and distinct ...
solver.add(z3.Distinct([x,y,z]))
# ... and they add up to 7.
solver.add(7 == x + y + z)
# solver.check() does the work to find an assignment of Integers
# to the variables x,y,z that make all the constraints true
if solver.check() == z3.sat:
print(solver.model())
else:
print("No solution")
The idea here is to split the problem into two parts:
- Express the problem in a form that can be given to a solver
- Sit back and let the solver do the work
The nice thing here is that there has been lots of research into making fast SMT solvers, and we can build on that rather than trying to implement some subset of the tricks in functionality ourselves.
SMT is a good example of what I would call a 'keyhole API': Behind a small API surface lies many many years of complex research.
SMT is in general np-complete, but there are lots and lots of tricks that make it practical to solve large problems.
Input format¶
Cells can be of two types:
- Blank (and need filling in), or
- Non-blank, in which case they may contain a total for the column below them or row to the left of them.
In this representation _
means blank and n\m
are the
totals for the column and row.
d_med = \
r"""
\ 11\ 10\ 4\ 9\ \ 17\ 34\ \ 30\ 11\
\11 _ _ _ _ 7\16 _ _ \11 _ _
\29 _ _ _ _ _ _ _ 18\4 _ _
\4 _ _ 33\3 _ _ 16\29 _ _ _ _
\16 _ _ _ 38\ \21 _ _ _ _ _
\ \ 27\12 _ _ 9\14 _ _ _ _ \
\ 11\13 _ _ _ _ \14 _ _ 29\ 30\
\33 _ _ _ _ _ 8\ 6\15 _ _ _
\20 _ _ _ _ 3\5 _ _ 4\14 _ _
\11 _ _ \34 _ _ _ _ _ _ _
\13 _ _ \4 _ _ \19 _ _ _ _
"""
re_cell = re.compile('(\d*)\\\\(\d*)')
class Cell(object):
def __init__(self, s, row, col):
"""s is a string representation of the cell, e.g. '_' """
self.row = row
self.col = col
s = s.strip()
m = re_cell.match(s)
if m:
down = m.group(1)
if down != '':
self.column_total = int(down)
else:
self.column_total = None
across = m.group(2)
if across != '':
self.row_total = int(across)
else:
self.row_total = None
self.z3var = None
elif s == '_':
self.z3var = z3.Int('r%dc%d' % (row, col))
self.row_total = None
self.column_total = None
else:
raise Exception("Can't parse %s at row %d col %d" %(s,row, col))
def is_blank(self):
return self.z3var is not None
def z3_variable(self):
assert(self.is_blank())
return self.z3var
def display(self, model):
if model is not None and self.z3var is not None:
return ' %s ' % model[self.z3var]
else:
return str(self)
def __repr__(self):
if self.is_blank():
return str(self.z3var)
return self.__str__()
@staticmethod
def format_total(x):
if x:
return str(x)
else:
return ''
def __str__(self):
if self.is_blank():
return ' - '
else:
return '%2s\\%-2s' % (Cell.format_total(self.column_total), Cell.format_total(self.row_total))
def insert_constraints(self, solver):
if self.is_blank():
solver.add(0 < self.z3var, self.z3var < 10)
def parse(s, rows, cols):
"""Parse a string representation of a board into a grid of Cells"""
board = []
lines = [l for l in s.split('\n') if l != ''] # Remove empty lines
re_cells = re.compile('\s+')
if len(lines) != rows:
raise Exception("Wrong number of rows: %r" % lines)
for i, l in enumerate(lines):
row = []
cells = [l for l in re_cells.split(l) if l != '']
if len(cells) != cols:
raise Exception("Row is wrong length: %r" % cells)
for j, c in enumerate(cells):
row.append(Cell(c, row=i, col=j))
board.append(row)
return board
def dump(board, model=None):
for l in board:
print("".join([x.display(model) for x in l]))
# A quick test of the parsing routine
d = \
r"""
\3 _ _ 1\
\6 _ _ _
"""
board = parse(d, rows=2, cols=4)
print("Parsed board:")
print(repr(board))
print("Example output of 'dump'")
dump(board)
The next step is to add the constraints of the game: Each block must contain distinct numbers and sum to a specific total. The hard part here is walking over the board and remembering where we are:
class AccumulateBlock(object):
def __init__(self, solver):
self.solver = solver
self.block_total = None
self.z3vars = []
def define_sum(self, block_total):
if self.block_total is not None:
self.send_to_solver()
self.block_total = block_total
self.z3vars = []
def add_cell(self, z3var):
if self.block_total is not None:
self.z3vars.append(z3var)
def send_to_solver(self):
if self.block_total is not None:
assert(len(self.z3vars) > 0)
print(" %d == sum(%r)" % (self.block_total, self.z3vars))
self.solver.add(z3.Distinct(self.z3vars))
self.solver.add(self.block_total == z3.Sum(self.z3vars))
self.block_total = None
self.z3vars = []
def insert_constraints(board, solver):
accumulator = AccumulateBlock(solver)
print("Adding row constraints")
for row in board:
for cell in row:
if cell.row_total:
accumulator.define_sum(cell.row_total)
elif cell.is_blank():
accumulator.add_cell(cell.z3var)
accumulator.send_to_solver()
rows = len(board)
columns = len(board[0])
print("Adding column constraints")
for col in range(columns):
for row in range(rows):
cell = board[row][col]
if cell.column_total:
accumulator.define_sum(cell.column_total)
elif cell.is_blank():
accumulator.add_cell(cell.z3var)
accumulator.send_to_solver()
Finally, solving the puzzle requires adding the constraints on each cell (they must between 1 and 9) and the the constraints on the board as a whole to a Z3 solver, and asking it to find a solution:
def solve(board):
s = z3.Solver()
for row in board:
for cell in row:
cell.insert_constraints(s)
insert_constraints(board, s)
if s.check() == z3.sat:
print ("Found solution")
dump(board, s.model())
else:
print("unsat")
d = \
r"""
\3 _ _ 1\
\6 _ _ _
"""
board = parse(d, rows=2, cols=4)
print("Parsed input")
dump(board)
solve(board)
board = parse(d_med, rows=11, cols=11)
# dump(board)
solve(board)
Simple!