圖像重採樣(CPU和GPU)

1 前言

    以前在寫影像融合算法的時候,免不了要實現將多光譜影像重採樣到全色大小。當時爲了避免影響融合算法總體開發進度,其中重採樣功能用的是GDAL開源庫中的Warp接口實現的。算法

後來發現GDAL Warp接口實現的多光譜到全色影像的重採樣主要存在兩個問題:1 與原有平臺的已有功能不兼容,產生衝突;2 效率較低。所以,決定從新設計和開發一個這樣的功能,方便後期軟件系統的維護等。函數

 

2 圖像重採樣

圖像處理從形式上來講主要包括兩個方面:1 單像素或者鄰域像素的處理,好比影像的相加或者濾波運算等;2 圖像幾何空間變換,如圖像的重採樣,配準等。oop

影像重採樣的幾何空間變換公式以下:post

 

其中 爲變換系數,經常使用的重採樣算法主要包括如下三種:1 最近鄰;2 雙線性;3 三次卷積。ui

2.1 最近鄰採樣

最近鄰採樣的原理概況起來就是用採樣點位置最近的一個像素值替代採樣點位置的像素值。在這裏插入一點:spa

一般圖像空間變換有兩種方法,直接法或者間接法。以圖像重採樣爲例說明以下:直接法:從原始的圖像行列初始值開始,根據變換公式,計算採樣後的像素位置,並對位置賦值,可是這種方法會出現,原始圖像的多個像素點對應到同一採樣後的像素點,從而還要增長額外方法進行處理;間接法:是從重採樣後圖像的行列初始值開始,計算獲得其在原始影像中的位置,並根據必定的算法進行計算,獲得採樣後的值。這種方法簡單直接,本文就是採用這樣的方法。設計

最近鄰採樣的實現算法以下:3d

首先遍歷採樣點,根據公式計算採樣點在原始圖像中的位置,假設位置爲 。而後計算與 最近的點,並將其像素值賦爲採樣點的像素值。其公式以下:日誌

2.2 雙線性

雙線性採樣和最近鄰賦值不一樣,是找到 最近的四個像素點,而後將距離做爲權重分別插值 和 方向,從而插值到採樣點位置。具體公式見代碼部分。code

2.3 三次卷積

三次卷積是利用 最近的16個像素點進行插值計算獲得。一樣也是分別插值 和 方向。具體公式見代碼部分。

 

3 實驗結果

主要實現了兩個版本的差值結果:1 CPU版本;2 GPU版本(初學)。考慮到多光譜和全色影像範圍大小不一致的事實,算法首先計算全色和多光譜的重疊區域。話很少說,先看看詳細的代碼實現過程。

CPU版本:

  1 #ifndef RESAMPLECPU_H
  2 #define RESAMPLECPU_H
  3 
  4 #include <gdal_alg_priv.h>
  5 #include <gdal_priv.h>
  6 #include <assert.h>
  7 
  8 
  9 template<typename T>
 10 void ReSampleCPUKernel(const float *mssData,
 11                        T *resampleData,
 12                        int mssWidth,
 13                        int mssHeight,
 14                        int mssBandCount,
 15                        int mssOffsetX,
 16                        int mssOffsetY,
 17                        int panWidth,
 18                        int panHeight,
 19                        float radioX,
 20                        float radioY,
 21                        double dfDstNoDataValue,
 22                        int MethodType)
 23 {
 24     assert(mssData != NULL);
 25     float eps = 0.00001f;
 26     for(int idx = 0;idx < panHeight;idx++){
 27         for(int idy = 0;idy<panWidth;idy++){
 28             // 找到對應的MSS像素位置
 29             float curX = (float)idx * radioX;
 30             float curY = (float)idy * radioY;
 31             if(mssData[int(curX)*mssWidth*mssBandCount + int(curY)] == dfDstNoDataValue)
 32             {
 33                 int i = 0;
 34                 while(i < mssBandCount){
 35                     resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = T(dfDstNoDataValue);
 36                     i++;
 37                 }
 38                 continue;
 39             }
 40             if(MethodType == 0){  // 最近鄰
 41                 int nearX = (int)(curX + 0.5)>(int)curX?(int)(curX + 1):(int)curX;
 42                 int nearY = (int)(curY + 0.5)>(int)curY?(int)(curY + 1):(int)curY;
 43                 if(nearX >= mssHeight - 1){
 44                     nearX = mssHeight - 1;
 45                 }
 46                 if(nearY >= mssWidth - 1){
 47                     nearY = mssWidth - 1;
 48                 }
 49                 if(nearX < mssHeight && nearY < mssWidth){
 50                     int i = 0;
 51                     while(i < mssBandCount){
 52                         float tmp = mssData[nearX*mssWidth*mssBandCount + i*mssWidth + nearY];
 53                         resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = T(tmp);
 54                         i++;
 55                     }
 56                 }
 57             }
 58 
 59             if(MethodType == 1){   // 雙線性
 60                 float dataX = curX - (int)curX;
 61                 float dataY = curY - (int)curY;
 62                 int preX = (int)curX;
 63                 int preY = (int)curY;
 64                 int postX = (int)curX + 1;
 65                 int postY = (int)curY + 1;
 66                 if(postX >= mssHeight - 1){
 67                     postX = mssHeight - 1;
 68                 }
 69                 if(postY >= mssWidth - 1){
 70                     postY = mssWidth - 1;
 71                 }
 72 
 73                 float Wx1 = 1 - dataX;
 74                 float Wx2 = dataX;
 75                 float Wy1 = 1 - dataY;
 76                 float Wy2 = dataY;
 77                 // 雙線性差值核心代碼
 78                 int i = 0;
 79                 while(i < mssBandCount){
 80                     float pMssValue[4] = {0,0,0,0};
 81                     pMssValue[0] = mssData[preX*mssWidth*mssBandCount + i*mssWidth + preY];
 82                     pMssValue[1] = mssData[preX*mssWidth*mssBandCount + i*mssWidth + postY];
 83                     pMssValue[2] = mssData[postX*mssWidth*mssBandCount + i*mssWidth + preY];
 84                     pMssValue[3] = mssData[postX*mssWidth*mssBandCount + i*mssWidth + postY];
 85                     float tmp = Wy1*(Wx1*pMssValue[0] + Wx2*pMssValue[2]) + Wy2*(Wx1*pMssValue[1] + Wx2*pMssValue[3]);
 86                     resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = T(tmp);
 87                     i++;
 88                 }
 89             }
 90 
 91 
 92             if(MethodType == 2){  // 雙三次卷積
 93                 float dataX = curX - (int)curX;
 94                 float dataY = curY - (int)curY;
 95                 int preX1 = (int)curX - 1;
 96                 int preX2 = (int)curX;
 97                 int postX1 = (int)curX + 1;
 98                 int postX2 = (int)curX + 2;
 99                 int preY1 = (int)curY - 1;
100                 int preY2 = (int)curY;
101                 int postY1 = (int)curY + 1;
102                 int postY2 = (int)curY + 2;
103                 if(preX1 < 0) preX1 = 0;
104                 if(preY1 < 0) preY1 = 0;
105                 if(postX1 > mssHeight - 1) postX1 = mssHeight - 1;
106                 if(postX2 > mssHeight - 1) postX2 = mssHeight - 1;
107                 if(postY1 > mssWidth - 1) postY1 = mssWidth - 1;
108                 if(postY2 > mssWidth - 1) postY2 = mssWidth - 1;
109 
110                 float Wx1 = -1.0f*dataX + 2*dataX*dataX - dataX*dataX*dataX;
111                 float Wx2 = 1 - 2*dataX*dataX + dataX*dataX*dataX;
112                 float Wx3 = dataX + dataX*dataX - dataX*dataX*dataX;
113                 float Wx4 = -1.0f*dataX*dataX + dataX*dataX*dataX;
114                 float Wy1 = -1.0f*dataY + 2*dataY*dataY - dataY*dataY*dataY;
115                 float Wy2 = 1 - 2*dataY*dataY + dataY*dataY*dataY;
116                 float Wy3 = dataY + dataY*dataY - dataY*dataY*dataY;
117                 float Wy4 = -1.0f*dataY*dataY + dataY*dataY*dataY;
118                 
119                 int i = 0;
120                 while(i < mssBandCount){
121                     float *pMssValue = new float[16];
122                     memset(pMssValue,0,sizeof(float)*16);
123                     pMssValue[0] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + preY1];
124                     pMssValue[1] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + preY2];
125                     pMssValue[2] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + postY1];
126                     pMssValue[3] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + postY2];
127 
128                     pMssValue[4] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + preY1];
129                     pMssValue[5] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + preY2];
130                     pMssValue[6] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + postY1];
131                     pMssValue[7] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + postY2];
132 
133                     pMssValue[8] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + preY1];
134                     pMssValue[9] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + preY2];
135                     pMssValue[10] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + postY1];
136                     pMssValue[11] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + postY2];
137 
138                     pMssValue[12] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + preY1];
139                     pMssValue[13] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + preY2];
140                     pMssValue[14] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + postY1];
141                     pMssValue[15] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + postY2];
142 
143                     float tmp = Wy1*(Wx1*pMssValue[0] + Wx2*pMssValue[4] + Wx3*pMssValue[8] + Wx4*pMssValue[12])+    
144                         Wy2*(Wx1*pMssValue[1] + Wx2*pMssValue[5] + Wx3*pMssValue[9] + Wx4*pMssValue[13])+
145                         Wy3*(Wx1*pMssValue[2] + Wx2*pMssValue[6] + Wx3*pMssValue[10] + Wx4*pMssValue[14])+
146                         Wy4*(Wx1*pMssValue[3] + Wx2*pMssValue[7] + Wx3*pMssValue[11] + Wx4*pMssValue[15]);
147                     resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = T(tmp);
148                     delete []pMssValue;pMssValue = NULL;
149                     i++;
150                 }    
151             }
152         }
153     }
154 }
155 
156 int ReSampleCPUApp(const char *mssfileName,
157                       const char *panfileName,
158                       const char *resamplefileName,
159                       int MethodType = 1)
160 {
161     GDALAllRegister();
162     GDALDataset *poPANDS = (GDALDataset*)GDALOpen(panfileName,GA_ReadOnly);
163     GDALDataset *poMSSDS = (GDALDataset*)GDALOpen(mssfileName,GA_ReadOnly);
164     if(!poPANDS || !poMSSDS)
165         return -1;
166 
167     //MSS info
168     int mssBandCount = poMSSDS->GetRasterCount();
169     int mssWidth = poMSSDS->GetRasterXSize();
170     int mssHeight = poMSSDS->GetRasterYSize();
171     double adfMssGeoTransform[6] = {0};
172     poMSSDS->GetGeoTransform(adfMssGeoTransform);
173     GDALDataType mssDT = poMSSDS->GetRasterBand(1)->GetRasterDataType();
174 
175     int bSrcHasNoData;
176     double dfSrcNoDataValue = 0;
177     dfSrcNoDataValue = GDALGetRasterNoDataValue(poMSSDS->GetRasterBand(1),&bSrcHasNoData);
178     if(!bSrcHasNoData) dfSrcNoDataValue = 0.0;
179     //DT = mssDT;
180 
181     int *pBandMap= new int[mssBandCount];
182     for(int i = 0;i<mssBandCount;i++){
183         pBandMap[i] = i+1;
184     }
185 
186     // PAN Info
187     int panBandCount = poPANDS->GetRasterCount();
188     int panWidth = poPANDS->GetRasterXSize();
189     int panHeidht = poPANDS->GetRasterYSize();
190     double adfPanGeoTransform[6] = {0};
191     poPANDS->GetGeoTransform(adfPanGeoTransform);
192     GDALDataType panDT = poPANDS->GetRasterBand(1)->GetRasterDataType();
193 
194     // 建立新數據集=======投影信息
195     double adfResampleGeoTransform[6] = {0};
196     adfResampleGeoTransform[1] = adfPanGeoTransform[1];
197     adfResampleGeoTransform[5] = adfPanGeoTransform[5];
198     adfResampleGeoTransform[2] = adfPanGeoTransform[2];
199     adfResampleGeoTransform[4] = adfPanGeoTransform[4];
200     if(adfMssGeoTransform[0] >= adfPanGeoTransform[0]){
201         adfResampleGeoTransform[0] = adfMssGeoTransform[0];
202     }else{
203         adfResampleGeoTransform[0] = adfPanGeoTransform[0];
204     }
205     if(adfMssGeoTransform[3] > adfPanGeoTransform[3]){
206         adfResampleGeoTransform[3] = adfPanGeoTransform[3];
207     }else{
208         adfResampleGeoTransform[3] = adfMssGeoTransform[3];
209     }
210 
211     // 建立新數據集=======影像大小
212     double panEndX = adfPanGeoTransform[0] + panWidth*adfPanGeoTransform[1] + 
213         panHeidht*adfPanGeoTransform[2];
214     double panEndY = adfPanGeoTransform[3] + panHeidht*adfPanGeoTransform[4] + 
215         panHeidht*adfPanGeoTransform[5];
216 
217     double mssEndX = adfMssGeoTransform[0] +mssWidth*adfMssGeoTransform[1] + 
218         mssHeight*adfMssGeoTransform[2];
219     double mssEndY = adfMssGeoTransform[3] + mssWidth*adfMssGeoTransform[4] + 
220         mssHeight*adfMssGeoTransform[5];
221     double resampleEndXY[2] = {0};
222     if(panEndX > mssEndX)
223         resampleEndXY[0] = mssEndX;
224     else
225         resampleEndXY[0] = panEndX;
226     if(panEndY >= mssEndY)
227         resampleEndXY[1] = panEndY;
228     else
229         resampleEndXY[1] = mssEndY;
230 
231     // 建立新數據集=======MSS AND PAN 有效長寬
232     int resampleWidth = static_cast<int>((resampleEndXY[0] - adfResampleGeoTransform[0])/adfResampleGeoTransform[1] + 0.5);
233     int resampleHeight = static_cast<int>((resampleEndXY[1] - adfResampleGeoTransform[3])/adfResampleGeoTransform[5] + 0.5);
234     int mssEffectiveWidth = static_cast<int>((resampleEndXY[0] - adfResampleGeoTransform[0])/adfMssGeoTransform[1] + 0.5);
235     int mssEffectiveHeight = static_cast<int>((resampleEndXY[1] - adfResampleGeoTransform[3])/adfMssGeoTransform[5] + 0.5);
236     int panEffectiveWidth = resampleWidth;
237     int panEffectiveHeight = resampleHeight;
238 
239     // 建立新數據集=======位置增益大小
240     int mssGainX = static_cast<int>((adfResampleGeoTransform[0] - adfMssGeoTransform[0])/adfMssGeoTransform[1] + 0.5);
241     int mssGainY = static_cast<int>((adfResampleGeoTransform[3] - adfMssGeoTransform[3])/adfMssGeoTransform[5] + 0.5);
242     int panGainX = static_cast<int>((adfResampleGeoTransform[0] - adfPanGeoTransform[0])/adfPanGeoTransform[1] + 0.5);
243     int panGainY = static_cast<int>((adfResampleGeoTransform[3] - adfPanGeoTransform[3])/adfPanGeoTransform[5] + 0.5);
244 
245 
246     // 建立新數據集=======建立文件
247     GDALDriver *poOutDriver = (GDALDriver*)GDALGetDriverByName("GTIFF");
248     if(!poOutDriver){
249         return -1;
250     }
251     GDALDataset *poOutDS = poOutDriver->Create(resamplefileName,resampleWidth,
252         resampleHeight,mssBandCount,mssDT,NULL);
253     poOutDS->SetGeoTransform(adfResampleGeoTransform);
254     poOutDS->SetProjection(poPANDS->GetProjectionRef());
255 
256 
257     // 重採樣核心代碼============圖像分塊
258     int iNumRow = 256;
259     if(iNumRow > resampleHeight){
260         iNumRow = 1;
261     }
262     int loopNum = (resampleHeight + iNumRow - 1)/iNumRow;  //分塊數
263     int nLineSpace,nPixSpace,nBandSpace;
264     nLineSpace = sizeof(float)*mssEffectiveWidth*mssBandCount;
265     nPixSpace = 0;
266     nBandSpace = sizeof(float)*mssEffectiveWidth;
267 
268     // 重採樣採樣比例
269     float radioX = float(adfPanGeoTransform[1]/adfMssGeoTransform[1]);
270     float radioY = float(adfPanGeoTransform[5]/adfMssGeoTransform[5]);
271 
272     int mssCurPosX = mssGainX;
273     int mssCurPosY = mssGainY;
274     int mssCurWidth = 0;
275     int mssCurHeight = 0;
276 
277     // 重採樣核心代碼============
278     for(int i = 0;i<loopNum;i++){
279         int tmpRowNum = iNumRow;
280         int startR = i*iNumRow;
281         int endR = startR + iNumRow - 1;
282         if(endR>resampleHeight -1){
283             tmpRowNum = resampleHeight - startR;
284             //endR = startR + tmpRowNum - 1;
285         }
286         //計算讀取的MSS影像區域大小
287         int mssCurWidth = mssEffectiveWidth;
288         int mssCurHeight = 0;
289         if(MethodType == 0)
290             mssCurHeight = int(tmpRowNum*radioY);
291         else if(MethodType == 1)
292             mssCurHeight = int(tmpRowNum*radioY)+1;
293         else if(MethodType == 2)
294             mssCurHeight = int(tmpRowNum*radioY)+2;
295 
296         if(mssCurHeight + mssCurPosY > mssHeight - 1){
297             mssCurHeight = mssHeight - mssCurPosY;
298         }
299 
300         //建立數據
301         /*float *resampleBuf = (float *)malloc(sizeof(cl_float)*tmpRowNum*resampleWidth*trueBandCount);*/
302         float *mssBuf = (float *)malloc(sizeof(cl_float)*mssCurHeight*mssCurWidth*mssBandCount);
303         //memset(resampleBuf,0,sizeof(float)*tmpRowNum*resampleWidth*trueBandCount);
304         memset(mssBuf,0,sizeof(float)*mssCurHeight*mssCurWidth*mssBandCount);
305 
306         // 讀取數據
307         poMSSDS->RasterIO(GF_Read,mssCurPosX,mssCurPosY,mssCurWidth,mssCurHeight,
308             mssBuf,mssCurWidth,mssCurHeight,GDT_Float32,mssBandCount,NULL,nPixSpace,
309             nLineSpace,nBandSpace);
310 
311         if(MethodType == 0)
312             mssCurPosY += mssCurHeight;
313         else if(MethodType == 1)
314             mssCurPosY += mssCurHeight - 1;        
315         else if(MethodType == 2)
316             mssCurPosY += mssCurHeight - 2;
317 
318         // 數據格式轉換
319         long sz = tmpRowNum*resampleWidth*mssBandCount;
320         void *resampleBuf = NULL;
321         switch(mssDT){
322             case GDT_Byte:
323                 resampleBuf = new unsigned char[sz];
324                 ReSampleCPUKernel<unsigned char>(mssBuf,(unsigned char*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
325                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
326                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
327                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned char),
328                     resampleWidth*sizeof(unsigned char));
329                 break;
330             case GDT_UInt16:
331                 resampleBuf = new unsigned short int[sz];
332                 ReSampleCPUKernel<unsigned short int>(mssBuf,(unsigned short int*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
333                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
334                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
335                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned short int),
336                     resampleWidth*sizeof(unsigned short int));
337                 break;
338             case GDT_Int16:
339                 resampleBuf = new short int[sz];
340                 ReSampleCPUKernel<short int>(mssBuf,(short int*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
341                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
342                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
343                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(short int),
344                     resampleWidth*sizeof(short int));
345                 break;
346             case GDT_UInt32:
347                 resampleBuf = new unsigned int[sz];
348                 ReSampleCPUKernel<unsigned int>(mssBuf,(unsigned int*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
349                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
350                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
351                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned int),
352                     resampleWidth*sizeof(unsigned int));
353                 break;
354             case GDT_Int32:
355                 resampleBuf = new int[sz];
356                 ReSampleCPUKernel<int>(mssBuf,(int*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
357                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
358                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
359                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(int),
360                     resampleWidth*sizeof(int));
361                 break;
362             case GDT_Float32:
363                 resampleBuf = new float[sz];
364                 ReSampleCPUKernel<float>(mssBuf,(float*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
365                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
366                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
367                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(float),
368                     resampleWidth*sizeof(float));
369                 break;
370             case GDT_Float64:
371                 resampleBuf = new double[sz];
372                 ReSampleCPUKernel<double>(mssBuf,(double*)resampleBuf,mssCurWidth,mssCurHeight,mssBandCount,mssGainX,mssGainY,
373                     resampleWidth,tmpRowNum,radioX,radioY,dfSrcNoDataValue,MethodType);
374                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,resampleBuf,
375                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(double),
376                     resampleWidth*sizeof(double));
377                 break;
378         }
379         delete []mssBuf;
380         delete []resampleBuf;
381         std::cout<<i<<std::endl;
382     }
383     delete []pBandMap;pBandMap = NULL;
384     GDALClose((GDALDatasetH)poPANDS);
385     GDALClose((GDALDatasetH)poMSSDS);
386     GDALClose((GDALDatasetH)poOutDS);
387     return 0;
388 }
389 
390 #endif

GPU版本:

  1 #ifndef RESAMPLEOPENCL_H
  2 #define RESAMPLEOPENCL_H
  3 
  4 #include <CL/cl.h>
  5 #include <gdal_alg_priv.h>
  6 #include <gdal_priv.h>
  7 
  8 #pragma comment(lib,"OpenCL.lib")
  9 
 10 /*
 11 @ 功能描述
 12     讀取源程序,將文本源程序讀到內核中
 13 */
 14 char* LoadProgSource(const char* cFilename, const char* cPreamble, size_t* szFinalLength)
 15 {
 16     FILE* pFileStream = NULL;
 17     size_t szSourceLength;
 18 
 19     // open the OpenCL source code file
 20     pFileStream = fopen(cFilename, "rb");
 21     if(pFileStream == 0) 
 22     {     
 23         return NULL;
 24     }
 25 
 26     size_t szPreambleLength = strlen(cPreamble);
 27 
 28     // get the length of the source code
 29     fseek(pFileStream, 0, SEEK_END); 
 30     szSourceLength = ftell(pFileStream);
 31     fseek(pFileStream, 0, SEEK_SET); 
 32 
 33     // allocate a buffer for the source code string and read it in
 34     char* cSourceString = (char *)malloc(szSourceLength + szPreambleLength + 1); 
 35     memcpy(cSourceString, cPreamble, szPreambleLength);
 36     if (fread((cSourceString) + szPreambleLength, szSourceLength, 1, pFileStream) != 1)
 37     {
 38         fclose(pFileStream);
 39         free(cSourceString);
 40         return 0;
 41     }
 42 
 43     // close the file and return the total length of the combined (preamble + source) string
 44     fclose(pFileStream);
 45     if(szFinalLength != 0)
 46     {
 47         *szFinalLength = szSourceLength + szPreambleLength;
 48     }
 49     cSourceString[szSourceLength + szPreambleLength] = '\0';
 50 
 51     return cSourceString;
 52 }
 53 
 54 template<typename T>
 55 bool DataTypeTrans(const float *pSrcBuf,T *pDesBuf,long size)
 56 {
 57     if(pSrcBuf == NULL){
 58         return false;
 59     }
 60     while(size--){
 61         pDesBuf[size] = T(pSrcBuf[size]);
 62     }
 63     return true;
 64 }
 65 
 66 int ReSampleOpenCLApp(const char *mssfileName,
 67                       const char *panfileName,
 68                       const char *resamplefileName,
 69                       int MethodType = 1)
 70 {
 71     GDALAllRegister();
 72     GDALDataset *poPANDS = (GDALDataset*)GDALOpen(panfileName,GA_ReadOnly);
 73     GDALDataset *poMSSDS = (GDALDataset*)GDALOpen(mssfileName,GA_ReadOnly);
 74     if(!poPANDS || !poMSSDS)
 75         return -1;
 76 
 77     //MSS info
 78     int mssBandCount = poMSSDS->GetRasterCount();
 79     int mssWidth = poMSSDS->GetRasterXSize();
 80     int mssHeight = poMSSDS->GetRasterYSize();
 81     double adfMssGeoTransform[6] = {0};
 82     poMSSDS->GetGeoTransform(adfMssGeoTransform);
 83     GDALDataType mssDT = poMSSDS->GetRasterBand(1)->GetRasterDataType();
 84 
 85     int bSrcHasNoData;
 86     float dfSrcNoDataValue = 0;
 87     dfSrcNoDataValue = (float)GDALGetRasterNoDataValue(poMSSDS->GetRasterBand(1),&bSrcHasNoData);
 88     if(!bSrcHasNoData) dfSrcNoDataValue = 0.0;
 89 
 90 
 91     // PAN Info
 92     int panBandCount = poPANDS->GetRasterCount();
 93     int panWidth = poPANDS->GetRasterXSize();
 94     int panHeidht = poPANDS->GetRasterYSize();
 95     double adfPanGeoTransform[6] = {0};
 96     poPANDS->GetGeoTransform(adfPanGeoTransform);
 97     GDALDataType panDT = poPANDS->GetRasterBand(1)->GetRasterDataType();
 98 
 99     // 建立新數據集=======投影信息
100     double adfResampleGeoTransform[6] = {0};
101     adfResampleGeoTransform[1] = adfPanGeoTransform[1];
102     adfResampleGeoTransform[5] = adfPanGeoTransform[5];
103     adfResampleGeoTransform[2] = adfPanGeoTransform[2];
104     adfResampleGeoTransform[4] = adfPanGeoTransform[4];
105     if(adfMssGeoTransform[0] >= adfPanGeoTransform[0]){
106         adfResampleGeoTransform[0] = adfMssGeoTransform[0];
107     }else{
108         adfResampleGeoTransform[0] = adfPanGeoTransform[0];
109     }
110     if(adfMssGeoTransform[3] > adfPanGeoTransform[3]){
111         adfResampleGeoTransform[3] = adfPanGeoTransform[3];
112     }else{
113         adfResampleGeoTransform[3] = adfMssGeoTransform[3];
114     }
115 
116     // 建立新數據集=======影像大小
117     double panEndX = adfPanGeoTransform[0] + panWidth*adfPanGeoTransform[1] + 
118         panHeidht*adfPanGeoTransform[2];
119     double panEndY = adfPanGeoTransform[3] + panHeidht*adfPanGeoTransform[4] + 
120         panHeidht*adfPanGeoTransform[5];
121 
122     double mssEndX = adfMssGeoTransform[0] +mssWidth*adfMssGeoTransform[1] + 
123         mssHeight*adfMssGeoTransform[2];
124     double mssEndY = adfMssGeoTransform[3] + mssWidth*adfMssGeoTransform[4] + 
125         mssHeight*adfMssGeoTransform[5];
126     double resampleEndXY[2] = {0};
127     if(panEndX > mssEndX)
128         resampleEndXY[0] = mssEndX;
129     else
130         resampleEndXY[0] = panEndX;
131     if(panEndY >= mssEndY)
132         resampleEndXY[1] = panEndY;
133     else
134         resampleEndXY[1] = mssEndY;
135 
136     // 建立新數據集=======MSS AND PAN 有效長寬
137     int resampleWidth = static_cast<int>((resampleEndXY[0] - adfResampleGeoTransform[0])/adfResampleGeoTransform[1] + 0.5);
138     int resampleHeight = static_cast<int>((resampleEndXY[1] - adfResampleGeoTransform[3])/adfResampleGeoTransform[5] + 0.5);
139     int mssEffectiveWidth = static_cast<int>((resampleEndXY[0] - adfResampleGeoTransform[0])/adfMssGeoTransform[1] + 0.5);
140     int mssEffectiveHeight = static_cast<int>((resampleEndXY[1] - adfResampleGeoTransform[3])/adfMssGeoTransform[5] + 0.5);
141     int panEffectiveWidth = resampleWidth;
142     int panEffectiveHeight = resampleHeight;
143 
144     // 建立新數據集=======位置增益大小
145     int mssGainX = static_cast<int>((adfResampleGeoTransform[0] - adfMssGeoTransform[0])/adfMssGeoTransform[1] + 0.5);
146     int mssGainY = static_cast<int>((adfResampleGeoTransform[3] - adfMssGeoTransform[3])/adfMssGeoTransform[5] + 0.5);
147     int panGainX = static_cast<int>((adfResampleGeoTransform[0] - adfPanGeoTransform[0])/adfPanGeoTransform[1] + 0.5);
148     int panGainY = static_cast<int>((adfResampleGeoTransform[3] - adfPanGeoTransform[3])/adfPanGeoTransform[5] + 0.5);
149 
150 
151     // 建立新數據集=======建立文件
152     GDALDriver *poOutDriver = (GDALDriver*)GDALGetDriverByName("GTIFF");
153     if(!poOutDriver){
154         return -1;
155     }
156     GDALDataset *poOutDS = poOutDriver->Create(resamplefileName,resampleWidth,
157         resampleHeight,mssBandCount,mssDT,NULL);
158     //GDALDataset *poOutDS = poOutDriver->Create(resamplefileName,resampleWidth,
159     //    resampleHeight,mssBandCount,GDT_Float32,NULL);
160     poOutDS->SetGeoTransform(adfResampleGeoTransform);
161     poOutDS->SetProjection(poPANDS->GetProjectionRef());
162 
163     int pBandMap[4] = {1,2,3,4};
164     // 重採樣核心代碼============圖像分塊
165     int iNumRow = 256;
166     if(iNumRow > resampleHeight){
167         iNumRow = 1;
168     }
169     int loopNum = (resampleHeight + iNumRow - 1)/iNumRow;  //分塊數
170     int nLineSpace,nPixSpace,nBandSpace;
171     nLineSpace = sizeof(float)*mssEffectiveWidth*mssBandCount;
172     nPixSpace = 0;
173     nBandSpace = sizeof(float)*mssEffectiveWidth;
174 
175     // 重採樣採樣比例
176     float radioX = adfPanGeoTransform[1]/adfMssGeoTransform[1];
177     float radioY = adfPanGeoTransform[5]/adfMssGeoTransform[5];
178 
179     int mssCurPosX = mssGainX;
180     int mssCurPosY = mssGainY;
181     int mssCurWidth = 0;
182     int mssCurHeight = 0;
183 
184     // 重採樣核心代碼============
185     // OpenCL部分 =============== 1 建立平臺
186     cl_uint num_platforms;
187     cl_int ret = clGetPlatformIDs(0,NULL,&num_platforms);
188     if(ret != CL_SUCCESS || num_platforms < 1){
189         printf("clGetPlatformIDs Error\n");
190         return -1;
191     }
192     cl_platform_id platform_id = NULL;
193     ret = clGetPlatformIDs(1,&platform_id,NULL);
194     if(ret != CL_SUCCESS){
195         printf("clGetPlatformIDs Error2\n");
196         return -1;
197     }
198 
199     // OpenCL部分 =============== 2 得到設備
200     cl_uint num_devices;
201     ret = clGetDeviceIDs(platform_id,CL_DEVICE_TYPE_GPU,0,NULL,
202         &num_devices);
203     if(ret != CL_SUCCESS || num_devices < 1){
204         printf("clGetDeviceIDs Error\n");
205         return -1;
206     }
207     cl_device_id device_id;
208     ret = clGetDeviceIDs(platform_id,CL_DEVICE_TYPE_GPU,1,&device_id,NULL);
209     if(ret != CL_SUCCESS){
210         printf("clGetDeviceIDs Error2\n");
211         return -1;
212     }
213 
214     // OpenCL部分 =============== 3 建立Context
215     cl_context_properties props[] = {CL_CONTEXT_PLATFORM,
216         (cl_context_properties)platform_id,0};
217     cl_context context = NULL;
218     context = clCreateContext(props,1,&device_id,NULL,NULL,&ret);
219     if(ret != CL_SUCCESS || context == NULL){
220         printf("clCreateContext Error\n");
221         return -1;
222     }
223 
224     // OpenCL部分 =============== 4 建立Command Queue
225     cl_command_queue command_queue = NULL;
226     command_queue = clCreateCommandQueue(context,device_id,0,&ret);
227     if(ret != CL_SUCCESS || command_queue == NULL){
228         printf("clCreateCommandQueue Error\n");
229         return -1;
230     }
231 
232     // OpenCL部分 =============== 6 建立編譯Program
233     const char *strfile = "D:\\PIE3\\src\\Test\\TextOpecCLResample\\TextOpecCLResample\\ReSampleKernel.txt";
234     size_t lenSource = 0;
235     char *kernelSource = LoadProgSource(strfile,"",&lenSource);
236     cl_program *programs = (cl_program *)malloc(loopNum*sizeof(cl_program));
237     memset(programs,0,sizeof(cl_program)*loopNum);
238 
239     cl_kernel *kernels = (cl_kernel*)malloc(loopNum*sizeof(cl_kernel));
240     memset(kernels,0,sizeof(cl_kernel)*loopNum);
241 
242 
243     for(int i = 0;i<loopNum;i++){
244         int tmpRowNum = iNumRow;
245         int startR = i*iNumRow;
246         int endR = startR + iNumRow - 1;
247         if(endR>resampleHeight -1){
248             tmpRowNum = resampleHeight - startR;
249             //endR = startR + tmpRowNum - 1;
250         }
251         //計算讀取的MSS影像區域大小
252         int mssCurWidth = mssEffectiveWidth;
253         int mssCurHeight = 0;
254         if(MethodType == 0)
255             mssCurHeight = int(tmpRowNum*radioY);
256         else if(MethodType == 1)
257             mssCurHeight = int(tmpRowNum*radioY)+1;
258         else if(MethodType == 2)
259             mssCurHeight = int(tmpRowNum*radioY)+2;
260 
261         if(mssCurHeight + mssCurPosY > mssHeight - 1){
262             mssCurHeight = mssHeight - mssCurPosY;
263         }
264 
265         //建立數據
266         float *resampleBuf = (float *)malloc(sizeof(cl_float)*tmpRowNum*resampleWidth*mssBandCount);
267         float *mssBuf = (float *)malloc(sizeof(cl_float)*mssCurHeight*mssCurWidth*mssBandCount);
268         memset(resampleBuf,0,sizeof(cl_float)*tmpRowNum*resampleWidth*mssBandCount);
269         memset(mssBuf,0,sizeof(cl_float)*mssCurHeight*mssCurWidth*mssBandCount);
270         
271         // 讀取數據
272         poMSSDS->RasterIO(GF_Read,mssCurPosX,mssCurPosY,mssCurWidth,mssCurHeight,
273             mssBuf,mssCurWidth,mssCurHeight,GDT_Float32,mssBandCount,pBandMap,nPixSpace,
274             nLineSpace,nBandSpace);
275 
276         if(MethodType == 0)
277             mssCurPosY += mssCurHeight;
278         else if(MethodType == 1)
279             mssCurPosY += mssCurHeight - 1;        
280         else if(MethodType == 2)
281             mssCurPosY += mssCurHeight - 2;
282 
283         // OpenCL部分 =============== 5 建立Memory Object
284         cl_mem mem_mss = NULL;
285         mem_mss = clCreateBuffer(context,CL_MEM_READ_WRITE | CL_MEM_USE_HOST_PTR,
286             sizeof(cl_float)*mssCurHeight*mssCurWidth*mssBandCount,mssBuf,&ret);
287         if(ret != CL_SUCCESS || NULL == mem_mss){
288             printf("clCreateBuffer Error\n");
289             return -1;
290         }
291 
292         cl_mem mem_resample = NULL;
293         mem_resample = clCreateBuffer(context,CL_MEM_READ_WRITE | CL_MEM_USE_HOST_PTR,
294             sizeof(cl_float)*resampleWidth*tmpRowNum*mssBandCount,resampleBuf,&ret);
295         if(ret != CL_SUCCESS || NULL == mem_resample){
296             printf("clCreateBuffer Error\n");
297             return -1;
298         }
299 
300         // OpenCL部分 =============== 6 建立編譯Program
301         //const char *strfile = "D:\\PIE3\\src\\Test\\TextOpecCLResample\\TextOpecCLResample\\ReSampleKernel.txt";
302         //size_t lenSource = 0;
303         //char *kernelSource = LoadProgSource(strfile,"",&lenSource);
304         //cl_program program = NULL;
305         programs[i] = clCreateProgramWithSource(context,1,(const char**)&kernelSource,
306             NULL,&ret);
307         if(ret != CL_SUCCESS || NULL == programs[i]){
308             printf("clCreateProgramWithSource Error\n");
309             return -1;
310         }
311         ret = clBuildProgram(programs[i],1,&device_id,NULL,NULL,NULL);
312         if(ret != CL_SUCCESS){
313             char* build_log;
314             size_t log_size;
315             //查詢日誌的大小
316             clGetProgramBuildInfo(programs[i], device_id, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
317             build_log = new char[log_size+1];
318             //得到編譯日誌信息
319             ret = clGetProgramBuildInfo(programs[i], device_id, CL_PROGRAM_BUILD_LOG, log_size, build_log, NULL);
320             build_log[log_size] = '\0';
321             printf("%s\n",build_log);
322             printf("編譯失敗!");
323             delete []build_log;
324             return -1;
325         }
326 
327         // OpenCL部分 =============== 7 建立Kernel
328         //cl_kernel kernel = NULL;
329         kernels[i] = clCreateKernel(programs[i],"ReSampleKernel",&ret);
330         if(ret != CL_SUCCESS || NULL == kernels[i]){
331             printf("clCreateProgramWithSource Error\n");
332             return -1;
333         }
334 
335         // OpenCL部分 =============== 8 設置Kernel參數
336         ret = clSetKernelArg(kernels[i],0,sizeof(cl_mem),&mem_mss);
337         ret |= clSetKernelArg(kernels[i],1,sizeof(cl_mem),&mem_resample);
338         ret |= clSetKernelArg(kernels[i],2,sizeof(cl_int),&mssCurWidth);
339         ret |= clSetKernelArg(kernels[i],3,sizeof(cl_int),&mssCurHeight);
340         ret |= clSetKernelArg(kernels[i],4,sizeof(cl_int),&mssBandCount);
341         ret |= clSetKernelArg(kernels[i],5,sizeof(cl_int),&mssGainX);
342         ret |= clSetKernelArg(kernels[i],6,sizeof(cl_int),&mssGainY);
343         ret |= clSetKernelArg(kernels[i],7,sizeof(cl_int),&resampleWidth);
344         ret |= clSetKernelArg(kernels[i],8,sizeof(cl_int),&tmpRowNum);
345         ret |= clSetKernelArg(kernels[i],9,sizeof(cl_float),&radioX);
346         ret |= clSetKernelArg(kernels[i],10,sizeof(cl_float),&radioY);
347         ret |= clSetKernelArg(kernels[i],11,sizeof(cl_float),&dfSrcNoDataValue);
348         ret |= clSetKernelArg(kernels[i],12,sizeof(cl_int),&MethodType);
349         if(ret != CL_SUCCESS){
350             printf("clSetKernelArg Error\n");
351             return -1;
352         }
353 
354         // OpenCL部分 =============== 9 設置Group Size
355         cl_uint work_dim = 2;
356         size_t global_work_size[] = {resampleWidth,tmpRowNum};
357         size_t *local_work_size = NULL;
358 
359         // OpenCL部分 =============== 10 執行內核
360         ret = clEnqueueNDRangeKernel(command_queue,kernels[i],work_dim,NULL,global_work_size,
361             local_work_size,0,NULL,NULL);
362         ret |= clFinish(command_queue);
363         if(ret != CL_SUCCESS){
364             printf("clEnqueueNDRangeKernel Error\n");
365             return -1;
366         }
367         
368         // OpenCL部分 =============== 11 讀取結果
369         
370         resampleBuf = (float*)clEnqueueMapBuffer(command_queue,mem_resample,CL_TRUE,CL_MAP_READ | CL_MAP_WRITE,
371             0,sizeof(cl_float)*tmpRowNum*resampleWidth*mssBandCount,0,NULL,NULL,&ret);
372         //ret = clEnqueueReadBuffer(command_queue,mem_resample,CL_TRUE,0,
373         //    sizeof(cl_float)*tmpRowNum*resampleWidth*mssBandCount,(void*)resampleBuf,0,NULL,NULL);
374         if(ret != CL_SUCCESS){
375             printf("clEnqueueMapBuffer Error\n");
376             return -1;
377         }
378 
379         
380         // 數據格式轉換
381         long sz = tmpRowNum*resampleWidth*mssBandCount;
382         void *pBuf = NULL;
383         CPLErr err;
384         switch(mssDT){
385             case GDT_Byte:
386                 pBuf = new unsigned char[sz];
387                 if(!DataTypeTrans<unsigned char>(resampleBuf,(unsigned char*)pBuf,sz))
388                 {
389                     printf("DataTypeTrans Error\n");
390                     return -1;
391                 }
392                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
393                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned char),
394                     resampleWidth*sizeof(unsigned char));
395                 break;
396             case GDT_UInt16:
397                 pBuf = new unsigned short int[sz];
398                 if(!DataTypeTrans<unsigned short int>(resampleBuf,(unsigned short int*)pBuf,sz))
399                 {
400                     printf("DataTypeTrans Error\n");
401                     return -1;
402                 }
403                 err = poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
404                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned short int),
405                     resampleWidth*sizeof(unsigned short int));
406                 break;
407             case GDT_Int16:
408                 pBuf = new short int[sz];
409                 if(!DataTypeTrans<short int>(resampleBuf,(short int*)pBuf,sz))
410                 {
411                     printf("DataTypeTrans Error\n");
412                     return -1;
413                 }
414                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
415                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(short int),
416                     resampleWidth*sizeof(short int));
417                 break;
418             case GDT_UInt32:
419                 pBuf = new unsigned int[sz];
420                 if(!DataTypeTrans<unsigned int>(resampleBuf,(unsigned int*)pBuf,sz))
421                 {
422                     printf("DataTypeTrans Error\n");
423                     return -1;
424                 }
425                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
426                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(unsigned int),
427                     resampleWidth*sizeof(unsigned int));
428                 break;
429             case GDT_Int32:
430                 pBuf = new int[sz];
431                 if(!DataTypeTrans<int>(resampleBuf,(int*)pBuf,sz))
432                 {
433                     printf("DataTypeTrans Error\n");
434                     return -1;
435                 }
436                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
437                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(int),
438                     resampleWidth*sizeof(int));
439                 break;
440             case GDT_Float32:
441                 pBuf = new float[sz];
442                 if(!DataTypeTrans<float>(resampleBuf,(float *)pBuf,sz))
443                 {
444                     printf("DataTypeTrans Error\n");
445                     return -1;
446                 }
447                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
448                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(float),
449                     resampleWidth*sizeof(float));
450                 break;
451             case GDT_Float64:
452                 pBuf = new double[sz];
453                 if(!DataTypeTrans<double>(resampleBuf,(double *)pBuf,sz))
454                 {
455                     printf("DataTypeTrans Error\n");
456                     return -1;
457                 }
458                 poOutDS->RasterIO(GF_Write,0,startR,resampleWidth,tmpRowNum,pBuf,
459                     resampleWidth,tmpRowNum,mssDT,mssBandCount,NULL,nPixSpace,resampleWidth*mssBandCount*sizeof(double),
460                     resampleWidth*sizeof(double));
461                 break;
462         }
463         delete []pBuf;pBuf = NULL;
464         free(mssBuf);
465         free(resampleBuf);
466 
467         // OpenCL部分 =============== 12 釋放資源
468         if(NULL != mem_mss) clReleaseMemObject(mem_mss);
469         if(NULL != mem_resample) clReleaseMemObject(mem_resample);
470         std::cout<<i<<std::endl;
471     }
472     // OpenCL部分 =============== 12 釋放資源
473     int i = 0;
474     while(i < loopNum){
475         if(NULL != kernels[i]) clReleaseKernel(kernels[i]);
476         if(NULL != programs[i]) clReleaseProgram(programs[i]);
477         i++;
478     }
479 
480     if(NULL != command_queue) clReleaseCommandQueue(command_queue);
481     if(NULL != context) clReleaseContext(context);
482     GDALClose((GDALDatasetH)poPANDS);
483     GDALClose((GDALDatasetH)poMSSDS);
484     GDALClose((GDALDatasetH)poOutDS);
485     return 0;
486 }
487 
488 
489 
490 
491 
492 #endif

GPU核函數代碼以下:

  1 #pragma OPENCL EXTENSION cl_amd_printf:enable
  2 
  3 __kernel void ReSampleKernel(__global const float *mssData,
  4                              __global float *resampleData,
  5                              int mssWidth,
  6                              int mssHeight,
  7                              int mssBandCount,
  8                              int mssOffsetX,
  9                              int mssOffsetY,
 10                              int panWidth,
 11                              int panHeight,
 12                              float radioX,
 13                              float radioY,
 14                              float dfDstNoDataValue,
 15                              int MethodType)
 16 {
 17     int idx = get_global_id(1);  // 採樣行
 18     int idy = get_global_id(0);  // 採樣列
 19     float eps = 0.00001f;
 20     if(idx < panHeight && idy < panWidth){
 21         // 找到對應的MSS像素位置
 22         float curX = (float)idx * radioX;
 23         float curY = (float)idy * radioY;
 24         int tmpP = (int)curX*mssWidth*mssBandCount + (int)curY;
 25         if(mssData[tmpP] == dfDstNoDataValue)
 26         {
 27             int i = 0;
 28             while(i < mssBandCount){
 29                 resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = dfDstNoDataValue;
 30                 i++;
 31             }
 32             return;
 33         }
 34         if(MethodType == 0){  // 最近鄰
 35             int nearX = (int)(curX + 0.5)>(int)curX?(int)(curX + 1):(int)curX;
 36             int nearY = (int)(curY + 0.5)>(int)curY?(int)(curY + 1):(int)curY;
 37             if(nearX >= mssHeight - 1){
 38                 nearX = mssHeight - 1;
 39             }
 40             if(nearY >= mssWidth - 1){
 41                 nearY = mssWidth - 1;
 42             }
 43             if(nearX < mssHeight && nearY < mssWidth){
 44                 int i = 0;
 45                 while(i < mssBandCount){
 46                     resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = 
 47                         mssData[nearX*mssWidth*mssBandCount + i*mssWidth + nearY];
 48                     i++;
 49                 }
 50             }
 51         }
 52         if(MethodType == 1){  // 雙線性
 53             float dataX = curX - (int)curX;
 54             float dataY = curY - (int)curY;
 55             if(dataX < eps){
 56                 dataX = 0.00001;
 57             }
 58             if(dataY < eps){
 59                 dataY = 0.00001;
 60             }
 61             int preX = (int)curX;
 62             int preY = (int)curY;
 63             int postX = (int)curX + 1;
 64             int postY = (int)curY + 1;
 65             if(postX >= mssHeight - 1){
 66                 postX = mssHeight - 1;
 67             }
 68             if(postY >= mssWidth - 1){
 69                 postY = mssWidth - 1;
 70             }
 71             
 72             float Wx1 = 1 - dataX;
 73             float Wx2 = dataX;
 74             float Wy1 = 1 - dataY;
 75             float Wy2 = dataY;
 76             // 雙線性差值核心代碼
 77             int i = 0;
 78             while(i < mssBandCount){
 79                 float pMssValue[4] = {0,0,0,0};
 80                 pMssValue[0] = mssData[preX*mssWidth*mssBandCount + i*mssWidth + preY];
 81                 pMssValue[1] = mssData[preX*mssWidth*mssBandCount + i*mssWidth + postY];
 82                 pMssValue[2] = mssData[postX*mssWidth*mssBandCount + i*mssWidth + preY];
 83                 pMssValue[3] = mssData[postX*mssWidth*mssBandCount + i*mssWidth + postY];
 84                 resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = 
 85                     Wy1*(Wx1*pMssValue[0] + Wx2*pMssValue[2]) + Wy2*(Wx1*pMssValue[1] + Wx2*pMssValue[3]);
 86                 i++;
 87             }
 88         }
 89         if(MethodType == 2){  // 雙三次卷積
 90             float dataX = curX - (int)curX;
 91             float dataY = curY - (int)curY;
 92             //printf("dataX = %f   dataY = %f\n",dataX,dataY);
 93             int preX1 = (int)curX - 1;
 94             int preX2 = (int)curX;
 95             int postX1 = (int)curX + 1;
 96             int postX2 = (int)curX + 2;
 97             int preY1 = (int)curY - 1;
 98             int preY2 = (int)curY;
 99             int postY1 = (int)curY + 1;
100             int postY2 = (int)curY + 2;
101             if(preX1 < 0) preX1 = 0;
102             if(preY1 < 0) preY1 = 0;
103             if(postX1 > mssHeight - 1) postX1 = mssHeight - 1;
104             if(postX2 > mssHeight - 1) postX2 = mssHeight - 1;
105             if(postY1 > mssWidth - 1) postY1 = mssWidth - 1;
106             if(postY2 > mssWidth - 1) postY2 = mssWidth - 1;
107 
108             float Wx1 = -1.0f*dataX + 2*dataX*dataX - dataX*dataX*dataX;
109             float Wx2 = 1 - 2*dataX*dataX + dataX*dataX*dataX;
110             float Wx3 = dataX + dataX*dataX - dataX*dataX*dataX;
111             float Wx4 = -1.0f*dataX*dataX + dataX*dataX*dataX;
112             float Wy1 = -1.0f*dataY + 2*dataY*dataY - dataY*dataY*dataY;
113             float Wy2 = 1 - 2*dataY*dataY + dataY*dataY*dataY;
114             float Wy3 = dataY + dataY*dataY - dataY*dataY*dataY;
115             float Wy4 = -1.0f*dataY*dataY + dataY*dataY*dataY;
116             
117             //printf("preX1 = %d\n",preX1);
118             int i = 0;
119             while(i < mssBandCount){
120                 float pMssValue[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
121                 pMssValue[0] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + preY1];
122                 pMssValue[1] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + preY2];
123                 pMssValue[2] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + postY1];
124                 pMssValue[3] = mssData[preX1*mssWidth*mssBandCount + i*mssWidth + postY2];
125                 
126                 pMssValue[4] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + preY1];
127                 pMssValue[5] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + preY2];
128                 pMssValue[6] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + postY1];
129                 pMssValue[7] = mssData[preX2*mssWidth*mssBandCount + i*mssWidth + postY2];
130                 
131                 pMssValue[8] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + preY1];
132                 pMssValue[9] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + preY2];
133                 pMssValue[10] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + postY1];
134                 pMssValue[11] = mssData[postX1*mssWidth*mssBandCount + i*mssWidth + postY2];
135                 
136                 pMssValue[12] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + preY1];
137                 pMssValue[13] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + preY2];
138                 pMssValue[14] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + postY1];
139                 pMssValue[15] = mssData[postX2*mssWidth*mssBandCount + i*mssWidth + postY2];
140 
141                 resampleData[idx*panWidth*mssBandCount+i*panWidth + idy] = 
142                     Wy1*(Wx1*pMssValue[0] + Wx2*pMssValue[4] + Wx3*pMssValue[8] + Wx4*pMssValue[12])+
143                     Wy2*(Wx1*pMssValue[1] + Wx2*pMssValue[5] + Wx3*pMssValue[9] + Wx4*pMssValue[13])+
144                     Wy3*(Wx1*pMssValue[2] + Wx2*pMssValue[6] + Wx3*pMssValue[10] + Wx4*pMssValue[14])+
145                     Wy4*(Wx1*pMssValue[3] + Wx2*pMssValue[7] + Wx3*pMssValue[11] + Wx4*pMssValue[15]);
146                 i++;
147             }    
148         }
149     }
150 }

  以上代碼應該能夠直接使用,歡迎你們一塊兒交流探討。

 

另外,我對GDAL、CPU和GPU版本的重採樣算法效率進行了一下對比,GPU在三次卷積重採樣算法上要明顯的比CPU版本效率高不少。具體結果以下:

相關文章
相關標籤/搜索