From 9c8d3a7b5f73e3155e3479fccdb4bf4b3bc3cd40 Mon Sep 17 00:00:00 2001 From: SEK1RO Date: Mon, 15 Dec 2025 18:53:54 +0300 Subject: [PATCH] ds: 1e: karatsuba prototype --- ds/25-1/1e/main.cu | 153 +++++++++++++++++++++++++++++++++++++++------ ds/25-1/1e/op.ptx | 81 ++++++++++++++++++++++-- ds/25-1/1e/test.py | 17 +++++ 3 files changed, 227 insertions(+), 24 deletions(-) create mode 100644 ds/25-1/1e/test.py diff --git a/ds/25-1/1e/main.cu b/ds/25-1/1e/main.cu index c39ed0d..378901a 100644 --- a/ds/25-1/1e/main.cu +++ b/ds/25-1/1e/main.cu @@ -1,4 +1,17 @@ #include +#include + +extern "C" __device__ void add_u16( + ulonglong2 *out_c, + ulonglong2 in_a, + ulonglong2 in_b +); + +extern "C" __device__ void sub_u16( + ulonglong2 *out_c, + ulonglong2 in_a, + ulonglong2 in_b +); extern "C" __device__ void add_u32( ulonglong4 *out_c, @@ -6,30 +19,134 @@ extern "C" __device__ void add_u32( ulonglong4 in_b ); -__constant__ char ok[] = "ok"; -__constant__ char not_ok[] = "not ok"; +extern "C" __device__ void sub_u32( + ulonglong4 *out_c, + ulonglong4 in_a, + ulonglong4 in_b +); -__global__ void kernel(char *buf) { - ulonglong4 a = {0, 1, 2, 3}; - ulonglong4 b = {1, 1, 1, 1}; - ulonglong4 c = {1, 2, 3, 4}; +__device__ void mul_u16( + ulonglong2 *out_c, + ulonglong2 in_a, + ulonglong2 in_b +) { + uint64_t ax_ay = in_a.x + in_a.y; + uint64_t bx_by = in_b.x + in_b.y; + uint64_t axbx = in_a.x * in_b.x; + uint64_t ayby = in_a.y * in_b.y; + out_c->x = ax_ay * bx_by - axbx - ayby; + out_c->y = ayby; +} - add_u32(&c, a, b); +__device__ void mul_u32( + ulonglong4 *out_c, + ulonglong4 in_a, + ulonglong4 in_b +) { + auto ax = (ulonglong2 *)&in_a.x; + auto ay = (ulonglong2 *)&in_a.z; + auto bx = (ulonglong2 *)&in_b.x; + auto by = (ulonglong2 *)&in_b.z; + ulonglong2 ax_ay, bx_by, paren, axbx, ayby; + add_u16(&ax_ay, *ax, *ay); + add_u16(&bx_by, *bx, *by); + mul_u16(&paren, ax_ay, bx_by); + mul_u16(&axbx, *ax, *bx); + mul_u16(&ayby, *ay, *by); + sub_u16(&paren, paren, axbx); + sub_u16(&paren, paren, ayby); + out_c->x = paren.x; + out_c->y = paren.y; + out_c->z = ayby.x; + out_c->w = ayby.y; +} - memcpy(buf, ok, sizeof(ok)); +__device__ bool equ_u16(ulonglong2 a, ulonglong2 b) { + return a.x == b.x && a.y == b.y; +} + +__device__ bool equ_u32(ulonglong4 a, ulonglong4 b) { + return a.x == b.x && + a.y == b.y && + a.z == b.z && + a.w == b.w; +} + +__device__ void print_u16(ulonglong2 a) { + printf("0x%016llx.%016llx\n", a.x, a.y); +} + +__device__ void print_u32(ulonglong4 a) { + printf("0x%016llx.%016llx.%016llx.%016llx\n", a.x, a.y, a.z, a.w); +} + +#define U8_MAX 0xFFFFFFFFFFFFFFFF +#define _U16_MAX {U8_MAX, U8_MAX} +#define _U32_MAX {U8_MAX, U8_MAX, U8_MAX, U8_MAX} + +__global__ void test(bool *passed) { + *passed = false; + { + ulonglong4 a = _U32_MAX; + ulonglong4 b = {0, 0, 0, 1}; + ulonglong4 c = {0, 0, 0, 0}; + add_u32(&a, a, b); + if (!equ_u32(a, c)) { + printf("add_u32 "); + print_u32(a); + return; + } + } + { + ulonglong4 a = {0, 0, 0, 0}; + ulonglong4 b = {0, 0, 0, 1}; + ulonglong4 c = _U32_MAX; + sub_u32(&a, a, b); + if (!equ_u32(a, c)) { + printf("sub_32 "); + print_u32(a); + return; + } + } + { + ulonglong2 a = _U16_MAX; + ulonglong2 b = {0, U8_MAX}; + ulonglong2 c = {U8_MAX, 1}; + mul_u16(&a, a, b); + if (!equ_u16(a, c)) { + printf("mul_16 "); + print_u16(a); + return; + } + } + { + ulonglong4 a = _U32_MAX; + ulonglong4 b = {0, 0, U8_MAX, U8_MAX}; + ulonglong4 c = {U8_MAX, U8_MAX, 0, 1}; + mul_u32(&a, a, b); + if (!equ_u32(a, c)) { + printf("mul_32 "); + print_u32(a); + return; + } + } + *passed = true; } int main() { - char h_buf[32]; - char *d_buf; - cudaMalloc(&d_buf, 32); - - kernel<<<1, 1>>>(d_buf); - + bool test_passed, *d_test_passed; + cudaMalloc(&d_test_passed, sizeof(bool)); + + test<<<1, 1>>>(d_test_passed); cudaDeviceSynchronize(); - cudaMemcpy(h_buf, d_buf, 32, cudaMemcpyDeviceToHost); - - printf("%s\n", h_buf); - cudaFree(d_buf); + + cudaMemcpy(&test_passed, d_test_passed, sizeof(bool), cudaMemcpyDeviceToHost); + cudaFree(d_test_passed); + + if (!test_passed) { + printf("test not passed\n"); + return 1; + } + return 0; } \ No newline at end of file diff --git a/ds/25-1/1e/op.ptx b/ds/25-1/1e/op.ptx index e6077d5..9b608f3 100644 --- a/ds/25-1/1e/op.ptx +++ b/ds/25-1/1e/op.ptx @@ -2,6 +2,49 @@ .target sm_75 .address_size 64 +.visible .func add_u16( + .param .b64 out_c, + .param .align 16 .b8 in_a[16], + .param .align 16 .b8 in_b[16] +) { + .reg .u64 %ra<2>, %rb<2>; + .reg .b64 %rdc; + + ld.param.b64 %rdc, [out_c]; + + ld.param.v2.u64 {%ra1, %ra0}, [in_a]; + ld.param.v2.u64 {%rb1, %rb0}, [in_b]; + + add.cc.u64 %ra0, %ra0, %rb0; + addc.u64 %ra1, %ra1, %rb1; + + st.v2.u64 [%rdc], {%ra1, %ra0}; + + ret; +} + +.visible .func sub_u16( + .param .b64 out_c, + .param .align 16 .b8 in_a[16], + .param .align 16 .b8 in_b[16] +) { + .reg .u64 %ra<2>, %rb<2>; + .reg .b64 %rdc; + + ld.param.b64 %rdc, [out_c]; + + ld.param.v2.u64 {%ra1, %ra0}, [in_a]; + ld.param.v2.u64 {%rb1, %rb0}, [in_b]; + + add.cc.u64 %ra0, %ra0, %rb0; + addc.u64 %ra1, %ra1, %rb1; + + st.v2.u64 [%rdc], {%ra1, %ra0}; + + ret; +} + + .visible .func add_u32( .param .b64 out_c, .param .align 16 .b8 in_a[32], @@ -12,18 +55,44 @@ ld.param.b64 %rdc, [out_c]; - ld.param.v2.u64 {%ra0, %ra1}, [in_a]; - ld.param.v2.u64 {%ra2, %ra3}, [in_a + 16]; - ld.param.v2.u64 {%rb0, %rb1}, [in_b]; - ld.param.v2.u64 {%rb2, %rb3}, [in_b + 16]; + ld.param.v2.u64 {%ra3, %ra2}, [in_a]; + ld.param.v2.u64 {%ra1, %ra0}, [in_a + 16]; + ld.param.v2.u64 {%rb3, %rb2}, [in_b]; + ld.param.v2.u64 {%rb1, %rb0}, [in_b + 16]; add.cc.u64 %ra0, %ra0, %rb0; addc.cc.u64 %ra1, %ra1, %rb1; addc.cc.u64 %ra2, %ra2, %rb2; addc.u64 %ra3, %ra3, %rb3; - st.v2.u64 [%rdc], {%ra0, %ra1}; - st.v2.u64 [%rdc + 16], {%ra2, %ra3}; + st.v2.u64 [%rdc], {%ra3, %ra2}; + st.v2.u64 [%rdc + 16], {%ra1, %ra0}; + + ret; +} + +.visible .func sub_u32( + .param .b64 out_c, + .param .align 16 .b8 in_a[32], + .param .align 16 .b8 in_b[32] +) { + .reg .u64 %ra<4>, %rb<4>; + .reg .b64 %rdc; + + ld.param.b64 %rdc, [out_c]; + + ld.param.v2.u64 {%ra3, %ra2}, [in_a]; + ld.param.v2.u64 {%ra1, %ra0}, [in_a + 16]; + ld.param.v2.u64 {%rb3, %rb2}, [in_b]; + ld.param.v2.u64 {%rb1, %rb0}, [in_b + 16]; + + sub.cc.u64 %ra0, %ra0, %rb0; + subc.cc.u64 %ra1, %ra1, %rb1; + subc.cc.u64 %ra2, %ra2, %rb2; + subc.u64 %ra3, %ra3, %rb3; + + st.v2.u64 [%rdc], {%ra3, %ra2}; + st.v2.u64 [%rdc + 16], {%ra1, %ra0}; ret; } \ No newline at end of file diff --git a/ds/25-1/1e/test.py b/ds/25-1/1e/test.py new file mode 100644 index 0000000..e7f1e14 --- /dev/null +++ b/ds/25-1/1e/test.py @@ -0,0 +1,17 @@ +U8_MAX = 0xFFFFFFFFFFFFFFFF +U16_MAX = U8_MAX << 64 | U8_MAX +U32_MAX = U16_MAX << 128 | U16_MAX + +def dothex(num): + strhex = hex(num)[2:] + dothex = strhex[-16:] + strhex = strhex[:-16] + + while len(strhex) > 0: + dothex = strhex[-16:] + '.' + dothex + strhex = strhex[:-16] + + return '0x' + dothex + +print('mul16', dothex(U16_MAX * 2 % (U16_MAX + 1))) +print('mul32', dothex(U32_MAX * U16_MAX % (U32_MAX + 1))) \ No newline at end of file