Solving the ‘Kakuro’ Sudoku-like puzzle with a Theorem Prover

The Kakuro puzzle is similar to Sudoku. The goal is to insert the numbers 1-9 in a grid so that each horizontal or vertical ‘block’ has no repeating digits and a specified sum.

Rather than solve this with a normal backtracking search, instead I will use the Z3 SMT solver from Microsoft. SMT (Satisfiability Modulo Theories) is a generalisation of the Boolean Satisfiability Problem and involves taking a formula finding how to set the inputs to make the output ‘true’.

In traditional Boolean Satisfiability the input can only be ‘true’ and ‘false’. SMT extends the language to include other mathmatical objects, such as integers.

import z3
import re

For example, we can encode the question ‘find 3 distict numbers that add up to 7’ as a problem and get Z3 to solve it:

solver = z3.Solver()
# There are 3 integer variables...
x = z3.Int('x')
y = z3.Int('y')
z = z3.Int('z')

# ...that are positive..
solver.add(0 < x)
solver.add(0 < y)
solver.add(0 < z)

# ... and distinct ...
solver.add(z3.Distinct([x,y,z]))

# ... and they add up to 7.
solver.add(7 == x + y + z)

# solver.check() does the work to find an assignment of Integers
# to the variables x,y,z that make all the constraints true
if solver.check() == z3.sat:
    print(solver.model())
else:
    print("No solution")
[z = 2, y = 1, x = 4]

The idea here is to split the problem into two parts:

  • Express the problem in a form that can be given to a solver
  • Sit back and let the solver do the work

The nice thing here is that there has been lots of research into making fast SMT solvers, and we can build on that rather than trying to implement some subset of the tricks in functionality ourselves.

SMT is a good example of what I would call a ‘keyhole API’: Behind a small API surface lies many many years of complex research.

SMT is in general np-complete, but there are lots and lots of tricks that make it practical to solve large problems.

Input format

Cells can be of two types:

  • Blank (and need filling in), or
  • Non-blank, in which case they may contain a total for the column below them or row to the left of them.

In this representation ‘_’ means blank and n:raw-latex:`m `are the totals for the column and row.

d_med = \
r"""
 \   11\  10\  4\  9\    \   17\  34\    \  30\  11\
 \11   _    _   _   _   7\16   _    _    \11  _    _
 \29   _    _   _   _    _     _    _  18\4   _    _
 \4    _    _ 33\3  _    _   16\29  _    _    _    _
 \16   _    _   _ 38\    \21    _   _    _    _    _
 \     \  27\12 _   _   9\14    _   _    _    _    \
 \   11\13  _   _   _    _      \14 _    _  29\  30\
 \33   _    _   _   _    _     8\  6\15  _    _    _
 \20   _    _   _   _   3\5     _   _   4\14  _    _
 \11   _    _   \34 _    _      _   _    _    _    _
 \13   _    _   \4  _    _      \19 _    _    _    _
"""
re_cell = re.compile('(\d*)\\\\(\d*)')


class Cell(object):
    def __init__(self, s, row, col):
        """s is a string representation of the cell, e.g. '_' """
        self.row = row
        self.col = col
        s = s.strip()
        m = re_cell.match(s)
        if m:
            down = m.group(1)
            if down != '':
                self.column_total = int(down)
            else:
                self.column_total = None
            across = m.group(2)
            if across != '':
                self.row_total = int(across)
            else:
                self.row_total = None
            self.z3var = None
        elif s == '_':
            self.z3var = z3.Int('r%dc%d' % (row, col))
            self.row_total = None
            self.column_total = None
        else:
            raise Exception("Can't parse %s at row %d col %d" %(s,row, col))

    def is_blank(self):
        return self.z3var is not None

    def z3_variable(self):
        assert(self.is_blank())
        return self.z3var


    def display(self, model):
        if model is not None and self.z3var is not None:
            return '  %s  ' % model[self.z3var]
        else:
            return str(self)

    def __repr__(self):
        if self.is_blank():
            return str(self.z3var)
        return self.__str__()

    @staticmethod
    def format_total(x):
        if x:
            return str(x)
        else:
            return ''

    def __str__(self):
        if self.is_blank():
            return '  -  '
        else:
            return '%2s\\%-2s' % (Cell.format_total(self.column_total), Cell.format_total(self.row_total))

    def insert_constraints(self, solver):
        if self.is_blank():
            solver.add(0 < self.z3var, self.z3var < 10)


def parse(s, rows, cols):
    """Parse a string representation of a board into a grid of Cells"""
    board = []
    lines = [l for l in s.split('\n') if l != ''] # Remove empty lines
    re_cells = re.compile('\s+')
    if len(lines) != rows:
        raise Exception("Wrong number of rows: %r" % lines)
    for i, l in enumerate(lines):
        row = []
        cells = [l for l in re_cells.split(l) if l != '']
        if len(cells) != cols:
            raise Exception("Row is wrong length: %r" % cells)
        for j, c in enumerate(cells):
            row.append(Cell(c, row=i, col=j))
        board.append(row)
    return board

def dump(board, model=None):
    for l in board:
        print("".join([x.display(model) for x in l]))
# A quick test of the parsing routine
d = \
r"""
 \3 _  _  1\
 \6 _  _  _
"""
board = parse(d, rows=2, cols=4)
print("Parsed board:")
print(repr(board))
print("Example output of 'dump'")
dump(board)
Parsed board:
[[  \3 , r0c1, r0c2,  1\  ], [  \6 , r1c1, r1c2, r1c3]]
Example output of 'dump'
  \3   -    -   1\
  \6   -    -    -

The next step is to add the constraints of the game: Each block must contain distinct numbers and sum to a specific total. The hard part here is walking over the board and remembering where we are:

class AccumulateBlock(object):
    def __init__(self, solver):
        self.solver = solver
        self.block_total = None
        self.z3vars = []

    def define_sum(self, block_total):
        if self.block_total is not None:
            self.send_to_solver()
        self.block_total = block_total
        self.z3vars = []

    def add_cell(self, z3var):
        if self.block_total is not None:
            self.z3vars.append(z3var)

    def send_to_solver(self):
        if self.block_total is not None:
            assert(len(self.z3vars) > 0)
            print("   %d == sum(%r)" % (self.block_total, self.z3vars))
            self.solver.add(z3.Distinct(self.z3vars))
            self.solver.add(self.block_total == z3.Sum(self.z3vars))
            self.block_total = None
            self.z3vars = []


def insert_constraints(board, solver):
    accumulator = AccumulateBlock(solver)
    print("Adding row constraints")
    for row in board:
        for cell in row:
            if cell.row_total:
                accumulator.define_sum(cell.row_total)
            elif cell.is_blank():
                accumulator.add_cell(cell.z3var)
        accumulator.send_to_solver()
    rows = len(board)
    columns = len(board[0])
    print("Adding column constraints")
    for col in range(columns):
        for row in range(rows):
            cell = board[row][col]
            if cell.column_total:
                accumulator.define_sum(cell.column_total)
            elif cell.is_blank():
                accumulator.add_cell(cell.z3var)
        accumulator.send_to_solver()

Finally, solving the puzzle requires adding the constraints on each cell (they must between 1 and 9) and the the constraints on the board as a hole to a Z3 solver, and asking it to find a solution:

def solve(board):
    s = z3.Solver()
    for row in board:
        for cell in row:
            cell.insert_constraints(s)
    insert_constraints(board, s)
    if s.check() == z3.sat:
        print ("Found solution")
        dump(board, s.model())
    else:
        print("unsat")

d = \
r"""
 \3 _  _  1\
 \6 _  _  _
"""
board = parse(d, rows=2, cols=4)
print("Parsed input")
dump(board)

solve(board)
Parsed input
  \3   -    -   1\
  \6   -    -    -
Adding row constraints
   3 == sum([r0c1, r0c2])
   6 == sum([r1c1, r1c2, r1c3])
Adding column constraints
   1 == sum([r1c3])
Found solution
  \3   2    1   1\
  \6   3    2    1

Finally, solve the original problem:

board = parse(d_med, rows=11, cols=11)
dump(board)

solve(board)
  \  11\  10\   4\   9\    \  17\  34\    \  30\  11\
  \11  -    -    -    -   7\16  -    -    \11  -    -
  \29  -    -    -    -    -    -    -  18\4   -    -
  \4   -    -  33\3   -    -  16\29  -    -    -    -
  \16  -    -    -  38\    \21  -    -    -    -    -
  \    \  27\12  -    -   9\14  -    -    -    -    \
  \  11\13  -    -    -    -    \14  -    -  29\  30\
  \33  -    -    -    -    -   8\   6\15  -    -    -
  \20  -    -    -    -   3\5   -    -   4\14  -    -
  \11  -    -    \34  -    -    -    -    -    -    -
  \13  -    -    \4   -    -    \19  -    -    -    -
Adding row constraints
   11 == sum([r1c1, r1c2, r1c3, r1c4])
   16 == sum([r1c6, r1c7])
   11 == sum([r1c9, r1c10])
   29 == sum([r2c1, r2c2, r2c3, r2c4, r2c5, r2c6, r2c7])
   4 == sum([r2c9, r2c10])
   4 == sum([r3c1, r3c2])
   3 == sum([r3c4, r3c5])
   29 == sum([r3c7, r3c8, r3c9, r3c10])
   16 == sum([r4c1, r4c2, r4c3])
   21 == sum([r4c6, r4c7, r4c8, r4c9, r4c10])
   12 == sum([r5c3, r5c4])
   14 == sum([r5c6, r5c7, r5c8, r5c9])
   13 == sum([r6c2, r6c3, r6c4, r6c5])
   14 == sum([r6c7, r6c8])
   33 == sum([r7c1, r7c2, r7c3, r7c4, r7c5])
   15 == sum([r7c8, r7c9, r7c10])
   20 == sum([r8c1, r8c2, r8c3, r8c4])
   5 == sum([r8c6, r8c7])
   14 == sum([r8c9, r8c10])
   11 == sum([r9c1, r9c2])
   34 == sum([r9c4, r9c5, r9c6, r9c7, r9c8, r9c9, r9c10])
   13 == sum([r10c1, r10c2])
   4 == sum([r10c4, r10c5])
   19 == sum([r10c7, r10c8, r10c9, r10c10])
Adding column constraints
   11 == sum([r1c1, r2c1, r3c1, r4c1])
   11 == sum([r7c1, r8c1, r9c1, r10c1])
   10 == sum([r1c2, r2c2, r3c2, r4c2])
   27 == sum([r6c2, r7c2, r8c2, r9c2, r10c2])
   4 == sum([r1c3, r2c3])
   33 == sum([r4c3, r5c3, r6c3, r7c3, r8c3])
   9 == sum([r1c4, r2c4, r3c4])
   38 == sum([r5c4, r6c4, r7c4, r8c4, r9c4, r10c4])
   7 == sum([r2c5, r3c5])
   9 == sum([r6c5, r7c5])
   3 == sum([r9c5, r10c5])
   17 == sum([r1c6, r2c6])
   16 == sum([r4c6, r5c6])
   8 == sum([r8c6, r9c6])
   34 == sum([r1c7, r2c7, r3c7, r4c7, r5c7, r6c7])
   6 == sum([r8c7, r9c7, r10c7])
   18 == sum([r3c8, r4c8, r5c8, r6c8, r7c8])
   4 == sum([r9c8, r10c8])
   30 == sum([r1c9, r2c9, r3c9, r4c9, r5c9])
   29 == sum([r7c9, r8c9, r9c9, r10c9])
   11 == sum([r1c10, r2c10, r3c10, r4c10])
   30 == sum([r7c10, r8c10, r9c10, r10c10])
Found solution
  \  11\  10\   4\   9\    \  17\  34\    \  30\  11\
  \11  1    2    3    5   7\16  9    7    \11  8    3
  \29  2    4    1    3    5    8    6  18\4   3    1
  \4   3    1  33\3   1    2  16\29  8    7    9    5
  \16  5    3    8  38\    \21  9    3    1    6    2
  \    \  27\12  5    7   9\14  7    1    2    4    \
  \  11\13  1    4    6    2    \14  9    5  29\  30\
  \33  3    6    9    8    7   8\   6\15  3    5    7
  \20  1    3    7    9   3\5   2    3   4\14  8    6
  \11  2    9    \34  5    2    6    1    3    9    8
  \13  5    8    \4   3    1    \19  2    1    7    9

Simple!