{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from numba import cuda\n", "from numba import vectorize\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "U8_MAX = 0xFFFFFFFFFFFFFFFF\n", "U32_MAX = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF\n", "\n", "def bignum_to_u32(num):\n", " res = np.empty(4, np.uint64)\n", " res[0] = num & U8_MAX\n", " res[1] = (num >> 64) & U8_MAX\n", " res[2] = (num >> 128) & U8_MAX\n", " res[3] = (num >> 192) & U8_MAX\n", " return res\n", "\n", "def u32_to_bignum(arr):\n", " return int(arr[3]) << 192 | \\\n", " int(arr[2]) << 128 | \\\n", " int(arr[1]) << 64 | \\\n", " int(arr[0]) \n", "\n", "Gx = bignum_to_u32(55066263022277343669578718895168534326250603453777594175500187360389116729240)\n", "Gy = bignum_to_u32(32670510020758816978083085130507043184471273380659243275938904335757337482424)\n", "p = bignum_to_u32(2**256 - 2**32 - 2**9 - 2**8 - 2**7 - 2**6 - 2**4 - 1)" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "@cuda.jit('void(u8[:], u8[:], u8[:])', device=True)\n", "def add_u32(a, b, out):\n", " carry = np.uint64(0)\n", " for i in range(4):\n", " ai = a[i]\n", " bi = b[i]\n", " outi = ai + bi\n", " out[i] = outi + carry\n", " carry = outi < ai or outi == U8_MAX and carry\n", "\n", "@cuda.jit('void(u8[:], u8[:], u8[:])', device=True)\n", "def sub_u32(a, b, out):\n", " borrow = np.uint64(0)\n", " for i in range(4):\n", " ai = a[i]\n", " bi = b[i]\n", " outi = ai - bi\n", " out[i] = outi - borrow\n", " borrow = ai < bi or outi == 0 and borrow\n", "\n", "@cuda.jit('void(u8[:], u8)', device=True)\n", "def shr_u32(out, bits):\n", " bits = np.uint64(bits)\n", " lost = np.uint64(0)\n", " for i in range(4):\n", " outi = out[3 - i]\n", " out[3 - i] = (outi >> bits) | lost\n", " lost = outi << (np.uint64(64) - bits)\n", " \n", "@cuda.jit('void(u8[:], u8[:], u8[:])', device=True)\n", "def mul_u32(a, b, out):\n", " for i in range(4):\n", " out[i] = 0\n", " \n", " for i in range(255, -1, -1):\n", " quad_pos = np.uint64(i // 64)\n", " bit_pos = np.uint64(i % 64)\n", " bit = b[quad_pos] >> bit_pos\n", " if bit % 2 == 1:\n", " add_u32(a, out, out)\n", " \n", " shr_u32(out, 1)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "failed on shr 0b0 0b110000000110100101\n", "failed on shr 0b11111111111111111111111111111111111111111111111111111111111111111111111111111111111111 0b1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111011111100100000111011\n", "failed on shr 0b11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111 0b111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111010100000101111100111110110011110100011011000010010\n", "failed on shr 0b111111111111111111111111111111111111111111111111111 0b111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111110101100000010111101100110100111110011000000011010\n", "failed on shr 0b1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111 0b11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111101011110\n", "failed on shr 0b1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111 0b11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111100010100000001110001001100001111\n", "failed on shr 0b0 0b11110100110111011010100011011110011010000010011101010\n", "failed on shr 0b0 0b10000010101011110110111000000100100010011110111111011101111\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/appuser/Miniconda3/lib/python3.6/site-packages/ipykernel_launcher.py:7: RuntimeWarning: overflow encountered in ulong_scalars\n", " import sys\n", "/home/appuser/Miniconda3/lib/python3.6/site-packages/ipykernel_launcher.py:28: RuntimeWarning: overflow encountered in ulong_scalars\n", "/home/appuser/Miniconda3/lib/python3.6/site-packages/ipykernel_launcher.py:17: RuntimeWarning: overflow encountered in ulong_scalars\n" ] } ], "source": [ "from random import randint\n", "\n", "@cuda.jit('void(u8[:], u8[:], u8[:])')\n", "def mul_u32_ker(a, b, out):\n", " mul_u32(a, b, out)\n", "\n", "for i in range(10):\n", " a = randint(0, U8_MAX)\n", " b = randint(0, U8_MAX)\n", " bit = randint(0, 256)\n", " a32 = bignum_to_u32(a)\n", " b32 = bignum_to_u32(b)\n", " out32 = np.empty(4, np.uint64)\n", " \n", " add_u32.py_func(a32, b32, out32)\n", " out = (a + b) % (U32_MAX + 1)\n", " if out != u32_to_bignum(out32):\n", " print(\"failed on add\")\n", " \n", " sub_u32.py_func(a32, b32, out32)\n", " out = (a - b) % (U32_MAX + 1)\n", " if out != u32_to_bignum(out32):\n", " print(\"failed on sub\")\n", " \n", " shr_u32.py_func(out32, bit)\n", " out >>= bit\n", " if out != u32_to_bignum(out32):\n", " print(\"failed on shr\", bin(out), bin(u32_to_bignum(out32)))\n", " \n", " #mul_u32_ker(a32, b32, out32)\n", " #out = (a * b) % (U32_MAX + 1)\n", " #if out != u32_to_bignum(out32):\n", " # print(\"failed on mul\", a, b, u32_to_bignum(out32))\n", " " ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "collapsed": true }, "outputs": [], "source": [ "@cuda.jit(device=True)\n", "def cmp_ge256(a, b):\n", " # return True if a >= b\n", " # compare from most significant word\n", " for i in range(3, -1, -1):\n", " ai = a[i]\n", " bi = b[i]\n", " if ai > bi:\n", " return True\n", " elif ai < bi:\n", " return False\n", " return True # equal\n", "\n", "@cuda.jit(device=True)\n", "def sub256(a, b, out):\n", " # assume a >= b, perform out = a - b\n", " borrow = 0\n", " for i in range(4):\n", " ai = a[i]\n", " bi = b[i]\n", " # compute ai - bi - borrow\n", " tmp = ai - bi\n", " borrow1 = 1 if tmp > ai else 0 # tmp < 0 -> wrapping -> tmp > ai\n", " tmp2 = tmp - borrow\n", " borrow2 = 1 if tmp2 > tmp else 0\n", " out[i] = tmp2\n", " borrow = 1 if (borrow1 or borrow2) else 0\n", " # if borrow != 0 then a < b (but we assumed a>=b)\n", "\n", "# 64x64 -> 128 using 32-bit split\n", "@cuda.jit(device=True)\n", "def mul64wide(x, y):\n", " MASK32 = (1 << 32) - 1\n", " x_lo = x & MASK32\n", " x_hi = x >> 32\n", " y_lo = y & MASK32\n", " y_hi = y >> 32\n", "\n", " p0 = x_lo * y_lo # <= 64 bits\n", " p1 = x_lo * y_hi # <= 64 bits\n", " p2 = x_hi * y_lo # <= 64 bits\n", " p3 = x_hi * y_hi # <= 64 bits\n", "\n", " # combine: p0 + (p1<<32) + (p2<<32) + (p3<<64)\n", " mid = p1 + p2\n", " carry_mid = 0\n", " # lower 64:\n", " low = (p0 + ((mid & MASK32) << 32)) & ((1 << 64) - 1)\n", " # carry from lower additions\n", " if (p0 + ((mid & MASK32) << 32)) >> 64:\n", " carry_mid = 1\n", " high = p3 + (mid >> 32) + carry_mid\n", " # low, high are 64-bit parts of 128-bit product\n", " return low & ((1 << 64) - 1), high & ((1 << 64) - 1)\n", "\n", "@cuda.jit(device=True)\n", "def mul256_full(a, b, out8):\n", " # out8 must be length 8 (little-endian) to hold full 512-bit product\n", " # Initialize\n", " for i in range(8):\n", " out8[i] = 0\n", "\n", " # schoolbook: for i in 0..3, j in 0..3\n", " temp = cuda.local.array(8, dtype=np.uint64) # local accumulator\n", " for i in range(8):\n", " temp[i] = 0\n", "\n", " for i in range(4):\n", " ai = a[i]\n", " for j in range(4):\n", " bj = b[j]\n", " lo, hi = mul64wide(ai, bj) # 128-bit product\n", " k = i + j\n", " # add lo to temp[k], handle carry\n", " s, c = add64_carry(temp[k], lo, 0)\n", " temp[k] = s\n", " # propagate carry to next word together with hi\n", " carry = hi + c\n", " t_idx = k + 1\n", " while carry != 0:\n", " s2, c2 = add64_carry(temp[t_idx], carry, 0)\n", " temp[t_idx] = s2\n", " # compute new carry (0/1) from addition\n", " carry = 1 if c2 else 0\n", " t_idx += 1\n", " # t_idx never exceeds 7 because i+j <= 6, plus propagation safe\n", " # copy to out8\n", " for i in range(8):\n", " out8[i] = temp[i]\n", "\n", "@cuda.jit(device=True)\n", "def mul256_lo(a, b, out4):\n", " # compute full product and keep lower 4 words\n", " out8 = cuda.local.array(8, dtype=np.uint64)\n", " mul256_full(a, b, out8)\n", " for i in range(4):\n", " out4[i] = out8[i]\n", "\n", "# --------------------\n", "# division: binary long division\n", "# --------------------\n", "@cuda.jit(device=True)\n", "def shl1_256_inplace(x):\n", " # x <<= 1 (inplace)\n", " carry = 0\n", " for i in range(4):\n", " new_carry = (x[i] >> 63) & 1\n", " x[i] = (x[i] << 1) & ((1 << 64) - 1)\n", " x[i] |= carry\n", " carry = new_carry\n", "\n", "@cuda.jit(device=True)\n", "def shr1_256_getbit(x):\n", " # shift right by 1, return LSB (bit 0) before shift\n", " lsb = x[0] & 1\n", " carry = 0\n", " for i in range(3, -1, -1):\n", " new_carry = x[i] & 1\n", " x[i] = (x[i] >> 1) | (carry << 63)\n", " carry = new_carry\n", " return lsb\n", "\n", "@cuda.jit(device=True)\n", "def get_bit256(x, idx):\n", " # idx: 0..255, 0 = least significant bit\n", " w = idx // 64\n", " b = idx % 64\n", " return (x[w] >> b) & 1\n", "\n", "@cuda.jit(device=True)\n", "def set_bit256(x, idx):\n", " w = idx // 64\n", " b = idx % 64\n", " x[w] |= (1 << b)\n", "\n", "@cuda.jit(device=True)\n", "def copy256(src, dst):\n", " for i in range(4):\n", " dst[i] = src[i]\n", "\n", "@cuda.jit(device=True)\n", "def zero256(x):\n", " for i in range(4):\n", " x[i] = 0\n", "\n", "@cuda.jit(device=True)\n", "def divmod256(dividend, divisor, q, r):\n", " # Binary long division (restoring), bit-by-bit from MSB..LSB\n", " # q,r are output arrays (4 words). dividend/divisor are input arrays.\n", " # edge cases\n", " zero = True\n", " for i in range(4):\n", " if divisor[i] != 0:\n", " zero = False\n", " break\n", " if zero:\n", " # division by zero — define q=0,r=dividend (user must avoid)\n", " for i in range(4):\n", " q[i] = 0\n", " r[i] = dividend[i]\n", " return\n", "\n", " zero256(q)\n", " zero256(r)\n", " # iterate bits from highest (255) down to 0\n", " for i in range(255, -1, -1):\n", " # left shift r by 1\n", " shl1_256_inplace(r)\n", " # bring in bit i of dividend\n", " if get_bit256(dividend, i):\n", " r[0] |= 1 # set lsb\n", " # if r >= divisor then r -= divisor and set q[i] = 1\n", " if cmp_ge256(r, divisor):\n", " # r = r - divisor\n", " tmp = cuda.local.array(4, dtype=np.uint64)\n", " sub256(r, divisor, tmp)\n", " for k in range(4):\n", " r[k] = tmp[k]\n", " # set q bit i\n", " w = i // 64\n", " b = i % 64\n", " q[w] |= (1 << b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 2 }