gpt4 book ai didi

java - 为什么这个 JNI 程序不将浮点值复制回 Java 端?

转载 作者:行者123 更新时间:2023-12-01 09:16:57 33 4
gpt4 key购买 nike

我有这个代码:

#if defined(NOT_STANDALONE)
JNIEXPORT void JNICALL sumTraces
(JNIEnv* env, jclass caller, jobjectArray jprestackTraces, jint nTracesIn, jobjectArray jsampleShifts,
jobjectArray jstartIndices, jobjectArray jnSamples, jobjectArray jstackTracesOut,
jobjectArray jpowerTracesOut, jint nTracesOut, jint samplesPerTrace) {

jboolean isCopy;

float* prestackTraces1D = (float*)malloc(nTracesIn * samplesPerTrace * sizeof(float));
if (prestackTraces1D == NULL) Fatal("Could not malloc prestackTraces1D");
int* sampleShifts1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int));
if (sampleShifts1D == NULL) Fatal("Could not malloc sampleShifts1D");
int* startIndices1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int));
if (startIndices1D == NULL) Fatal("Could not malloc startIndices1D");
int* nSamples1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int));
if (nSamples1D == NULL) Fatal("Could not malloc nSamples1D");

for (int in = 0; in < nTracesIn; in++) {

jfloatArray j_prestack = (jfloatArray)env->GetObjectArrayElement(jprestackTraces, in);
float* prestackTracesJava = (float*)env->GetPrimitiveArrayCritical(j_prestack, &isCopy);

for (int s = 0; s < samplesPerTrace; s++) {
int readIndex = s + (in * samplesPerTrace);
prestackTraces1D[readIndex] = prestackTracesJava[s];
}

env->ReleasePrimitiveArrayCritical(j_prestack, prestackTracesJava, JNI_ABORT);
}

for (int out = 0; out < nTracesOut; out++) {

jintArray j_shift = (jintArray)env->GetObjectArrayElement(jsampleShifts, out);
int* sampleShiftsJava = (int*)env->GetPrimitiveArrayCritical(j_shift, &isCopy);
jintArray j_start = (jintArray)env->GetObjectArrayElement(jstartIndices, out);
int* startIndicesJava = (int*)env->GetPrimitiveArrayCritical(j_start, &isCopy);
jintArray j_nSamps = (jintArray)env->GetObjectArrayElement(jnSamples, out);
int* nSamplesJava = (int*)env->GetPrimitiveArrayCritical(j_nSamps, &isCopy);

for (int in = 0; in < nTracesIn; in++) {
int readIndex = in + (out * nTracesIn);
sampleShifts1D[readIndex] = sampleShiftsJava[in];
startIndices1D[readIndex] = startIndicesJava[in];
nSamples1D[readIndex] = nSamplesJava[in];
}

env->ReleasePrimitiveArrayCritical(j_nSamps, nSamplesJava, JNI_ABORT);
env->ReleasePrimitiveArrayCritical(j_start, startIndicesJava, JNI_ABORT);
env->ReleasePrimitiveArrayCritical(j_shift, sampleShiftsJava, JNI_ABORT);
}

float* stackTracesOut1D = (float*)malloc(nTracesOut * samplesPerTrace * sizeof(float));
if (stackTracesOut1D == NULL) Fatal("Could not malloc stackTracesOut1D");
float* powerTracesOut1D = (float*)malloc(nTracesOut * samplesPerTrace * sizeof(float));
if (powerTracesOut1D == NULL) Fatal("Could not malloc powerTracesOut1D");

// Run the OpenCL program
ComputeTraces(prestackTraces1D, stackTracesOut1D, powerTracesOut1D,
startIndices1D, nSamples1D, sampleShifts1D,
samplesPerTrace, nTracesIn, nTracesOut,
0, 0, 1000);

// Free the arrays that we can
free(nSamples1D);
free(startIndices1D);
free(sampleShifts1D);
free(prestackTraces1D);

// Copy back the output for Java
for (int out = 0; out < nTracesOut; out++) {
jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out);
jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out);

float* stackOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float));
float* powerOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float));
for (int s = 0; s < samplesPerTrace; s++) {
int readIndex = s + (out * samplesPerTrace);
stackOutCopyBack[s] = stackTracesOut1D[readIndex];
powerOutCopyBack[s] = powerTracesOut1D[readIndex];
}

for (int s = 0; s < samplesPerTrace; s++) {
printf("%d %f/%f\n", s, stackOutCopyBack[s], powerOutCopyBack[s]);
}

env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0);
env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0);

free(stackOutCopyBack);
free(powerOutCopyBack);
}

// Free the output arrays
free(powerTracesOut1D);
free(stackTracesOut1D);
}

ComputeTraces(...) 方法用值填充 stackTracesOut1D 和 powerTracesOut1D 数组。我知道这些值是正确的,因为 for 循环内接近末尾的 printf 语句,我将它与我想要的值进行比较,它们匹配。然而,当我检查 Java 端时,所有值都被清零。为什么这段 JNI 代码没有将数据复制回来?

请记住,正如您在代码中看到的那样,我必须将通过参数给出的 2D 数组压缩为 1D 数组,以便传递到函数中。因此,在将数据复制回之前,我将较大的一维数组的一部分复制到较小的数组中,这是 ReleasePrimitiveArrayCritical 中的参数之一,但这些值不会复制回来。

编辑:为了清楚起见,我正在谈论从最后开始大约 10 行的行; env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0); 其中我使用的是 0。

最佳答案

所以问题很简单,我忘记在输出数组上使用 GetPrimitiveArrayCritical(...) 。所以:

  for (int out = 0; out < nTracesOut; out++) {
jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out);
jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out);

float* stackOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float));
float* powerOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float));

for (int s = 0; s < samplesPerTrace; s++) {
int readIndex = s + (out * samplesPerTrace);
stackOutCopyBack[s] = stackTracesOut1D[readIndex];
powerOutCopyBack[s] = powerTracesOut1D[readIndex];
}

env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0);
env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0);

free(stackOutCopyBack);
free(powerOutCopyBack);
}

变成:

  for (int out = 0; out < nTracesOut; out++) {
jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out);
jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out);

float* stackOutCopyBack = (float*)env->GetPrimitiveArrayCritical(j_stackOut, &isCopy);
float* powerOutCopyBack = (float*)env->GetPrimitiveArrayCritical(j_powerOut, &isCopy);

for (int s = 0; s < samplesPerTrace; s++) {
int readIndex = s + (out * samplesPerTrace);
stackOutCopyBack[s] = stackTracesOut1D[readIndex];
powerOutCopyBack[s] = powerTracesOut1D[readIndex];
}

env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0);
env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0);
}

删除 free 也很重要,否则我们会尝试从内存中删除数组两次。

关于java - 为什么这个 JNI 程序不将浮点值复制回 Java 端?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40478017/

33 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com