gpt4 book ai didi

c++ - C++ 中 Armadillo 矩阵维度的动态参数化

转载 作者:行者123 更新时间:2023-12-05 03:45:22 24 4
gpt4 key购买 nike

标题总结了更准确地动态检索传递给 Armadillo 矩阵的 MATLAB 数组维数的目标。

我想将 mY() 和 mD() 的第二个和第三个参数更改为下面的参数。

// mat(ptr_aux_mem, n_rows, n_cols, copy_aux_mem = true, strict = false)
arma::mat mY(&dY[0], 2, 168, false);
arma::mat mD(&dD[0], 2, 168, false);

这肯定是一个常见的用例,但对于从 MATLAB 馈送的数组维数可以是任意的 (n > 2) 的一般情况,我仍然找不到实现它的好方法。

对于矩阵(二维)的情况,我可能会破解我的方法,但我觉得这不够优雅(可能效率也不高)。

恕我直言,要走的路必须是:

matlab::data::TypedArray<double>getDimensions()检索 matlab::data::ArrayDimensions 的成员函数这基本上是一个 std::vector<size_t> .

索引由 getDimensions() 检索到的 vector 的第一个和第二个元素可以检索行数和列数,如下所示。

unsigned int mYrows = matrixY.getDimensions()[0];
unsigned int mYcols = matrixY.getDimensions()[1];

但是,根据我当前的设置,我无法调用 getDimensions()通过 foo() 中的指针/引用sub.cpp 的功能。如果可行,我既不想创建额外的临时对象,也不想将其他参数传递给 foo() .怎么可能呢?

直觉一直告诉我,也一定有一个优雅的解决方案。也许使用多重间接寻址?

我非常感谢知识渊博的 SO 成员提供的任何帮助、提示或建设性意见。提前谢谢你。

设置:

两个C++源文件和一个头文件:

main.cpp

  • 包含 MATLAB 和 C++ 之间的通用 IO 接口(interface)
  • 将两个 double 数组和两个 double const double 送入 C++
  • 它通过调用 foo()
  • 做一些基于 Armadillo 的循环(这部分不那么重要因此省略)
  • 返回 outp,这是一个“普通” double 标量
  • 没有花哨或复杂的东西。

子.cpp

  • 这仅适用于 foo() 循环部分。

sub.hpp

  • 只是一个简单的头文件。
// main.cpp
// MATLAB API Header Files
#include "mex.hpp"
#include "mexAdapter.hpp"

// Custom header
#include "sub.hpp"

// Overloading the function call operator, thus class acts as a functor
class MexFunction : public matlab::mex::Function {
public:
void operator()(matlab::mex::ArgumentList outputs,
matlab::mex::ArgumentList inputs){

matlab::data::ArrayFactory factory;
// Validate arguments
checkArguments(outputs, inputs);

matlab::data::TypedArray<double> matrixY = std::move(inputs[0]);
matlab::data::TypedArray<double> matrixD = std::move(inputs[1]);
const double csT = inputs[2][0];
const double csKy = inputs[3][0];

buffer_ptr_t<double> mY = matrixY.release();
buffer_ptr_t<double> mD = matrixD.release();

double* darrY = mY.get();
double* darrD = mD.get();

// data type of outp is "just" a plain double, NOT a double array
double outp = foo(darrY, darrD, csT, csKy);

outputs[0] = factory.createScalar(outp);

void checkArguments(matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs){
// Create pointer to MATLAB engine
std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
// Create array factory, allows us to create MATLAB arrays in C++
matlab::data::ArrayFactory factory;
// Check input size and types
if (inputs[0].getType() != ArrayType::DOUBLE ||
inputs[0].getType() == ArrayType::COMPLEX_DOUBLE)
{
// Throw error directly into MATLAB if type does not match
matlabPtr->feval(u"error", 0,
std::vector<Array>({ factory.createScalar("Input must be double array.") }));
}
// Check output size
if (outputs.size() > 1) {
matlabPtr->feval(u"error", 0,
std::vector<Array>({ factory.createScalar("Only one output is returned.") }));
}
}
};

// sub.cpp

#include "sub.hpp"
#include "armadillo"

double foo(double* dY, double* dD, const double T, const double Ky) {

double sum = 0;

// Conversion of input parameters to Armadillo types
// mat(ptr_aux_mem, n_rows, n_cols, copy_aux_mem = true, strict = false)
arma::mat mY(&dY[0], 2, 168, false);
arma::mat mD(&dD[0], 2, 168, false);

// Armadillo calculations

for(int t=0; t<int(T); t++){

// some armadillo based calculation
// each for cycle increments sum by its return value
}

return sum;
}

// sub.hpp

#ifndef SUB_H_INCLUDED
#define SUB_H_INCLUDED

double foo(double* dY, double* dD, const double T, const double Ky);

#endif // SUB_H_INCLUDED

最佳答案

一种方法是使用函数将其转换为 arma 矩阵

template<class T>
arma::Mat<T> getMat( matlab::data::TypedArray<T> A)
{
matlab::data::TypedIterator<T> it = A.begin();
matlab::data::ArrayDimensions nDim = A.getDimensions();
return arma::Mat<T>(it.operator->(), nDim[0], nDim[1]);
}

并调用它

 arma::mat Y = getMat<double>(inputs[0]);
arma::mat D = getMat<double>(inputs[1]);
...
double outp = foo(Y,D, csT, csKy);

并将foo()更改为

double foo( arma::mat& dY, arma::mat& dD, const double T, const double Ky) 

关于c++ - C++ 中 Armadillo 矩阵维度的动态参数化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65868123/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com