from z3 import *
import re
import string

def mk_mask(addrs):
    m = 0
    for a in addrs:
	if a == "*":
	    m = 256*m
	else:
	    m = 256*m + 255
    return m

def mk_addr(addrs):
    m = 0
    for a in addrs:
	if a == "*":
	    m = 256*m
	else:
	    m = 256*m + string.atoi(a)
    return m    

srcIp, dstIp = BitVecs('srcIp dstIp', 32)
srcPort, dstPort, proto = BitVecs('srcPort dstPort protocol', 16)

def parse_mask(mask):
    if isinstance(mask,str):
	addrs = re.split("\.",mask)
	mask  = mk_mask(addrs)
	addr  = mk_addr(addrs)
	return (mask, addr)
    else:
	return (-1, mask)

    
def matchMask(v,mask):
    (mask, addr) = parse_mask(mask)
    return (v & mask) == addr

def rule(srcIpMask, srcPortMask, dstIpMask, dstPortMask, protoMask):
    return And([matchMask(srcIp,    srcIpMask),
		matchMask(srcPort,  srcPortMask),
		matchMask(dstIp,    dstIpMask),
		matchMask(dstPort,  dstPortMask),
		matchMask(proto,    protoMask)])

def deny(sa,sp,da,dp,p):
    return (rule(sa,sp,da,dp,p), False)

def allow(sa,sp,da,dp,p):
    return (rule(sa,sp,da,dp,p), True)

policy1 = [ allow("1.64.*.*",80,"10.*.*.*",80,4),
	    deny("1.*.*.*",80,"20.*.*.*",80,4),
	    allow("1.*.*.*",80,"30.*.*.*",80,4) ]
	    
policy2 = [ allow("1.128.*.*",80,"10.*.*.*",80,4),
	    deny("1.*.*.*",80,"20.*.*.*",80,4),
	    allow("1.*.*.*",80,"30.*.*.*",80,4) ]
	    

# First applicable
def fa(policy):
    policy = [p for p in policy]
    policy.reverse()
    fml = False
    for (filter, is_allow) in policy:
	if is_allow:
	    fml = Or(filter, fml)
	else:
	    fml = And(filter, fml)
    return fml

# Deny overrides
def do(policy):
    allows = [filter for (filter, is_allow) in policy if is_allow]
    rules  = [Or(allows)] + [ Not(filter) for (filter, is_allow) in policy if not is_allow]
    return And(rules)

solve (fa(policy1) != fa(policy2))

s = Solver()

# Create bit-blasted versions of variables
# so they can be tracked.
def blast(x):
    sz = x.sort().size()
    vars = []
    for i in range(sz):
	x_i = Bool("%s_%i" % (x, i))
	s.add(x_i == (Extract(i,i,x) == BitVecVal(1,1)))
	vars += [x_i]
    return vars

srcIp_b   = blast(srcIp)
srcPort_b = blast(srcPort)
dstIp_b   = blast(dstIp)
dstPort_b = blast(dstPort)
proto_b   = blast(proto)


#
# Retrieve ID of x to use for comparisons.
# The IDs are unique for every term.
# This function illustrates a call to the
# C-API from Python
#
def get_id(x):
    return Z3_get_ast_id(x.ctx.ref(),x.as_ast())

def get_ids(xs):
    return [get_id(x) for x in xs]

def remove_ast(x, ys):
    return [y for y in ys if get_id(x) != get_id(y)]

#
# Minimize the assignment that
# makes the current context unsat
# This is a straight-forward
# hill-climbing algorithm that
# removes literals from the current core.
#
# Exercise: implement and try some variants,
# such as QuickExplain.
#
def minimize_assignment(assignment):
    ids = get_ids(assignment)
    for v in assignment:
	id = get_id(v)
	if id in ids:
	    r = s.check([w for w in assignment if get_id(w) in ids and get_id(w) != id])
	    if r == unsat:
		ids = get_ids(s.unsat_core())
    return [w for w in assignment if get_id(w) in ids]

#
# Generalize the current model with respect to vars
# Extract the minimal satisfying assignment of vars
# that imply the formula.
#

def mk_literal(model, v):
    if is_true(model.eval(v,True)):
	return v
    else:
	return Not(v)

def print_bit(v, ids):
    if get_id(v) in ids:
	return "1"
    if get_id(Not(v)) in ids:
	return "0"
    return "*"

def print_bits(v, vars, literals):
    ids = get_ids(literals)
    bits = [ print_bit(v, ids) for v in vars]
    bits.reverse()
    print v, ": ", string.join(bits)
    
def generalize_model(model, vars, formula):
    s.push()
    s.add(Not(formula))
    values = [ mk_literal(model, x) for (_,vs) in vars for x in vs ]
    new_values = minimize_assignment(values)
    s.pop()
    for (v,vs) in vars:
	print_bits(v, vs, new_values)

def solve_relaxed(formula):
    vars = [(srcIp, srcIp_b), (srcPort, srcPort_b), (dstIp, dstIp_b), (dstPort, dstPort_b), (proto, proto_b)]
    s.push()
    s.add(formula)
    print s
    r = s.check()
    if r == sat:
	m = s.model()
    s.pop()
    print r
    if r == sat:
	generalize_model(m, vars, formula)
    

solve_relaxed(fa(policy1) != fa(policy2))

#
# Exercise:
# Extract all policy differences (efficiently)
# 
