Spring Boot 實現文件分片上傳

思路

  1. 根據文件參數(MD5,大小,名稱)建立一個上傳任務,分配一個ID
  2. 循環上傳每個分片,保存成臨時文件
  3. 合併分片文件,校驗MD5值,最後清理臨時文件

服務端代碼

  1. Controllerjava

    @RestController
    public class FileController {
      @Autowired
      private FileUploadService fileUploadService;
    
      @PostMapping("/createTask")
      public TaskModel createTask(CreateTaskDTO dto) {
        return fileUploadService.createUploadTask(dto.getFileName(), dto.getFileMd5(), dto.getFileSize());
      }
    
      @PostMapping("/uploadFile")
      public String uploadFile(UploadFileDTO dto) {
        fileUploadService.uploadFile(dto.getTaskId(), dto.getMd5(), dto.getNo(), dto.getFile());
        return "ok";
      }
    
      @PostMapping("/mergeFile")
      public String mergeFile(UploadFileDTO dto) {
        fileUploadService.mergeFile(dto.getTaskId());
        return "ok";
      }
    }
    複製代碼
  2. Serviceweb

    import org.springframework.stereotype.Service;
    import org.springframework.util.DigestUtils;
    import org.springframework.util.FileCopyUtils;
    import org.springframework.web.multipart.MultipartFile;
    
    import java.io.File;
    import java.io.IOException;
    import java.nio.file.Files;
    import java.nio.file.Path;
    import java.nio.file.Paths;
    import java.util.Map;
    import java.util.Objects;
    import java.util.UUID;
    import java.util.concurrent.ConcurrentHashMap;
    
    @Service
    public class FileUploadService {
    
      private static final Map<String, TaskModel> fileMap = new ConcurrentHashMap<>();
    
      private final String tmpPath = "fileTmp/";
    
      public TaskModel createUploadTask(String fileName, String fileMd5, Long fileSize) {
        createTmpPath(tmpPath);
        TaskModel taskModel = new TaskModel();
        long sliceNum = fileSize / taskModel.getSize();
        if (fileSize % taskModel.getSize() != 0) {
          sliceNum += 1;
        }
        taskModel.setFileMd5(fileMd5);
        taskModel.setId(UUID.randomUUID().toString());
        taskModel.setSliceNum(sliceNum);
        taskModel.setFileName(fileName);
        taskModel.setFileSize(fileSize);
        fileMap.put(taskModel.getId(), taskModel);
        return taskModel;
      }
    
      public void uploadFile(String taskId, String md5, Integer no, MultipartFile file) {
        String tmpPath0 = tmpPath + taskId + "/";
        createTmpPath(tmpPath0);
        String tmpFileName = tmpPath0 + no + ".tmp";
        byte[] bytes = new byte[0];
        try {
          bytes = file.getBytes();
          String md5Value = DigestUtils.md5DigestAsHex(file.getBytes());
          if (!md5.equals(md5Value)) throw new OpenApiException("MD5校驗失敗");
          FileCopyUtils.copy(bytes, new File(tmpFileName));
        } catch (IOException e) {
          e.printStackTrace();
          throw new OpenApiException("文件上傳失敗");
        }
      }
    
      public void mergeFile(String taskId) {
        TaskModel taskModel = fileMap.get(taskId);
        if (taskModel == null) return;
        try {
          byte[] reduce = Files.list(Paths.get(tmpPath + "/" + taskId + "/")).sorted((s1, s2) -> {
            String[] split = s1.toString().split("\\\\");
            String[] split2 = s2.toString().split("\\\\");
            String[] split1 = split[split.length - 1].split("\\.");
            String[] split3 = split2[split2.length - 1].split("\\.");
            return Long.compare(Long.parseLong(split1[0]), (Long.parseLong(split3[0])));
          })
              .map(this::readByteByPath).filter(Objects::nonNull)
              .reduce(new byte[]{}, this::addBytes);
          String s = DigestUtils.md5DigestAsHex(reduce);
          if (!taskModel.getFileMd5().equals(s)) {
            throw new OpenApiException("MD5校驗錯誤");
          }
          FileCopyUtils.copy(reduce, new File(tmpPath + "/" + taskId + taskModel.getFileName()));
          deleteDir(tmpPath + taskId);
        } catch (IOException e) {
          e.printStackTrace();
        }
    
      }
    
      private void deleteDir(String path) {
        try {
          Files.list(Paths.get(path))
              .forEach(p -> {
                try {
                  Files.delete(p);
                } catch (IOException e) {
                  e.printStackTrace();
                }
              });
          Files.delete(Paths.get(path));
        } catch (IOException e) {
          e.printStackTrace();
        }
      }
      
      private byte[] readByteByPath(Path i) {
        try {
          return Files.readAllBytes(i);
        } catch (IOException e) {
          e.printStackTrace();
          return null;
        }
      }
    
      public byte[] addBytes(byte[] data1, byte[] data2) {
        byte[] data3 = new byte[data1.length + data2.length];
        System.arraycopy(data1, 0, data3, 0, data1.length);
        System.arraycopy(data2, 0, data3, data1.length, data2.length);
        return data3;
      }
    
      private void createTmpPath(String tmpPath) {
        Path path = Paths.get(tmpPath);
        if (Files.notExists(path)) {
          try {
            Files.createDirectory(path);
          } catch (IOException e) {
            e.printStackTrace();
            throw new OpenApiException("建立失敗");
          }
        }
      }
    }
    複製代碼
  3. DTOspring

    public class CreateTaskDTO {
      private String fileName;
      private String fileMd5;
      private Long fileSize;
      getter setter...
    }
    
    public class TaskModel {
      private String id; // 上傳文件ID
      private Long sliceNum; // 分片次數
      private Integer size = 1024; // 每片大小 字節
      private String fileMd5; // 文件MD5值
      private String fileName; // 文件名稱
      private Long fileSize; // 文件大小
    }  
    
    public class UploadFileDTO {
      private String taskId; // 任務ID
      private String md5; // 分片文件MD5
      private Integer no; // 分片文件序號
      private MultipartFile file; // 分片文件
    }
    複製代碼

客戶端代碼

import com.alibaba.fastjson.JSONObject;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.io.FileSystemResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.util.DigestUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;

import java.io.*;
import java.nio.file.Paths;
import java.util.function.BiConsumer;

@RunWith(SpringRunner.class)
@SpringBootTest
public class FileUploadTest {

  @Autowired
  private RestTemplate restTemplate;
  String fileLocal = "01.png";

  @Test
  public void testCreateTaskId() {
    String url = "http://127.0.0.1:9090/createTask";
    File file = new File(fileLocal);
    MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
    param.add("fileName", fileLocal);
    try {
      param.add("fileMd5", DigestUtils.md5DigestAsHex(new FileInputStream(file)));
    } catch (IOException e) {
      e.printStackTrace();
    }
    param.add("fileSize", file.length());

    HttpEntity<MultiValueMap<String, Object>> httpEntity = new HttpEntity<MultiValueMap<String, Object>>(param);
    ResponseEntity<String> responseEntity = restTemplate.postForEntity(url, httpEntity, String.class);
    System.out.println(responseEntity.getBody());
  }

  String s = "{\"id\":\"681aff74-ee43-4d2d-9488-7568854315c7\",\"sliceNum\":57,\"size\":1024,\"fileMd5\":\"55b1bfaa8360f333082956790a10ca8f\",\"fileName\":\"01.png\",\"fileSize\":58185}\n";

  @Test
  public void testUploadFile() throws Exception {
    String url = "http://127.0.0.1:9090/uploadFile";
    JSONObject jsonObject = JSONObject.parseObject(s);
    String id = jsonObject.getString("id");
    Integer sliceNum = jsonObject.getInteger("sliceNum");
    Integer size = jsonObject.getInteger("size");
    sliceFile(new File("01.png"), size, (no, bytes) -> {
      MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
      param.add("taskId", id);
      param.add("md5", DigestUtils.md5DigestAsHex(bytes));
      param.add("no", no);
      File file = new File(id + ".tmp." + no);
      try {
        java.nio.file.Files.write(Paths.get(file.toURI()), bytes);
      } catch (IOException e) {
        e.printStackTrace();
      }
      FileSystemResource resource = new FileSystemResource(file);
      param.add("file", resource);
      sendFile(url, param);
    });
  }

  @Test
  public void mergeFile() {
    JSONObject jsonObject = JSONObject.parseObject(s);
    String id = jsonObject.getString("id");
    MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
    param.add("taskId", id);
    sendFile("http://127.0.0.1:9090/mergeFile", param);
  }

  public void sendFile(String url, MultiValueMap<String, Object> param) {
    HttpEntity<MultiValueMap<String, Object>> httpEntity = new HttpEntity<>(param);
    ResponseEntity<String> responseEntity = restTemplate.postForEntity(url, httpEntity, String.class);
    System.out.println(responseEntity.getBody());
  }

  public static void sliceFile(File file, int size, BiConsumer<Long, byte[]> consumer) {
    RandomAccessFile randomAccessFile = null;
    try {
      randomAccessFile = new RandomAccessFile(file, "r");
    } catch (FileNotFoundException e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    }
    long length = file.length();
    long count = length / size;
    if (file.length() % count != 0) {
      count++;
    }

    long sum = 0;
    for (long i = 0; i < count; i++) {
      try {
        byte[] bytes;
        if (i + 1 == count) {
          bytes = new byte[(int) (length - sum)];
          randomAccessFile.read(bytes, 0, bytes.length);
        } else {
          bytes = new byte[size];
          sum += size;
          randomAccessFile.read(bytes, 0, size);
          randomAccessFile.seek(sum);
        }
        consumer.accept(i, bytes);
      } catch (IOException e) {
        e.printStackTrace();
      }
    }
  }
}
複製代碼

裏面的異常是本身定義的。json

相關文章
相關標籤/搜索