#!/usr/bin/python
# see also: https://eprint.iacr.org/2006/011.pdf
# see also: https://github.com/dlitz/pycrypto/blob/master/lib/Crypto/Signature/PKCS1_PSS.py

import hashlib, secrets, binascii
from sympy import mod_inverse
from hexdump import hexdump
from Crypto.Util.number import long_to_bytes, bytes_to_long
from Crypto.Util.strxor import strxor
from Crypto.Util.py3compat import bchr

def octet_len(i):
    iLen = int((len(hex(i)) - 2)/2)
    return iLen

def MGF1(mgfSeed, maskLen, hash_name):
    T = b''
    for counter in range(int((maskLen-1)/hashlib.new(hash_name).digest_size)+1):
        c = long_to_bytes(counter, 4)
        ht = hashlib.new(hash_name)
        ht.update(mgfSeed + c)
        T = T + ht.digest()
    assert(len(T)>=maskLen)
    return T[:maskLen]

def verify_rsa_pss(mhash, signature, n, e):
    em = pow(signature, e, n)
    emLen = octet_len(em)
    em_hex = hex(em)
    em_hex = em_hex[2:]
    print("em:" , em_hex)

    sLen = 32
    hLen = octet_len(mhash)
    print("mHash:" , hex(mhash))

    if(emLen < hLen + sLen + 2):
        return 0

    em_tail = em_hex[-2:]
    print(em_tail)
    if(em_tail != "bc"):
        return 0

    maskedDBLen = emLen - hLen - 1
    maskedDB = em_hex[0: maskedDBLen*2]
    print("maskedDB:", maskedDB)

    H = em_hex[ maskedDBLen*2 : emLen*2-2 ]
    print("H:", H)

    dbMask = MGF1(bytearray.fromhex(H), maskedDBLen, 'sha256')
    print("dbMask:", binascii.hexlify(dbMask))

    DB = strxor(bytes.fromhex(maskedDB), dbMask)
    print("DB:", binascii.hexlify(DB))

    padding2 = DB[0 : maskedDBLen - sLen]
    print("padding2:" , binascii.hexlify(padding2))

    if(int.from_bytes(padding2, "big") != 0x01):
        return 0

    salt = DB[maskedDBLen - sLen :]
    print("salt:", binascii.hexlify(salt))

    padding1 =bchr(0x00)*8;
    print("padding1:" , binascii.hexlify(padding1))

    m_ = padding1 + bytes.fromhex((hex(mhash))[2:]) + salt
    print("m':", binascii.hexlify(m_))

    hnew = hashlib.sha256()
    hnew.update(m_)
    H_ = hnew.digest()
    print("H':", binascii.hexlify(H_))

    if(bytes.fromhex(H) == H_):
        return 1

    return 0

# public key info
n = 0xa521c8aad19020ca79bd4bf95e92f3b1d54cc72f515b95611f8765568ec2dd98488d2e943b1c57215cc2d40623408587c7d7b1c6ee17b287b2bc08ac1ee33e1af133a59e3734fae9bad94335ca0863915ff94f14153871fefe2196c38963dfcca3ae69d476d33e0378b9049a559314e14ad5991e8b8fdca3ebed13a663757e34b5776dd15c615fc7e948ae27544f224c26edbb36495441decc968699afa74973341675bcc4b35b8e209731a3f9ac6372bba4b26e67cdf15c8f023f19047bf9c343d909e8728fbdb5e593d84ed6447d9377e8dc30a6c7d1dd35244f94a5cda7338f9f9efd7503068353a0a47d6e3b97258e43d6764be589596363849657a7ee99
e =  0x10001

# root_hash  = mhash
mhash = 0x1CE9C8528F5EB7AB4C1FF676AF9CD336547DDB64164FB8C1CEE9D154A39B36DC

# signature 
signature = 0x2F488791C3424F7EAC474BB6C7BCCFF582413F2783DC64337F4F4070D398F2ACD82FC97656C084C90A25839DC43EFE9D2F4D944CA4E86FE47ACC628C3873D981175DD6199F1FC689EF8A4C05FFD4119ED09736B70EC421A21CA88D6423CD1F9982344F524DE44179EE6D8F5F6BA9055E7A81E0CE2681002279A6331A5120766CC19F502AC1C14DF94F5D19DE0E7F4A70CB0C3B0FD26A93695B40E65D6FC304B78274A1CC4CBE6400225B2DCA7BBDFE3119497E86516B9BCD46724A5F53FA69FFCE2FC5023336E6B9D7AB5DEC9632C22E7086225FFC1AA7F4F5AA66804CA1A8B94D5CBE843C6549D7B32C9EC76220999E6BFC08D8935B1AA6B10D5E9BCD6F4708

res = verify_rsa_pss(mhash, signature, n, e)
print("verify result:", res)
