ds: 1e: karatsuba prototype
This commit is contained in:
@ -1,4 +1,17 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
extern "C" __device__ void add_u16(
|
||||||
|
ulonglong2 *out_c,
|
||||||
|
ulonglong2 in_a,
|
||||||
|
ulonglong2 in_b
|
||||||
|
);
|
||||||
|
|
||||||
|
extern "C" __device__ void sub_u16(
|
||||||
|
ulonglong2 *out_c,
|
||||||
|
ulonglong2 in_a,
|
||||||
|
ulonglong2 in_b
|
||||||
|
);
|
||||||
|
|
||||||
extern "C" __device__ void add_u32(
|
extern "C" __device__ void add_u32(
|
||||||
ulonglong4 *out_c,
|
ulonglong4 *out_c,
|
||||||
@ -6,30 +19,134 @@ extern "C" __device__ void add_u32(
|
|||||||
ulonglong4 in_b
|
ulonglong4 in_b
|
||||||
);
|
);
|
||||||
|
|
||||||
__constant__ char ok[] = "ok";
|
extern "C" __device__ void sub_u32(
|
||||||
__constant__ char not_ok[] = "not ok";
|
ulonglong4 *out_c,
|
||||||
|
ulonglong4 in_a,
|
||||||
|
ulonglong4 in_b
|
||||||
|
);
|
||||||
|
|
||||||
__global__ void kernel(char *buf) {
|
__device__ void mul_u16(
|
||||||
ulonglong4 a = {0, 1, 2, 3};
|
ulonglong2 *out_c,
|
||||||
ulonglong4 b = {1, 1, 1, 1};
|
ulonglong2 in_a,
|
||||||
ulonglong4 c = {1, 2, 3, 4};
|
ulonglong2 in_b
|
||||||
|
) {
|
||||||
|
uint64_t ax_ay = in_a.x + in_a.y;
|
||||||
|
uint64_t bx_by = in_b.x + in_b.y;
|
||||||
|
uint64_t axbx = in_a.x * in_b.x;
|
||||||
|
uint64_t ayby = in_a.y * in_b.y;
|
||||||
|
out_c->x = ax_ay * bx_by - axbx - ayby;
|
||||||
|
out_c->y = ayby;
|
||||||
|
}
|
||||||
|
|
||||||
add_u32(&c, a, b);
|
__device__ void mul_u32(
|
||||||
|
ulonglong4 *out_c,
|
||||||
|
ulonglong4 in_a,
|
||||||
|
ulonglong4 in_b
|
||||||
|
) {
|
||||||
|
auto ax = (ulonglong2 *)&in_a.x;
|
||||||
|
auto ay = (ulonglong2 *)&in_a.z;
|
||||||
|
auto bx = (ulonglong2 *)&in_b.x;
|
||||||
|
auto by = (ulonglong2 *)&in_b.z;
|
||||||
|
ulonglong2 ax_ay, bx_by, paren, axbx, ayby;
|
||||||
|
add_u16(&ax_ay, *ax, *ay);
|
||||||
|
add_u16(&bx_by, *bx, *by);
|
||||||
|
mul_u16(&paren, ax_ay, bx_by);
|
||||||
|
mul_u16(&axbx, *ax, *bx);
|
||||||
|
mul_u16(&ayby, *ay, *by);
|
||||||
|
sub_u16(&paren, paren, axbx);
|
||||||
|
sub_u16(&paren, paren, ayby);
|
||||||
|
out_c->x = paren.x;
|
||||||
|
out_c->y = paren.y;
|
||||||
|
out_c->z = ayby.x;
|
||||||
|
out_c->w = ayby.y;
|
||||||
|
}
|
||||||
|
|
||||||
memcpy(buf, ok, sizeof(ok));
|
__device__ bool equ_u16(ulonglong2 a, ulonglong2 b) {
|
||||||
|
return a.x == b.x && a.y == b.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ bool equ_u32(ulonglong4 a, ulonglong4 b) {
|
||||||
|
return a.x == b.x &&
|
||||||
|
a.y == b.y &&
|
||||||
|
a.z == b.z &&
|
||||||
|
a.w == b.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void print_u16(ulonglong2 a) {
|
||||||
|
printf("0x%016llx.%016llx\n", a.x, a.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void print_u32(ulonglong4 a) {
|
||||||
|
printf("0x%016llx.%016llx.%016llx.%016llx\n", a.x, a.y, a.z, a.w);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define U8_MAX 0xFFFFFFFFFFFFFFFF
|
||||||
|
#define _U16_MAX {U8_MAX, U8_MAX}
|
||||||
|
#define _U32_MAX {U8_MAX, U8_MAX, U8_MAX, U8_MAX}
|
||||||
|
|
||||||
|
__global__ void test(bool *passed) {
|
||||||
|
*passed = false;
|
||||||
|
{
|
||||||
|
ulonglong4 a = _U32_MAX;
|
||||||
|
ulonglong4 b = {0, 0, 0, 1};
|
||||||
|
ulonglong4 c = {0, 0, 0, 0};
|
||||||
|
add_u32(&a, a, b);
|
||||||
|
if (!equ_u32(a, c)) {
|
||||||
|
printf("add_u32 ");
|
||||||
|
print_u32(a);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ulonglong4 a = {0, 0, 0, 0};
|
||||||
|
ulonglong4 b = {0, 0, 0, 1};
|
||||||
|
ulonglong4 c = _U32_MAX;
|
||||||
|
sub_u32(&a, a, b);
|
||||||
|
if (!equ_u32(a, c)) {
|
||||||
|
printf("sub_32 ");
|
||||||
|
print_u32(a);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ulonglong2 a = _U16_MAX;
|
||||||
|
ulonglong2 b = {0, U8_MAX};
|
||||||
|
ulonglong2 c = {U8_MAX, 1};
|
||||||
|
mul_u16(&a, a, b);
|
||||||
|
if (!equ_u16(a, c)) {
|
||||||
|
printf("mul_16 ");
|
||||||
|
print_u16(a);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ulonglong4 a = _U32_MAX;
|
||||||
|
ulonglong4 b = {0, 0, U8_MAX, U8_MAX};
|
||||||
|
ulonglong4 c = {U8_MAX, U8_MAX, 0, 1};
|
||||||
|
mul_u32(&a, a, b);
|
||||||
|
if (!equ_u32(a, c)) {
|
||||||
|
printf("mul_32 ");
|
||||||
|
print_u32(a);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*passed = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
char h_buf[32];
|
bool test_passed, *d_test_passed;
|
||||||
char *d_buf;
|
cudaMalloc(&d_test_passed, sizeof(bool));
|
||||||
cudaMalloc(&d_buf, 32);
|
|
||||||
|
test<<<1, 1>>>(d_test_passed);
|
||||||
kernel<<<1, 1>>>(d_buf);
|
|
||||||
|
|
||||||
cudaDeviceSynchronize();
|
cudaDeviceSynchronize();
|
||||||
cudaMemcpy(h_buf, d_buf, 32, cudaMemcpyDeviceToHost);
|
|
||||||
|
cudaMemcpy(&test_passed, d_test_passed, sizeof(bool), cudaMemcpyDeviceToHost);
|
||||||
printf("%s\n", h_buf);
|
cudaFree(d_test_passed);
|
||||||
cudaFree(d_buf);
|
|
||||||
|
if (!test_passed) {
|
||||||
|
printf("test not passed\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -2,6 +2,49 @@
|
|||||||
.target sm_75
|
.target sm_75
|
||||||
.address_size 64
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .func add_u16(
|
||||||
|
.param .b64 out_c,
|
||||||
|
.param .align 16 .b8 in_a[16],
|
||||||
|
.param .align 16 .b8 in_b[16]
|
||||||
|
) {
|
||||||
|
.reg .u64 %ra<2>, %rb<2>;
|
||||||
|
.reg .b64 %rdc;
|
||||||
|
|
||||||
|
ld.param.b64 %rdc, [out_c];
|
||||||
|
|
||||||
|
ld.param.v2.u64 {%ra1, %ra0}, [in_a];
|
||||||
|
ld.param.v2.u64 {%rb1, %rb0}, [in_b];
|
||||||
|
|
||||||
|
add.cc.u64 %ra0, %ra0, %rb0;
|
||||||
|
addc.u64 %ra1, %ra1, %rb1;
|
||||||
|
|
||||||
|
st.v2.u64 [%rdc], {%ra1, %ra0};
|
||||||
|
|
||||||
|
ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
.visible .func sub_u16(
|
||||||
|
.param .b64 out_c,
|
||||||
|
.param .align 16 .b8 in_a[16],
|
||||||
|
.param .align 16 .b8 in_b[16]
|
||||||
|
) {
|
||||||
|
.reg .u64 %ra<2>, %rb<2>;
|
||||||
|
.reg .b64 %rdc;
|
||||||
|
|
||||||
|
ld.param.b64 %rdc, [out_c];
|
||||||
|
|
||||||
|
ld.param.v2.u64 {%ra1, %ra0}, [in_a];
|
||||||
|
ld.param.v2.u64 {%rb1, %rb0}, [in_b];
|
||||||
|
|
||||||
|
add.cc.u64 %ra0, %ra0, %rb0;
|
||||||
|
addc.u64 %ra1, %ra1, %rb1;
|
||||||
|
|
||||||
|
st.v2.u64 [%rdc], {%ra1, %ra0};
|
||||||
|
|
||||||
|
ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
.visible .func add_u32(
|
.visible .func add_u32(
|
||||||
.param .b64 out_c,
|
.param .b64 out_c,
|
||||||
.param .align 16 .b8 in_a[32],
|
.param .align 16 .b8 in_a[32],
|
||||||
@ -12,18 +55,44 @@
|
|||||||
|
|
||||||
ld.param.b64 %rdc, [out_c];
|
ld.param.b64 %rdc, [out_c];
|
||||||
|
|
||||||
ld.param.v2.u64 {%ra0, %ra1}, [in_a];
|
ld.param.v2.u64 {%ra3, %ra2}, [in_a];
|
||||||
ld.param.v2.u64 {%ra2, %ra3}, [in_a + 16];
|
ld.param.v2.u64 {%ra1, %ra0}, [in_a + 16];
|
||||||
ld.param.v2.u64 {%rb0, %rb1}, [in_b];
|
ld.param.v2.u64 {%rb3, %rb2}, [in_b];
|
||||||
ld.param.v2.u64 {%rb2, %rb3}, [in_b + 16];
|
ld.param.v2.u64 {%rb1, %rb0}, [in_b + 16];
|
||||||
|
|
||||||
add.cc.u64 %ra0, %ra0, %rb0;
|
add.cc.u64 %ra0, %ra0, %rb0;
|
||||||
addc.cc.u64 %ra1, %ra1, %rb1;
|
addc.cc.u64 %ra1, %ra1, %rb1;
|
||||||
addc.cc.u64 %ra2, %ra2, %rb2;
|
addc.cc.u64 %ra2, %ra2, %rb2;
|
||||||
addc.u64 %ra3, %ra3, %rb3;
|
addc.u64 %ra3, %ra3, %rb3;
|
||||||
|
|
||||||
st.v2.u64 [%rdc], {%ra0, %ra1};
|
st.v2.u64 [%rdc], {%ra3, %ra2};
|
||||||
st.v2.u64 [%rdc + 16], {%ra2, %ra3};
|
st.v2.u64 [%rdc + 16], {%ra1, %ra0};
|
||||||
|
|
||||||
|
ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
.visible .func sub_u32(
|
||||||
|
.param .b64 out_c,
|
||||||
|
.param .align 16 .b8 in_a[32],
|
||||||
|
.param .align 16 .b8 in_b[32]
|
||||||
|
) {
|
||||||
|
.reg .u64 %ra<4>, %rb<4>;
|
||||||
|
.reg .b64 %rdc;
|
||||||
|
|
||||||
|
ld.param.b64 %rdc, [out_c];
|
||||||
|
|
||||||
|
ld.param.v2.u64 {%ra3, %ra2}, [in_a];
|
||||||
|
ld.param.v2.u64 {%ra1, %ra0}, [in_a + 16];
|
||||||
|
ld.param.v2.u64 {%rb3, %rb2}, [in_b];
|
||||||
|
ld.param.v2.u64 {%rb1, %rb0}, [in_b + 16];
|
||||||
|
|
||||||
|
sub.cc.u64 %ra0, %ra0, %rb0;
|
||||||
|
subc.cc.u64 %ra1, %ra1, %rb1;
|
||||||
|
subc.cc.u64 %ra2, %ra2, %rb2;
|
||||||
|
subc.u64 %ra3, %ra3, %rb3;
|
||||||
|
|
||||||
|
st.v2.u64 [%rdc], {%ra3, %ra2};
|
||||||
|
st.v2.u64 [%rdc + 16], {%ra1, %ra0};
|
||||||
|
|
||||||
ret;
|
ret;
|
||||||
}
|
}
|
||||||
17
ds/25-1/1e/test.py
Normal file
17
ds/25-1/1e/test.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
U8_MAX = 0xFFFFFFFFFFFFFFFF
|
||||||
|
U16_MAX = U8_MAX << 64 | U8_MAX
|
||||||
|
U32_MAX = U16_MAX << 128 | U16_MAX
|
||||||
|
|
||||||
|
def dothex(num):
|
||||||
|
strhex = hex(num)[2:]
|
||||||
|
dothex = strhex[-16:]
|
||||||
|
strhex = strhex[:-16]
|
||||||
|
|
||||||
|
while len(strhex) > 0:
|
||||||
|
dothex = strhex[-16:] + '.' + dothex
|
||||||
|
strhex = strhex[:-16]
|
||||||
|
|
||||||
|
return '0x' + dothex
|
||||||
|
|
||||||
|
print('mul16', dothex(U16_MAX * 2 % (U16_MAX + 1)))
|
||||||
|
print('mul32', dothex(U32_MAX * U16_MAX % (U32_MAX + 1)))
|
||||||
Reference in New Issue
Block a user