海思ive ann-mlp使用说明(2)

在学了在学了! 2020-10-09 18:00:00 4587

原文:https://blog.csdn.net/brightming/article/details/50895356
5 完整示例
5.1 二维数据的训练与预测
5.1.1 训练二维数据
以y=kx直线进行划分,在直线以下的为类别0,其他为类别1。
在训练的时候,指定k,产生训练数据,同时将一部分作为测试数据。

5.1.1.1 训练入口

/**

  • 以斜率=slope 来做分界,训练一个mlp模型
    */
    extern "C" void train_2_class_slope(float slope){//(int useExistModel,float from_x,float end_x,float from_y,float end_y,float x_step,float y_step){

CvANN_MLP annMlp;
int outputClassCnt=2;
bool loadModelFromFile=false;

Mat training_datas;
Mat trainClasses;
Mat oriTrainDatas;

generateFix2ClassSlopeTrainData(slope,training_datas,trainClasses);
training_datas.convertTo(training_datas,CV_32FC1);
trainClasses.convertTo(trainClasses,CV_32FC1);

cout<<"training_datas=\n"<<training_datas<<",oridata="<<oriTrainDatas<<endl;
cout<<"trainClasses=\n"<<trainClasses<<endl;

//创建mlp
Mat layers(1, 3, CV_32SC1);
layers.at(0) = training_datas.row(0).cols;
cout<<"------------------------trainAnnModel.input sample cnt:"<<training_datas.rows<<",input layer features:"<<layers.at(0,0)<<endl;
layers.at(1)=3;
layers.at(2) = outputClassCnt;//输出

cout<<"outputClassCnt="<<outputClassCnt<<endl;
annMlp.create(layers, CvANN_MLP::SIGMOID_SYM, 0.6667f, 1.7159f);

//--------训练mlp-----------//
// Set up BPNetwork‘s parameters
CvANN_MLP_TrainParams params;
params.train_method = CvANN_MLP_TrainParams::BACKPROP;
params.bp_dw_scale = 0.001;
params.bp_moment_scale = 0.0;

CvTermCriteria criteria;
criteria.max_iter = 300;
criteria.epsilon = 9.999999e-06;
criteria.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
params.term_crit = criteria;

annMlp.train(training_datas, trainClasses, Mat(), Mat(), params);
cout<<"train finished"<<endl;

char _dstPath[256];
sprintf(_dstPath,"data/my/my_simple_2_class_20160307slope%.2f.xml",slope);
string dstPath(_dstPath);//="data/my/my_simple_2_class_20160307_slope_1.xml";
annMlp.save(dstPath.c_str());
cout<<"save model finished.model file="<<dstPath<<"\n";

//预测
Mat test_datas;
Mat testClasses;
int testCount=1;//每个象限的测试图片数量
Mat oriTestData;
generate2ClassSlopeTestData(slope,test_datas,testClasses,oriTestData);//,from_x,end_x,from_y,end_y);
test_datas.convertTo(test_datas,CV_32FC1);
testClasses.convertTo(testClasses,CV_32FC1);
cout<<"test_datas=\n"<<test_datas<<",oridata="<<oriTestData<<endl;
// cout<<"testClasses=\n"<<testClasses<<endl;

int correctCount=0;
int errorCount=0;
cout<<"test_datas size="<<test_datas.rows<<endl;
int totalTestSize=test_datas.rows;
bool right=false;

// TestData_2Feature* cur=testDataHead;
int expected_idx=0;
for(int i=0;i<totalTestSize;i++){
Mat predict_result(1, outputClassCnt, CV_32FC1);
annMlp.predict(test_datas.row(i), predict_result);
Point maxLoc;
double maxVal;
minMaxLoc(predict_result, 0, &maxVal, 0, &maxLoc);

right=false;
if(test_datas.row(i).at(0,0)*slope > test_datas.row(i).at(0,1)){
expected_idx=0;
}else{
expected_idx=1;
}
if(expected_idx==maxLoc.x){
++correctCount;
right=true;
}else {
++errorCount;
}
cout<<"data:"<<test_datas.row(i)<<"("<<oriTestData.row(i)<<"),predict_result="<<predict_result<<",maxVal="<<maxVal<<",maxLoc.x="<<maxLoc.x<<",right?"<<right<<endl;

// cur=cur->next;
}

cout<<"total test data count="<<totalTestSize<<",correct count="<<correctCount<<",error count="<<errorCount<<",accurate="<<(correctCount)*1.0f/(totalTestSize)<<endl;

}

5.1.1.2 训练与测试数据产生方法

/**

  • y=x,划分,
    */
    void generateFix2ClassSlopeTrainData(float slope,Mat& mat,Mat& labels){
    vector dataVec;
    vector labVec;

float tmp1=0,tmp2=0;
printf("generateFix2ClassSlopeTrainData begin\n");

int multi=1;
float x_step=16;
float y_step=16;
int needTestSize=10;
int nowTestSize=0;

int loopcnt=0;
ostringstream os;

Int end_x=255;
Int end_y=255;
int getDataInterval=((end_x-0)/x_step (end_y-0)/y_step)/needTestSize;
printf("getDataInterval=%d,totalTrainSize=%d\n",(deltaX/x_step
deltaY/y_step));

for(int x=0;x<end_x;x+=x_step){
for(int y=0;y<end_y;y+=x_step){
++loopcnt;
dataVec.clear();
multi=-1;
tmp1=multi
(float)x;///255;
dataVec.push_back(tmp1);
tmp2=multi*(float)y;///255;
dataVec.push_back(tmp2);

// printf("tmp1=%f\n",tmp1);
// Mat tpmat=Mat(dataVec).reshape(1,1).clone();
mat.push_back(Mat(dataVec).reshape(1,1).clone());

labVec.clear();
if(tmp1*slope>tmp2){// x> 为类0
labVec.push_back(1.0f);
labVec.push_back(0.0f);
labels.push_back(Mat(labVec).reshape(1,1).clone());

if(loopcnt%getDataInterval==0){
os<<"0:";
}
}else{
labVec.push_back(0.0f);
labVec.push_back(1.0f);
labels.push_back(Mat(labVec).reshape(1,1).clone());

if(loopcnt%getDataInterval==0){
os<<"1:";
}
}
if(loopcnt%getDataInterval==0){
os<<x<<" "<<y<<endl;
}

}
}

//输出一部分作为测试文件
system("rm data/my/test2classdata_slope.list");
fstream ftxt;
string testfile="data/my/test2classdata_slope.list";
ftxt.open(testfile.c_str(), ios::out | ios::app);
if (ftxt.fail()) {
cout << "创建文件:"<<testfile<<" 失败!" <<endl;
getchar();
}
ftxt << os.str();
ftxt.close();
}

5.1.2 海思预测二维数据样本的所属类别
5.1.2.1 预测入口
/**

  • 测a试?y=kx的?分?类え?情é况?
    /
    HI_VOID SAMPLE_IVE_Ann_predict_2class_slope(float slope){
    // HI_CHAR
    pchBinFileName;
    int height,width,image_type;
    char pchBinFileName[256];
    sprintf(pchBinFileName,"./data/my/my_simple_2_class_20160307slope%.2f.bin",slope);
    // pchBinFileName = "./data/my/my_simple_2_class_20160307_slope_3.00.bin";
    height=1;
    width=2;
    image_type=IVE_IMAGE_TYPE_S32C1;

HI_S32 s32Ret;
SAMPLE_IVE_ANN_INFO_S stAnnInfo;

printf("use model bin file:%s\n",pchBinFileName);
SAMPLE_COMM_IVE_CheckIveMpiInit();

s32Ret=SAMPLE_IVE_Ann_Mlp_2Class_Slope_Init(&stAnnInfo, pchBinFileName,image_type,height,width);
if (HI_SUCCESS != s32Ret)
{
SAMPLE_PRT("SAMPLE_IVE_Ann_Mlp__2Class_Init fail\n");
goto ANN_FAIL;
}
// predict2ClassData(&stAnnInfo,slope);
predict2ClassSlopeData(&stAnnInfo,slope);

//uninit
SAMPLE_IVE_Ann_Mlp_Uninit(&stAnnInfo);

ANN_FAIL:
SAMPLE_COMM_IVE_IveMpiExit();
}

5.1.2.2 初始化

/**

  • function : Ann mlp init
    **/
    static HI_S32 SAMPLE_IVE_Ann_Mlp_2Class_Slope_Init(SAMPLE_IVE_ANN_INFO_S pstAnnInfo, HI_CHAR pchBinFileName,int image_type,int height,int width )
    {
    SAMPLE_PRT("SAMPLE_IVE_Ann_Mlp_Init.....\n");
    HI_S32 s32Ret = HI_SUCCESS;
    HI_U32 u32Size;

    memset(pstAnnInfo, 0, sizeof(SAMPLE_IVE_ANN_INFO_S));

    /**

    • 查é找ò表括?里?的?数簓值μ范?围§是?[0,1],?精?度è是?8位?,?即′1<<8=256,?
    • 表括?示?要癮被?分?成é256段?。£
    • /
      pstAnnInfo->stTable.s32TabInLower = 0;
      pstAnnInfo->stTable.s32TabInUpper = 1;//1;
      pstAnnInfo->stTable.u8TabInPreci = 8;//12;
      pstAnnInfo->stTable.u8TabOutNorm = 2;//2
      pstAnnInfo->stTable.u16ElemNum = (pstAnnInfo->stTable.s32TabInUpper-pstAnnInfo->stTable.s32TabInLower) << pstAnnInfo->stTable.u8TabInPreci;
      u32Size = pstAnnInfo->stTable.u16ElemNum
      sizeof(HI_U16);
      // SAMPLE_PRT("stTable.s32TabInLower=%d,s32TabInUpper=%d,u8TabInPreci=%d,u8TabOutNorm=%d,u16ElemNum=%d\n",pstAnnInfo->stTable.s32TabInLower,pstAnnInfo->stTable.s32TabInUpper,pstAnnInfo->stTable.u8TabInPreci,pstAnnInfo->stTable.u8TabOutNorm,pstAnnInfo->stTable.u16ElemNum);
      s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stTable.stTable), u32Size);
      if (s32Ret != HI_SUCCESS)
      {
      SAMPLE_PRT("SAMPLE_COMM_IVE_CreateMemInfo fail\n");
      goto ANN_INIT_FAIL;
      }

    s32Ret = SAMPLE_IVE_Ann_Mlp_CreateTable(&(pstAnnInfo->stTable), 0.6667f, 1.7159f);

// s32Ret = SAMPLE_IVE_Ann_Mlp_CreateTable(&(pstAnnInfo->stTable), 1.0f, 1.0f);
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("SAMPLE_IVE_Ann_Mlp_CreateTable fail\n");
goto ANN_INIT_FAIL;
}
SAMPLE_PRT("begin to load model:%s\n",pchBinFileName);
s32Ret = HI_MPI_IVE_ANN_MLP_LoadModel(pchBinFileName, &(pstAnnInfo->stAnnModel));
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("HI_MPI_IVE_ANN_MLP_LoadModel fail,Error(%#x)\n", s32Ret);
goto ANN_INIT_FAIL;
}
printf("finish load model:%s\n",pchBinFileName);

u32Size = pstAnnInfo->stAnnModel.au16LayerCount[0] * sizeof(HI_S16Q16);//输?入?层?需è要癮的?空?间?大洙?小?:阰输?入?层?的?元a素?个?数簓*每?个?元a素?的?大洙?小?
printf("allocate memory for input,size=%d\n",u32Size);
s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stSrc), u32Size);
if (s32Ret != HI_SUCCESS)
{
    SAMPLE_PRT("SAMPLE_COMM_IVE_CreateMemInfo fail\n");
    goto ANN_INIT_FAIL;
}

u32Size = pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1] * sizeof(HI_S16Q16);//输?出?类え?别纄信?息¢所ù需è空?间?的?大洙?小?:阰输?出?层?类え?别纄数簓*每?个?类え?别纄数簓值μ的?占?的?空?间?

// SAMPLE_PRT("annModel output class cnt=%d\n",pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1]);
printf("allocate memory for output,size=%d\n",u32Size);
s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stDst), u32Size);
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("SAMPLE_COMM_IVE_CreateMemInfo fail\n");
goto ANN_INIT_FAIL;
}

ANN_INIT_FAIL:

// printf("s32Ret=%d,HI_SUCCESS=%d\n",s32Ret,HI_SUCCESS);
if (HI_SUCCESS != s32Ret)
{
SAMPLE_IVE_Ann_Mlp_Uninit(pstAnnInfo);
}

return s32Ret;

}

5.1.2.3 预测

/**

  • 预¤测ay=kx的?分?类え?
    /
    void predict2ClassSlopeData(SAMPLE_IVE_ANN_INFO_S
    pstAnnInfo,float slope){
    char contFile="data/my/test2classdata_slope_eq_1.list";
    printf("try to get file info:%s\n",contFile);
    TestData_2Feature
    head=get2FeatureData(contFile);
    printf("after read file:%s,head=%p\n",contFile,head);
    if(!head){
    printf("fail to read contFile:%s\n",contFile);
    return;
    }

// printStringNode(head,"1");
// printStringNode(head,"2");

HI_S32 i, k;
HI_S32 s32Ret;
HI_S32 s32ResponseCls;
HI_U16 u16LayerCount;
HI_S16Q16 ps16q16Dst;
HI_S16Q16 s16q16Response;
HI_BOOL bInstant = HI_TRUE;
HI_BOOL bFinish;
HI_BOOL bBlock = HI_TRUE;
// HI_CHAR achFileName[IVE_FILE_NAME_LEN];
FILE
pFp = HI_NULL;
IVE_HANDLE iveHandle;

int xs[3]={-5,-4,3};
int ys[3]={99,-10,10};

srand(time(NULL));

int totalCount=0;
int correctCount=0;
TestData_2Feature* cur=head;

int cnt=0;
int expected_idx=0;
while(cur!=NULL){
// printf("flag=%d,filePath=%s,filenName=%s -->\n ",cur->flag,cur->fileFullPath,cur->fileName);
ps16q16Dst = (HI_S16Q16*)pstAnnInfo->stDst.pu8VirAddr;
s16q16Response = 0;
s32ResponseCls = -1;

HI_S16Q16 stSrc=(HI_S16Q16)pstAnnInfo->stSrc.pu8VirAddr;
stSrc[0]=changeFloatToS16Q16(cur->x1);//转换为以s16q16表示的数据
stSrc[1]=changeFloatToS16Q16(cur->x2);

s32Ret = HI_MPI_IVE_ANN_MLP_Predict(&iveHandle, &(pstAnnInfo->stSrc), \
& (pstAnnInfo->stTable), &(pstAnnInfo->stAnnModel), &(pstAnnInfo->stDst), bInstant);
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("HI_MPI_IVE_ANN_MLP_Predict fail,Error(%#x)\n", s32Ret);
break;
}
s32Ret = HI_MPI_IVE_Query(iveHandle, &bFinish, bBlock);
while (HI_ERR_IVE_QUERY_TIMEOUT == s32Ret)
{
usleep(100);
s32Ret = HI_MPI_IVE_Query(iveHandle, &bFinish, bBlock);
}
if (HI_SUCCESS != s32Ret)
{
SAMPLE_PRT("HI_MPI_IVE_Query fail,Error(%#x)\n", s32Ret);
break;
}
u16LayerCount = pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1];
// SAMPLE_PRT("pstAnnInfo->CstAnnModel.u8LayerNum=%d,pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1]=%d\n",pstAnnInfo->stAnnModel.u8LayerNum,pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1]);
SAMPLE_PRT(" \n--predict2ClassSlopeData--Begin-- x1=%f(s16q16=%d),x2=%f(s16q16=%d)\n",cur->x1,changeFloatToS16Q16(cur->x1),cur->x2,changeFloatToS16Q16(cur->x2));
++totalCount;
for (k = 0; k < u16LayerCount; k++)
{
printf(" ps16q16Dst[%d]=%d,H16Q16=%f\n", k,ps16q16Dst[k],calculateS16Q16_c(ps16q16Dst[k]));
if (s16q16Response < ps16q16Dst[k])
{
s16q16Response = ps16q16Dst[k];
s32ResponseCls = k;
}
}

if(cur->x1*slope>cur->x2){
expected_idx=0;
}else{
expected_idx=1;
}
SAMPLE_PRT(" --predict2ClassSlopeData--End-- result:%s,flag:%d,class:%d ------\n\n",(expected_idx==s32ResponseCls?"right":"wrong"),expected_idx,s32ResponseCls);

cur=cur->next;
}

freeTestData_2FeatureNode(head);
}
附上斜率为的预测结果输出:
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=0.220000(s16q16=14417),x2=0.100000(s16q16=6553)
ps16q16Dst[0]=46174,H16Q16=0.704559
ps16q16Dst[1]=20098,H16Q16=0.306671
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:0,class:0 ------

[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=-1.000000(s16q16=-65536),x2=-3.000000(s16q16=-196608)
ps16q16Dst[0]=48919,H16Q16=0.746445
ps16q16Dst[1]=15412,H16Q16=0.235168
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:0,class:0 ------

[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=1.000000(s16q16=65536),x2=0.200000(s16q16=13107)
ps16q16Dst[0]=48919,H16Q16=0.746445
ps16q16Dst[1]=15412,H16Q16=0.235168
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:0,class:0 ------

[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=0.200000(s16q16=13107),x2=0.700000(s16q16=45875)
ps16q16Dst[0]=16687,H16Q16=0.254623
ps16q16Dst[1]=51450,H16Q16=0.785065
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:1,class:1 ------

[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=0.400000(s16q16=26214),x2=0.900000(s16q16=58982)
ps16q16Dst[0]=16830,H16Q16=0.256805
ps16q16Dst[1]=51033,H16Q16=0.778702
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:1,class:1 ------

[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=0.690196(s16q16=45232),x2=0.062745(s16q16=4112)
ps16q16Dst[0]=48919,H16Q16=0.746445
ps16q16Dst[1]=15412,H16Q16=0.235168
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:0,class:0 ------

[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=224.000000(s16q16=14680064),x2=80.000000(s16q16=5242880)
ps16q16Dst[0]=17622,H16Q16=0.268890
ps16q16Dst[1]=45294,H16Q16=0.691132
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:wrong,flag:0,class:1 ------

[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=-224.000000(s16q16=-14680064),x2=80.000000(s16q16=5242880)
ps16q16Dst[0]=17117,H16Q16=0.261185
ps16q16Dst[1]=51728,H16Q16=0.789307
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:1,class:1 ------

声明:本文内容由易百纳平台入驻作者撰写,文章观点仅代表作者本人,不代表易百纳立场。如有内容侵权或者其他问题,请联系本站进行删除。
红包 点赞 收藏 评论 打赏
评论
0个
内容存在敏感词
手气红包
    易百纳技术社区暂无数据
相关专栏
置顶时间设置
结束时间
删除原因
  • 广告/SPAM
  • 恶意灌水
  • 违规内容
  • 文不对题
  • 重复发帖
打赏作者
易百纳技术社区
在学了在学了!
您的支持将鼓励我继续创作!
打赏金额:
¥1易百纳技术社区
¥5易百纳技术社区
¥10易百纳技术社区
¥50易百纳技术社区
¥100易百纳技术社区
支付方式:
微信支付
支付宝支付
易百纳技术社区微信支付
易百纳技术社区
打赏成功!

感谢您的打赏,如若您也想被打赏,可前往 发表专栏 哦~

举报反馈

举报类型

  • 内容涉黄/赌/毒
  • 内容侵权/抄袭
  • 政治相关
  • 涉嫌广告
  • 侮辱谩骂
  • 其他

详细说明

审核成功

发布时间设置
发布时间:
是否关联周任务-专栏模块

审核失败

失败原因
备注
拼手气红包 红包规则
祝福语
恭喜发财,大吉大利!
红包金额
红包最小金额不能低于5元
红包数量
红包数量范围10~50个
余额支付
当前余额:
可前往问答、专栏板块获取收益 去获取
取 消 确 定

小包子的红包

恭喜发财,大吉大利

已领取20/40,共1.6元 红包规则

    易百纳技术社区