基於Java Socket的自定義協議,實現Android與服務器的長鏈接(二)

原文地址:http://blog.csdn.net/u0108184...

在閱讀本文前須要對socket以及自定義協議有一個基本的瞭解,能夠先查看上一篇文章《基於Java Socket的自定義協議,實現Android與服務器的長鏈接(一)》學習相關的基礎知識點。java

1、協議定義

上一篇文章中,咱們對socket編程和自定義協議作了一個簡單的瞭解,本文將在此基礎上加以深刻,來實現Android和服務器之間的長鏈接,現定義協議以下:android

  • 數據類協議(Data)編程

    • 長度(length,32bit)json

    • 版本號(version,8bit,前3位預留,後5位用於表示真正的版本號)服務器

    • 數據類型(type,8bit,0表示數據)網絡

    • 業務類型(pattion,8bit,0表示push,其餘暫未定)app

    • 數據格式(dtype,8bit,0表示json,其餘暫未定)socket

    • 消息id(msgId,32bit)ide

    • 正文數據(data)oop

  • 數據ack類協議(DataAck)

    • 長度(length,32bit)

    • 版本號(version,8bit,前3位預留,後5位用於表示真正的版本號)

    • 數據類型(type,8bit,1表示數據ack)

    • ack消息id(ackMsgId,32bit)

    • 預留信息(unused)

  • 心跳類協議(ping)

    • 長度(length,32bit)

    • 版本號(version,8bit,前3位預留,後5位用於表示真正的版本號)

    • 數據類型(type,8bit,2表示心跳)

    • 心跳id(pingId,32bit,client上報取奇數,即1,3,5...,server下發取偶數,即0,2,4...)

    • 預留信息(unused)

  • 心跳ack類協議(pingAck)

    • 長度(length,32bit)

    • 版本號(version,8bit,前3位預留,後5位用於表示真正的版本號)

    • 數據類型(type,8bit,3表示心跳ack)

    • ack心跳id(pingId,32bit,client上報取奇數,即1,3,5...,server下發取偶數,即0,2,4...)

    • 預留信息(unused)

2、協議實現

從上述的協議定義中,咱們能夠看出,四種協議有共同的3個要素,分別是:長度、版本號、數據類型,那麼咱們能夠先抽象出一個基本的協議,以下:

1. BasicProtocol

import android.util.Log;

import com.shandiangou.sdgprotocol.lib.Config;
import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;

import java.io.ByteArrayOutputStream;

/**
 * Created by meishan on 16/12/1.
 * <p>
 * 協議類型: 0表示數據,1表示數據Ack,2表示ping,3表示pingAck
 */
public abstract class BasicProtocol {

    // 長度均以字節(byte)爲單位
    public static final int LENGTH_LEN = 4;       //記錄整條數據長度數值的長度
    protected static final int VER_LEN = 1;       //協議的版本長度(其中前3位做爲預留位,後5位做爲版本號)
    protected static final int TYPE_LEN = 1;      //協議的數據類型長度

    private int reserved = 0;                     //預留信息
    private int version = Config.VERSION;         //版本號

    /**
     * 獲取整條數據長度
     * 單位:字節(byte)
     *
     * @return
     */
    protected int getLength() {
        return LENGTH_LEN + VER_LEN + TYPE_LEN;
    }

    public int getReserved() {
        return reserved;
    }

    public void setReserved(int reserved) {
        this.reserved = reserved;
    }

    public int getVersion() {
        return version;
    }

    public void setVersion(int version) {
        this.version = version;
    }

    /**
     * 獲取協議類型,由子類實現
     *
     * @return
     */
    public abstract int getProtocolType();
    
    /**
     * 由預留值和版本號計算完整版本號的byte[]值
     *
     * @return
     */
    private int getVer(byte r, byte v, int vLen) {
        int num = 0;
        int rLen = 8 - vLen;
        for (int i = 0; i < rLen; i++) {
            num += (((r >> (rLen - 1 - i)) & 0x1) << (7 - i));
        }
        return num + v;
    }

    /**
     * 拼接發送數據,此處拼接了協議版本、協議類型和數據長度,具體內容子類中再拼接
     * 按順序拼接
     *
     * @return
     */
    public byte[] genContentData() {
        byte[] length = SocketUtil.int2ByteArrays(getLength());
        byte reserved = (byte) getReserved();
        byte version = (byte) getVersion();
        byte[] ver = {(byte) getVer(reserved, version, 5)};
        byte[] type = {(byte) getProtocolType()};

        ByteArrayOutputStream baos = new ByteArrayOutputStream(LENGTH_LEN + VER_LEN + TYPE_LEN);
        baos.write(length, 0, LENGTH_LEN);
        baos.write(ver, 0, VER_LEN);
        baos.write(type, 0, TYPE_LEN);
        return baos.toByteArray();
    }

    /**
     * 解析出整條數據長度
     *
     * @param data
     * @return
     */
    protected int parseLength(byte[] data) {
        return SocketUtil.byteArrayToInt(data, 0, LENGTH_LEN);
    }

    /**
     * 解析出預留位
     *
     * @param data
     * @return
     */
    protected int parseReserved(byte[] data) {
        byte r = data[LENGTH_LEN];//前4個字節(0,1,2,3)爲數據長度的int值,與版本號組成一個字節
        return (r >> 5) & 0xFF;
    }

    /**
     * 解析出版本號
     *
     * @param data
     * @return
     */
    protected int parseVersion(byte[] data) {
        byte v = data[LENGTH_LEN]; //與預留位組成一個字節
        return ((v << 3) & 0xFF) >> 3;
    }

    /**
     * 解析出協議類型
     *
     * @param data
     * @return
     */
    public static int parseType(byte[] data) {
        byte t = data[LENGTH_LEN + VER_LEN];//前4個字節(0,1,2,3)爲數據長度的int值,以及ver佔一個字節
        return t & 0xFF;
    }

    /**
     * 解析接收數據,此處解析了協議版本、協議類型和數據長度,具體內容子類中再解析
     *
     * @param data
     * @return
     * @throws ProtocolException 協議版本不一致,拋出異常
     */
    public int parseContentData(byte[] data) throws ProtocolException {
        int reserved = parseReserved(data);
        int version = parseVersion(data);
        int protocolType = parseType(data);
        if (version != getVersion()) {
            throw new ProtocolException("input version is error: " + version);
        }
        return LENGTH_LEN + VER_LEN + TYPE_LEN;
    }

    @Override
    public String toString() {
        return "Version: " + getVersion() + ", Type: " + getProtocolType();
    }
}

上述涉及到的Config類和SocketUtil類以下:

/**
 * Created by meishan on 16/12/2.
 */
public class Config {

    public static final int VERSION = 1;                 //協議版本號
    public static final String ADDRESS = "10.17.64.237"; //服務器地址
    public static final int PORT = 9013;                 //服務器端口號
    
}
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;

/**
 * Created by meishan on 16/12/1.
 */
public class SocketUtil {

    private static Map<Integer, String> msgImp = new HashMap<>();

    static {
        msgImp.put(DataProtocol.PROTOCOL_TYPE, "com.shandiangou.sdgprotocol.lib.protocol.DataProtocol");       //0
        msgImp.put(DataAckProtocol.PROTOCOL_TYPE, "com.shandiangou.sdgprotocol.lib.protocol.DataAckProtocol"); //1
        msgImp.put(PingProtocol.PROTOCOL_TYPE, "com.shandiangou.sdgprotocol.lib.protocol.PingProtocol");       //2
        msgImp.put(PingAckProtocol.PROTOCOL_TYPE, "com.shandiangou.sdgprotocol.lib.protocol.PingAckProtocol"); //3
    }

    /**
     * 解析數據內容
     *
     * @param data
     * @return
     */
    public static BasicProtocol parseContentMsg(byte[] data) {
        int protocolType = BasicProtocol.parseType(data);
        String className = msgImp.get(protocolType);
        BasicProtocol basicProtocol;
        try {
            basicProtocol = (BasicProtocol) Class.forName(className).newInstance();
            basicProtocol.parseContentData(data);
        } catch (Exception e) {
            basicProtocol = null;
            e.printStackTrace();
        }
        return basicProtocol;
    }

    /**
     * 讀數據
     *
     * @param inputStream
     * @return
     * @throws SocketExceptions
     */
    public static BasicProtocol readFromStream(InputStream inputStream) {
        BasicProtocol protocol;
        BufferedInputStream bis;
        
        //header中保存的是整個數據的長度值,4個字節表示。在下述write2Stream方法中,會先寫入header
        byte[] header = new byte[BasicProtocol.LENGTH_LEN];

        try {
            bis = new BufferedInputStream(inputStream);

            int temp;
            int len = 0;
            while (len < header.length) {
                temp = bis.read(header, len, header.length - len);
                if (temp > 0) {
                    len += temp;
                } else if (temp == -1) {
                    bis.close();
                    return null;
                }
            }

            len = 0;
            int length = byteArrayToInt(header);//數據的長度值
            byte[] content = new byte[length];
            while (len < length) {
                temp = bis.read(content, len, length - len);

                if (temp > 0) {
                    len += temp;
                }
            }

            protocol = parseContentMsg(content);
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }

        return protocol;
    }

    /**
     * 寫數據
     *
     * @param protocol
     * @param outputStream
     */
    public static void write2Stream(BasicProtocol protocol, OutputStream outputStream) {
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(outputStream);
        byte[] buffData = protocol.genContentData();
        byte[] header = int2ByteArrays(buffData.length);
        try {
            bufferedOutputStream.write(header);
            bufferedOutputStream.write(buffData);
            bufferedOutputStream.flush();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 關閉輸入流
     *
     * @param is
     */
    public static void closeInputStream(InputStream is) {
        try {
            if (is != null) {
                is.close();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 關閉輸出流
     *
     * @param os
     */
    public static void closeOutputStream(OutputStream os) {
        try {
            if (os != null) {
                os.close();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static byte[] int2ByteArrays(int i) {
        byte[] result = new byte[4];
        result[0] = (byte) ((i >> 24) & 0xFF);
        result[1] = (byte) ((i >> 16) & 0xFF);
        result[2] = (byte) ((i >> 8) & 0xFF);
        result[3] = (byte) (i & 0xFF);
        return result;
    }

    public static int byteArrayToInt(byte[] b) {
        int intValue = 0;
        for (int i = 0; i < b.length; i++) {
            intValue += (b[i] & 0xFF) << (8 * (3 - i)); //int佔4個字節(0,1,2,3)
        }
        return intValue;
    }

    public static int byteArrayToInt(byte[] b, int byteOffset, int byteCount) {
        int intValue = 0;
        for (int i = byteOffset; i < (byteOffset + byteCount); i++) {
            intValue += (b[i] & 0xFF) << (8 * (3 - (i - byteOffset)));
        }
        return intValue;
    }

    public static int bytes2Int(byte[] b, int byteOffset) {
        ByteBuffer byteBuffer = ByteBuffer.allocate(Integer.SIZE / Byte.SIZE);
        byteBuffer.put(b, byteOffset, 4); //佔4個字節
        byteBuffer.flip();
        return byteBuffer.getInt();
    }
}

接下來咱們實現具體的協議。

2. DataProtocol

import android.util.Log;

import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;

import java.io.ByteArrayOutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;

/**
 * Created by meishan on 16/12/1.
 */
public class DataProtocol extends BasicProtocol implements Serializable {

    public static final int PROTOCOL_TYPE = 0;

    private static final int PATTION_LEN = 1;
    private static final int DTYPE_LEN = 1;
    private static final int MSGID_LEN = 4;

    private int pattion;
    private int dtype;
    private int msgId;

    private String data;

    @Override
    public int getLength() {
        return super.getLength() + PATTION_LEN + DTYPE_LEN + MSGID_LEN + data.getBytes().length;
    }

    @Override
    public int getProtocolType() {
        return PROTOCOL_TYPE;
    }

    public int getPattion() {
        return pattion;
    }

    public void setPattion(int pattion) {
        this.pattion = pattion;
    }

    public int getDtype() {
        return dtype;
    }

    public void setDtype(int dtype) {
        this.dtype = dtype;
    }

    public void setMsgId(int msgId) {
        this.msgId = msgId;
    }

    public int getMsgId() {
        return msgId;
    }

    public String getData() {
        return data;
    }

    public void setData(String data) {
        this.data = data;
    }

    /**
     * 拼接發送數據
     *
     * @return
     */
    @Override
    public byte[] genContentData() {
        byte[] base = super.genContentData();
        byte[] pattion = {(byte) this.pattion};
        byte[] dtype = {(byte) this.dtype};
        byte[] msgid = SocketUtil.int2ByteArrays(this.msgId);
        byte[] data = this.data.getBytes();

        ByteArrayOutputStream baos = new ByteArrayOutputStream(getLength());
        baos.write(base, 0, base.length);          //協議版本+數據類型+數據長度+消息id
        baos.write(pattion, 0, PATTION_LEN);       //業務類型
        baos.write(dtype, 0, DTYPE_LEN);           //業務數據格式
        baos.write(msgid, 0, MSGID_LEN);           //消息id
        baos.write(data, 0, data.length);          //業務數據
        return baos.toByteArray();
    }

    /**
     * 解析接收數據,按順序解析
     *
     * @param data
     * @return
     * @throws ProtocolException
     */
    @Override
    public int parseContentData(byte[] data) throws ProtocolException {
        int pos = super.parseContentData(data);

        //解析pattion
        pattion = data[pos] & 0xFF;
        pos += PATTION_LEN;

        //解析dtype
        dtype = data[pos] & 0xFF;
        pos += DTYPE_LEN;

        //解析msgId
        msgId = SocketUtil.byteArrayToInt(data, pos, MSGID_LEN);
        pos += MSGID_LEN;

        //解析data
        try {
            this.data = new String(data, pos, data.length - pos, "utf-8");
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        }

        return pos;
    }

    @Override
    public String toString() {
        return "data: " + data;
    }
}

3. DataAckProtocol

import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;

import java.io.ByteArrayOutputStream;
import java.io.UnsupportedEncodingException;

/**
 * Created by meishan on 16/12/1.
 */
public class DataAckProtocol extends BasicProtocol {

    public static final int PROTOCOL_TYPE = 1;

    private static final int ACKMSGID_LEN = 4;

    private int ackMsgId;

    private String unused;

    @Override
    public int getLength() {
        return super.getLength() + ACKMSGID_LEN + unused.getBytes().length;
    }

    @Override
    public int getProtocolType() {
        return PROTOCOL_TYPE;
    }

    public int getAckMsgId() {
        return ackMsgId;
    }

    public void setAckMsgId(int ackMsgId) {
        this.ackMsgId = ackMsgId;
    }

    public String getUnused() {
        return unused;
    }

    public void setUnused(String unused) {
        this.unused = unused;
    }

    /**
     * 拼接發送數據
     *
     * @return
     */
    @Override
    public byte[] genContentData() {
        byte[] base = super.genContentData();
        byte[] ackMsgId = SocketUtil.int2ByteArrays(this.ackMsgId);
        byte[] unused = this.unused.getBytes();

        ByteArrayOutputStream baos = new ByteArrayOutputStream(getLength());
        baos.write(base, 0, base.length);              //協議版本+數據類型+數據長度+消息id
        baos.write(ackMsgId, 0, ACKMSGID_LEN);         //消息id
        baos.write(unused, 0, unused.length);          //unused
        return baos.toByteArray();
    }

    @Override
    public int parseContentData(byte[] data) throws ProtocolException {
        int pos = super.parseContentData(data);

        //解析ackMsgId
        ackMsgId = SocketUtil.byteArrayToInt(data, pos, ACKMSGID_LEN);
        pos += ACKMSGID_LEN;

        //解析unused
        try {
            unused = new String(data, pos, data.length - pos, "utf-8");
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        }

        return pos;
    }

}

4. PingProtocol

import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;

import java.io.ByteArrayOutputStream;
import java.io.UnsupportedEncodingException;

/**
 * Created by meishan on 16/12/1.
 */
public class PingProtocol extends BasicProtocol {

    public static final int PROTOCOL_TYPE = 2;

    private static final int PINGID_LEN = 4;

    private int pingId;

    private String unused;

    @Override
    public int getLength() {
        return super.getLength() + PINGID_LEN + unused.getBytes().length;
    }

    @Override
    public int getProtocolType() {
        return PROTOCOL_TYPE;
    }

    public int getPingId() {
        return pingId;
    }

    public void setPingId(int pingId) {
        this.pingId = pingId;
    }

    public String getUnused() {
        return unused;
    }

    public void setUnused(String unused) {
        this.unused = unused;
    }

    /**
     * 拼接發送數據
     *
     * @return
     */
    @Override
    public byte[] genContentData() {
        byte[] base = super.genContentData();
        byte[] pingId = SocketUtil.int2ByteArrays(this.pingId);
        byte[] unused = this.unused.getBytes();

        ByteArrayOutputStream baos = new ByteArrayOutputStream(getLength());
        baos.write(base, 0, base.length);          //協議版本+數據類型+數據長度+消息id
        baos.write(pingId, 0, PINGID_LEN);         //消息id
        baos.write(unused, 0, unused.length);            //unused
        return baos.toByteArray();
    }

    @Override
    public int parseContentData(byte[] data) throws ProtocolException {
        int pos = super.parseContentData(data);

        //解析pingId
        pingId = SocketUtil.byteArrayToInt(data, pos, PINGID_LEN);
        pos += PINGID_LEN;

        try {
            unused = new String(data, pos, data.length - pos, "utf-8");
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        }

        return pos;
    }

}

5. PingAckProtocol

import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;

import java.io.ByteArrayOutputStream;
import java.io.UnsupportedEncodingException;

/**
 * Created by meishan on 16/12/1.
 */
public class PingAckProtocol extends BasicProtocol {

    public static final int PROTOCOL_TYPE = 3;

    private static final int ACKPINGID_LEN = 4;

    private int ackPingId;

    private String unused;

    @Override
    public int getLength() {
        return super.getLength() + ACKPINGID_LEN + unused.getBytes().length;
    }

    @Override
    public int getProtocolType() {
        return PROTOCOL_TYPE;
    }

    public int getAckPingId() {
        return ackPingId;
    }

    public void setAckPingId(int ackPingId) {
        this.ackPingId = ackPingId;
    }

    public String getUnused() {
        return unused;
    }

    public void setUnused(String unused) {
        this.unused = unused;
    }

    /**
     * 拼接發送數據
     *
     * @return
     */
    @Override
    public byte[] genContentData() {
        byte[] base = super.genContentData();
        byte[] ackPingId = SocketUtil.int2ByteArrays(this.ackPingId);
        byte[] unused = this.unused.getBytes();

        ByteArrayOutputStream baos = new ByteArrayOutputStream(getLength());
        baos.write(base, 0, base.length);                //協議版本+數據類型+數據長度+消息id
        baos.write(ackPingId, 0, ACKPINGID_LEN);         //消息id
        baos.write(unused, 0, unused.length);            //unused
        return baos.toByteArray();
    }

    @Override
    public int parseContentData(byte[] data) throws ProtocolException {
        int pos = super.parseContentData(data);

        //解析ackPingId
        ackPingId = SocketUtil.byteArrayToInt(data, pos, ACKPINGID_LEN);
        pos += ACKPINGID_LEN;

        //解析unused
        try {
            unused = new String(data, pos, data.length - pos, "utf-8");
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        }

        return pos;
    }

}

3、任務調度

上述已經給出了四種協議的實現,接下來咱們將使用它們來實現app和服務端之間的通訊,這裏咱們把數據的發送、接收和心跳分別用一個線程去實現,具體以下:

1. 客戶端

import android.os.Handler;
import android.os.Looper;
import android.os.Message;
import android.util.Log;

import com.shandiangou.sdgprotocol.lib.protocol.BasicProtocol;
import com.shandiangou.sdgprotocol.lib.protocol.DataProtocol;
import com.shandiangou.sdgprotocol.lib.protocol.PingProtocol;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ConnectException;
import java.net.Socket;
import java.util.concurrent.ConcurrentLinkedQueue;

import javax.net.SocketFactory;

/**
 * 寫數據採用死循環,沒有數據時wait,有新消息時notify
 * <p>
 * Created by meishan on 16/12/1.
 */
public class ClientRequestTask implements Runnable {

    private static final int SUCCESS = 100;
    private static final int FAILED = -1;

    private boolean isLongConnection = true;
    private Handler mHandler;
    private SendTask mSendTask;
    private ReciveTask mReciveTask;
    private HeartBeatTask mHeartBeatTask;
    private Socket mSocket;

    private boolean isSocketAvailable;
    private boolean closeSendTask;

    protected volatile ConcurrentLinkedQueue<BasicProtocol> dataQueue = new ConcurrentLinkedQueue<>();

    public ClientRequestTask(RequestCallBack requestCallBacks) {
        mHandler = new MyHandler(requestCallBacks);
    }

    @Override
    public void run() {
        try {
            try {
                mSocket = SocketFactory.getDefault().createSocket(Config.ADDRESS, Config.PORT);
//                mSocket.setSoTimeout(10);
            } catch (ConnectException e) {
                failedMessage(-1, "服務器鏈接異常,請檢查網絡");
                return;
            }

            isSocketAvailable = true;

            //開啓接收線程
            mReciveTask = new ReciveTask();
            mReciveTask.inputStream = mSocket.getInputStream();
            mReciveTask.start();

            //開啓發送線程
            mSendTask = new SendTask();
            mSendTask.outputStream = mSocket.getOutputStream();
            mSendTask.start();

            //開啓心跳線程
            if (isLongConnection) {
                mHeartBeatTask = new HeartBeatTask();
                mHeartBeatTask.outputStream = mSocket.getOutputStream();
                mHeartBeatTask.start();
            }
        } catch (IOException e) {
            failedMessage(-1, "網絡發生異常,請稍後重試");
            e.printStackTrace();
        }
    }

    public void addRequest(DataProtocol data) {
        dataQueue.add(data);
        toNotifyAll(dataQueue);//有新增待發送數據,則喚醒發送線程
    }

    public synchronized void stop() {

        //關閉接收線程
        closeReciveTask();

        //關閉發送線程
        closeSendTask = true;
        toNotifyAll(dataQueue);

        //關閉心跳線程
        closeHeartBeatTask();

        //關閉socket
        closeSocket();

        //清除數據
        clearData();

        failedMessage(-1, "斷開鏈接");
    }

    /**
     * 關閉接收線程
     */
    private void closeReciveTask() {
        if (mReciveTask != null) {
            mReciveTask.interrupt();
            mReciveTask.isCancle = true;
            if (mReciveTask.inputStream != null) {
                try {
                    if (isSocketAvailable && !mSocket.isClosed() && mSocket.isConnected()) {
                        mSocket.shutdownInput();//解決java.net.SocketException問題,須要先shutdownInput
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }
                SocketUtil.closeInputStream(mReciveTask.inputStream);
                mReciveTask.inputStream = null;
            }
            mReciveTask = null;
        }
    }

    /**
     * 關閉發送線程
     */
    private void closeSendTask() {
        if (mSendTask != null) {
            mSendTask.isCancle = true;
            mSendTask.interrupt();
            if (mSendTask.outputStream != null) {
                synchronized (mSendTask.outputStream) {//防止寫數據時中止,寫完再停
                    SocketUtil.closeOutputStream(mSendTask.outputStream);
                    mSendTask.outputStream = null;
                }
            }
            mSendTask = null;
        }
    }

    /**
     * 關閉心跳線程
     */
    private void closeHeartBeatTask() {
        if (mHeartBeatTask != null) {
            mHeartBeatTask.isCancle = true;
            if (mHeartBeatTask.outputStream != null) {
                SocketUtil.closeOutputStream(mHeartBeatTask.outputStream);
                mHeartBeatTask.outputStream = null;
            }
            mHeartBeatTask = null;
        }
    }

    /**
     * 關閉socket
     */
    private void closeSocket() {
        if (mSocket != null) {
            try {
                mSocket.close();
                isSocketAvailable = false;
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 清除數據
     */
    private void clearData() {
        dataQueue.clear();
        isLongConnection = false;
    }

    private void toWait(Object o) {
        synchronized (o) {
            try {
                o.wait();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * notify()調用後,並非立刻就釋放對象鎖的,而是在相應的synchronized(){}語句塊執行結束,自動釋放鎖後
     *
     * @param o
     */
    protected void toNotifyAll(Object o) {
        synchronized (o) {
            o.notifyAll();
        }
    }

    private void failedMessage(int code, String msg) {
        Message message = mHandler.obtainMessage(FAILED);
        message.what = FAILED;
        message.arg1 = code;
        message.obj = msg;
        mHandler.sendMessage(message);
    }

    private void successMessage(BasicProtocol protocol) {
        Message message = mHandler.obtainMessage(SUCCESS);
        message.what = SUCCESS;
        message.obj = protocol;
        mHandler.sendMessage(message);
    }

    private boolean isConnected() {
        if (mSocket.isClosed() || !mSocket.isConnected()) {
            ClientRequestTask.this.stop();
            return false;
        }
        return true;
    }

    /**
     * 服務器返回處理,主線程運行
     */
    public class MyHandler extends Handler {

        private RequestCallBack mRequestCallBack;

        public MyHandler(RequestCallBack callBack) {
            super(Looper.getMainLooper());
            this.mRequestCallBack = callBack;
        }

        @Override
        public void handleMessage(Message msg) {
            super.handleMessage(msg);
            switch (msg.what) {
                case SUCCESS:
                    mRequestCallBack.onSuccess((BasicProtocol) msg.obj);
                    break;
                case FAILED:
                    mRequestCallBack.onFailed(msg.arg1, (String) msg.obj);
                    break;
                default:
                    break;
            }
        }
    }

    /**
     * 數據接收線程
     */
    public class ReciveTask extends Thread {

        private boolean isCancle = false;
        private InputStream inputStream;

        @Override
        public void run() {
            while (!isCancle) {
                if (!isConnected()) {
                    break;
                }

                if (inputStream != null) {
                    BasicProtocol reciverData = SocketUtil.readFromStream(inputStream);
                    if (reciverData != null) {
                        if (reciverData.getProtocolType() == 1 || reciverData.getProtocolType() == 3) {
                            successMessage(reciverData);
                        }
                    } else {
                        break;
                    }
                }
            }

            SocketUtil.closeInputStream(inputStream);//循環結束則退出輸入流
        }
    }

    /**
     * 數據發送線程
     * 當沒有發送數據時讓線程等待
     */
    public class SendTask extends Thread {

        private boolean isCancle = false;
        private OutputStream outputStream;

        @Override
        public void run() {
            while (!isCancle) {
                if (!isConnected()) {
                    break;
                }

                BasicProtocol dataContent = dataQueue.poll();
                if (dataContent == null) {
                    toWait(dataQueue);//沒有發送數據則等待
                    if (closeSendTask) {
                        closeSendTask();//notify()調用後,並非立刻就釋放對象鎖的,因此在此處中斷髮送線程
                    }
                } else if (outputStream != null) {
                    synchronized (outputStream) {
                        SocketUtil.write2Stream(dataContent, outputStream);
                    }
                }
            }

            SocketUtil.closeOutputStream(outputStream);//循環結束則退出輸出流
        }
    }

    /**
     * 心跳實現,頻率5秒
     * Created by meishan on 16/12/1.
     */
    public class HeartBeatTask extends Thread {

        private static final int REPEATTIME = 5000;
        private boolean isCancle = false;
        private OutputStream outputStream;
        private int pingId;

        @Override
        public void run() {
            pingId = 1;
            while (!isCancle) {
                if (!isConnected()) {
                    break;
                }

                try {
                    mSocket.sendUrgentData(0xFF);
                } catch (IOException e) {
                    isSocketAvailable = false;
                    ClientRequestTask.this.stop();
                    break;
                }

                if (outputStream != null) {
                    PingProtocol pingProtocol = new PingProtocol();
                    pingProtocol.setPingId(pingId);
                    pingProtocol.setUnused("ping...");
                    SocketUtil.write2Stream(pingProtocol, outputStream);
                    pingId = pingId + 2;
                }

                try {
                    Thread.sleep(REPEATTIME);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }

            SocketUtil.closeOutputStream(outputStream);
        }
    }
}

其中涉及到的RequestCallBack接口以下:

/**
 * Created by meishan on 16/12/1.
 */
public interface RequestCallBack {

    void onSuccess(BasicProtocol msg);

    void onFailed(int errorCode, String msg);
}

2. 服務端

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.net.Socket;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;

/**
 * Created by meishan on 16/12/1.
 */
public class ServerResponseTask implements Runnable {

    private ReciveTask reciveTask;
    private SendTask sendTask;
    private Socket socket;
    private ResponseCallback tBack;

    private volatile ConcurrentLinkedQueue<BasicProtocol> dataQueue = new ConcurrentLinkedQueue<>();
    private static ConcurrentHashMap<String, Socket> onLineClient = new ConcurrentHashMap<>();

    private String userIP;

    public String getUserIP() {
        return userIP;
    }

    public ServerResponseTask(Socket socket, ResponseCallback tBack) {
        this.socket = socket;
        this.tBack = tBack;
        this.userIP = socket.getInetAddress().getHostAddress();
        System.out.println("用戶IP地址:" + userIP);
    }

    @Override
    public void run() {
        try {
            //開啓接收線程
            reciveTask = new ReciveTask();
            reciveTask.inputStream = new DataInputStream(socket.getInputStream());
            reciveTask.start();

            //開啓發送線程
            sendTask = new SendTask();
            sendTask.outputStream = new DataOutputStream(socket.getOutputStream());
            sendTask.start();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void stop() {
        if (reciveTask != null) {
            reciveTask.isCancle = true;
            reciveTask.interrupt();
            if (reciveTask.inputStream != null) {
                SocketUtil.closeInputStream(reciveTask.inputStream);
                reciveTask.inputStream = null;
            }
            reciveTask = null;
        }

        if (sendTask != null) {
            sendTask.isCancle = true;
            sendTask.interrupt();
            if (sendTask.outputStream != null) {
                synchronized (sendTask.outputStream) {//防止寫數據時中止,寫完再停
                    sendTask.outputStream = null;
                }
            }
            sendTask = null;
        }
    }

    public void addMessage(BasicProtocol data) {
        if (!isConnected()) {
            return;
        }

        dataQueue.offer(data);
        toNotifyAll(dataQueue);//有新增待發送數據,則喚醒發送線程
    }

    public Socket getConnectdClient(String clientID) {
        return onLineClient.get(clientID);
    }

    /**
     * 打印已經連接的客戶端
     */
    public static void printAllClient() {
        if (onLineClient == null) {
            return;
        }
        Iterator<String> inter = onLineClient.keySet().iterator();
        while (inter.hasNext()) {
            System.out.println("client:" + inter.next());
        }
    }

    public void toWaitAll(Object o) {
        synchronized (o) {
            try {
                o.wait();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    public void toNotifyAll(Object obj) {
        synchronized (obj) {
            obj.notifyAll();
        }
    }

    private boolean isConnected() {
        if (socket.isClosed() || !socket.isConnected()) {
            onLineClient.remove(userIP);
            ServerResponseTask.this.stop();
            System.out.println("socket closed...");
            return false;
        }
        return true;
    }

    public class ReciveTask extends Thread {

        private DataInputStream inputStream;
        private boolean isCancle;

        @Override
        public void run() {
            while (!isCancle) {
                if (!isConnected()) {
                    isCancle = true;
                    break;
                }

                BasicProtocol clientData = SocketUtil.readFromStream(inputStream);

                if (clientData != null) {
                    if (clientData.getProtocolType() == 0) {
                        System.out.println("dtype: " + ((DataProtocol) clientData).getDtype() + ", pattion: " + ((DataProtocol) clientData).getPattion() + ", msgId: " + ((DataProtocol) clientData).getMsgId() + ", data: " + ((DataProtocol) clientData).getData());

                        DataAckProtocol dataAck = new DataAckProtocol();
                        dataAck.setUnused("收到消息:" + ((DataProtocol) clientData).getData());
                        dataQueue.offer(dataAck);
                        toNotifyAll(dataQueue); //喚醒發送線程

                        tBack.targetIsOnline(userIP);
                    } else if (clientData.getProtocolType() == 2) {
                        System.out.println("pingId: " + ((PingProtocol) clientData).getPingId());

                        PingAckProtocol pingAck = new PingAckProtocol();
                        pingAck.setUnused("收到心跳");
                        dataQueue.offer(pingAck);
                        toNotifyAll(dataQueue); //喚醒發送線程

                        tBack.targetIsOnline(userIP);
                    }
                } else {
                    System.out.println("client is offline...");
                    break;
                }
            }

            SocketUtil.closeInputStream(inputStream);
        }
    }

    public class SendTask extends Thread {

        private DataOutputStream outputStream;
        private boolean isCancle;

        @Override
        public void run() {
            while (!isCancle) {
                if (!isConnected()) {
                    isCancle = true;
                    break;
                }

                BasicProtocol procotol = dataQueue.poll();
                if (procotol == null) {
                    toWaitAll(dataQueue);
                } else if (outputStream != null) {
                    synchronized (outputStream) {
                        SocketUtil.write2Stream(procotol, outputStream);
                    }
                }
            }

            SocketUtil.closeOutputStream(outputStream);
        }
    }

其中涉及到的ResponseCallback接口以下:

/**
 * Created by meishan on 16/12/1.
 */
public interface ResponseCallback {

    void targetIsOffline(DataProtocol reciveMsg);

    void targetIsOnline(String clientIp);
}

上述代碼中處理了幾種狀況下的異常,好比,創建鏈接後,服務端中止運行,此時客戶端的輸入流還在阻塞狀態,怎麼保證客戶端不拋出異常,這些處理能夠結合SocketUtil類來看。

4、調用封裝

1. 客戶端

import com.shandiangou.sdgprotocol.lib.protocol.DataProtocol;

/**
 * Created by meishan on 16/12/1.
 */
public class ConnectionClient {

    private boolean isClosed;

    private ClientRequestTask mClientRequestTask;

    public ConnectionClient(RequestCallBack requestCallBack) {
        mClientRequestTask = new ClientRequestTask(requestCallBack);
        new Thread(mClientRequestTask).start();
    }

    public void addNewRequest(DataProtocol data) {
        if (mClientRequestTask != null && !isClosed)
            mClientRequestTask.addRequest(data);
    }

    public void closeConnect() {
        isClosed = true;
        mClientRequestTask.stop();
    }
}

2. 服務端

import com.shandiangou.sdgprotocol.lib.protocol.DataProtocol;

import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * Created by meishan on 16/12/1.
 */
public class ConnectionServer {

    private static boolean isStart = true;
    private static ServerResponseTask serverResponseTask;

    public ConnectionServer() {

    }

    public static void main(String[] args) {

        ServerSocket serverSocket = null;
        ExecutorService executorService = Executors.newCachedThreadPool();
        try {
            serverSocket = new ServerSocket(Config.PORT);
            while (isStart) {
                Socket socket = serverSocket.accept();
                serverResponseTask = new ServerResponseTask(socket,
                        new ResponseCallback() {

                            @Override
                            public void targetIsOffline(DataProtocol reciveMsg) {// 對方不在線
                                if (reciveMsg != null) {
                                    System.out.println(reciveMsg.getData());
                                }
                            }

                            @Override
                            public void targetIsOnline(String clientIp) {
                                System.out.println(clientIp + " is onLine");
                                System.out.println("-----------------------------------------");
                            }
                        });

                if (socket.isConnected()) {
                    executorService.execute(serverResponseTask);
                }
            }

            serverSocket.close();

        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (serverSocket != null) {
                try {
                    isStart = false;
                    serverSocket.close();
                    if (serverSocket != null)
                        serverResponseTask.stop();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
}
相關文章
相關標籤/搜索