public abstract class InputFormat<K, V> { /* * 獲取數據的分區信息,每一個分區包裝成InputSplit,返回一個List<InputSplit> * 注意這裏的分區是邏輯分區 * 好比一個文件,一共有100個字符,假如安裝每一個分區10個字符,那麼一共有10個分區 */ public abstract List<InputSplit> getSplits(JobContext context ) throws IOException, InterruptedException; /* * 根據分區信息,獲取RecordReader,RecordReader其實就是一個增強版的迭代器,只不過返回的是kv格式的數據 * 能夠看到,這裏只有一個InputSplit,也就是隻有一個分區,也就是說是分區內部的迭代 */ public abstract RecordReader<K,V> createRecordReader(InputSplit split, TaskAttemptContext context ) throws IOException, InterruptedException; }
這樣大概就理解了這個接口的定位,一個是how to defined partition,一個是how to get data from partition,下面再實例化到spark的應用場景。ide
//hbaseConfig HBaseConfiguration //TableInputFormat InputFormat的子類 表示輸入數據源 //ImmutableBytesWritable 數據源的key //Result 數據源的value //若是寫過mapreduce任務,這個方法和mapreduce的啓動配置相似,只不過輸出都是rdd,因此就不用聲明瞭 val hBaseRDD = sc.newAPIHadoopRDD(hbaseConfig, classOf[TableInputFormat], classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], classOf[org.apache.hadoop.hbase.client.Result])
new NewHadoopRDD(this, fClass, kClass, vClass, jconf)
override def getPartitions: Array[Partition] = { //實例化InputFormat對象 也就是咱們傳入的TableInputFormat(多是其它InputFormat,這裏只是舉個例子) val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => configurable.setConf(_conf) case _ => } val jobContext = new JobContextImpl(_conf, jobId) //拿到全部split val rawSplits = inputFormat.getSplits(jobContext).toArray //拿到總分區數,並轉換爲spark的套路 val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { //把每一個split封裝成partition result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } result }
//同樣的,實例化InputFormat對象 private val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } //知足mapreduce的一切要求... private val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) private val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private var finished = false private var reader = try { //拿到關鍵的RecordReader val _reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) _reader } catch { case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", e) finished = true null } //喜聞樂見的hasNext和next override def hasNext: Boolean = { if (!finished && !havePair) { try { finished = !reader.nextKeyValue } catch { case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", e) finished = true } if (finished) { // Close and release the reader here; close() will also be called when the task // completes, but for tasks that read from many files, it helps to release the // resources early. close() } havePair = !finished } !finished } override def next(): (K, V) = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } havePair = false if (!finished) { inputMetrics.incRecordsRead(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() } (reader.getCurrentKey, reader.getCurrentValue) }
RegionSizeCalculator sizeCalculator = new RegionSizeCalculator(getRegionLocator(), getAdmin()); TableName tableName = getTable().getName(); Pair<byte[][], byte[][]> keys = getStartEndKeys(); if (keys == null || keys.getFirst() == null || keys.getFirst().length == 0) { HRegionLocation regLoc = getRegionLocator().getRegionLocation(HConstants.EMPTY_BYTE_ARRAY, false); if (null == regLoc) { throw new IOException("Expecting at least one region."); } List<InputSplit> splits = new ArrayList<>(1); //拿到region的數量,用來作爲partitin的數量 long regionSize = sizeCalculator.getRegionSize(regLoc.getRegionInfo().getRegionName()); //建立TableSplit,也就是InputSplit TableSplit split = new TableSplit(tableName, scan, HConstants.EMPTY_BYTE_ARRAY, HConstants.EMPTY_BYTE_ARRAY, regLoc .getHostnamePort().split(Addressing.HOSTNAME_PORT_SEPARATOR)[0], regionSize); splits.add(split);
final TableRecordReader trr = this.tableRecordReader != null ? this.tableRecordReader : new TableRecordReader(); Scan sc = new Scan(this.scan); sc.setStartRow(tSplit.getStartRow()); sc.setStopRow(tSplit.getEndRow()); trr.setScan(sc); trr.setTable(getTable()); return new RecordReader<ImmutableBytesWritable, Result>() { @Override public void close() throws IOException { trr.close(); closeTable(); } @Override public ImmutableBytesWritable getCurrentKey() throws IOException, InterruptedException { return trr.getCurrentKey(); } @Override public Result getCurrentValue() throws IOException, InterruptedException { return trr.getCurrentValue(); } @Override public float getProgress() throws IOException, InterruptedException { return trr.getProgress(); } @Overrid public void initialize(InputSplit inputsplit, TaskAttemptContext context) throws IOException, InterruptedException { trr.initialize(inputsplit, context); } @Override public boolean nextKeyValue() throws IOException, InterruptedException { return trr.nextKeyValue(); } };