Files
lab/ds/25-1/1e/main.cu
2025-12-15 18:53:54 +03:00

152 lines
3.4 KiB
Plaintext

#include <stdio.h>
#include <stdint.h>
extern "C" __device__ void add_u16(
ulonglong2 *out_c,
ulonglong2 in_a,
ulonglong2 in_b
);
extern "C" __device__ void sub_u16(
ulonglong2 *out_c,
ulonglong2 in_a,
ulonglong2 in_b
);
extern "C" __device__ void add_u32(
ulonglong4 *out_c,
ulonglong4 in_a,
ulonglong4 in_b
);
extern "C" __device__ void sub_u32(
ulonglong4 *out_c,
ulonglong4 in_a,
ulonglong4 in_b
);
__device__ void mul_u16(
ulonglong2 *out_c,
ulonglong2 in_a,
ulonglong2 in_b
) {
uint64_t ax_ay = in_a.x + in_a.y;
uint64_t bx_by = in_b.x + in_b.y;
uint64_t axbx = in_a.x * in_b.x;
uint64_t ayby = in_a.y * in_b.y;
out_c->x = ax_ay * bx_by - axbx - ayby;
out_c->y = ayby;
}
__device__ void mul_u32(
ulonglong4 *out_c,
ulonglong4 in_a,
ulonglong4 in_b
) {
auto ax = (ulonglong2 *)&in_a.x;
auto ay = (ulonglong2 *)&in_a.z;
auto bx = (ulonglong2 *)&in_b.x;
auto by = (ulonglong2 *)&in_b.z;
ulonglong2 ax_ay, bx_by, paren, axbx, ayby;
add_u16(&ax_ay, *ax, *ay);
add_u16(&bx_by, *bx, *by);
mul_u16(&paren, ax_ay, bx_by);
mul_u16(&axbx, *ax, *bx);
mul_u16(&ayby, *ay, *by);
sub_u16(&paren, paren, axbx);
sub_u16(&paren, paren, ayby);
out_c->x = paren.x;
out_c->y = paren.y;
out_c->z = ayby.x;
out_c->w = ayby.y;
}
__device__ bool equ_u16(ulonglong2 a, ulonglong2 b) {
return a.x == b.x && a.y == b.y;
}
__device__ bool equ_u32(ulonglong4 a, ulonglong4 b) {
return a.x == b.x &&
a.y == b.y &&
a.z == b.z &&
a.w == b.w;
}
__device__ void print_u16(ulonglong2 a) {
printf("0x%016llx.%016llx\n", a.x, a.y);
}
__device__ void print_u32(ulonglong4 a) {
printf("0x%016llx.%016llx.%016llx.%016llx\n", a.x, a.y, a.z, a.w);
}
#define U8_MAX 0xFFFFFFFFFFFFFFFF
#define _U16_MAX {U8_MAX, U8_MAX}
#define _U32_MAX {U8_MAX, U8_MAX, U8_MAX, U8_MAX}
__global__ void test(bool *passed) {
*passed = false;
{
ulonglong4 a = _U32_MAX;
ulonglong4 b = {0, 0, 0, 1};
ulonglong4 c = {0, 0, 0, 0};
add_u32(&a, a, b);
if (!equ_u32(a, c)) {
printf("add_u32 ");
print_u32(a);
return;
}
}
{
ulonglong4 a = {0, 0, 0, 0};
ulonglong4 b = {0, 0, 0, 1};
ulonglong4 c = _U32_MAX;
sub_u32(&a, a, b);
if (!equ_u32(a, c)) {
printf("sub_32 ");
print_u32(a);
return;
}
}
{
ulonglong2 a = _U16_MAX;
ulonglong2 b = {0, U8_MAX};
ulonglong2 c = {U8_MAX, 1};
mul_u16(&a, a, b);
if (!equ_u16(a, c)) {
printf("mul_16 ");
print_u16(a);
return;
}
}
{
ulonglong4 a = _U32_MAX;
ulonglong4 b = {0, 0, U8_MAX, U8_MAX};
ulonglong4 c = {U8_MAX, U8_MAX, 0, 1};
mul_u32(&a, a, b);
if (!equ_u32(a, c)) {
printf("mul_32 ");
print_u32(a);
return;
}
}
*passed = true;
}
int main() {
bool test_passed, *d_test_passed;
cudaMalloc(&d_test_passed, sizeof(bool));
test<<<1, 1>>>(d_test_passed);
cudaDeviceSynchronize();
cudaMemcpy(&test_passed, d_test_passed, sizeof(bool), cudaMemcpyDeviceToHost);
cudaFree(d_test_passed);
if (!test_passed) {
printf("test not passed\n");
return 1;
}
return 0;
}