Controller
java
@RestController
public class FileController {
@Autowired
private FileUploadService fileUploadService;
@PostMapping("/createTask")
public TaskModel createTask(CreateTaskDTO dto) {
return fileUploadService.createUploadTask(dto.getFileName(), dto.getFileMd5(), dto.getFileSize());
}
@PostMapping("/uploadFile")
public String uploadFile(UploadFileDTO dto) {
fileUploadService.uploadFile(dto.getTaskId(), dto.getMd5(), dto.getNo(), dto.getFile());
return "ok";
}
@PostMapping("/mergeFile")
public String mergeFile(UploadFileDTO dto) {
fileUploadService.mergeFile(dto.getTaskId());
return "ok";
}
}
複製代碼
Service
web
import org.springframework.stereotype.Service;
import org.springframework.util.DigestUtils;
import org.springframework.util.FileCopyUtils;
import org.springframework.web.multipart.MultipartFile;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@Service
public class FileUploadService {
private static final Map<String, TaskModel> fileMap = new ConcurrentHashMap<>();
private final String tmpPath = "fileTmp/";
public TaskModel createUploadTask(String fileName, String fileMd5, Long fileSize) {
createTmpPath(tmpPath);
TaskModel taskModel = new TaskModel();
long sliceNum = fileSize / taskModel.getSize();
if (fileSize % taskModel.getSize() != 0) {
sliceNum += 1;
}
taskModel.setFileMd5(fileMd5);
taskModel.setId(UUID.randomUUID().toString());
taskModel.setSliceNum(sliceNum);
taskModel.setFileName(fileName);
taskModel.setFileSize(fileSize);
fileMap.put(taskModel.getId(), taskModel);
return taskModel;
}
public void uploadFile(String taskId, String md5, Integer no, MultipartFile file) {
String tmpPath0 = tmpPath + taskId + "/";
createTmpPath(tmpPath0);
String tmpFileName = tmpPath0 + no + ".tmp";
byte[] bytes = new byte[0];
try {
bytes = file.getBytes();
String md5Value = DigestUtils.md5DigestAsHex(file.getBytes());
if (!md5.equals(md5Value)) throw new OpenApiException("MD5校驗失敗");
FileCopyUtils.copy(bytes, new File(tmpFileName));
} catch (IOException e) {
e.printStackTrace();
throw new OpenApiException("文件上傳失敗");
}
}
public void mergeFile(String taskId) {
TaskModel taskModel = fileMap.get(taskId);
if (taskModel == null) return;
try {
byte[] reduce = Files.list(Paths.get(tmpPath + "/" + taskId + "/")).sorted((s1, s2) -> {
String[] split = s1.toString().split("\\\\");
String[] split2 = s2.toString().split("\\\\");
String[] split1 = split[split.length - 1].split("\\.");
String[] split3 = split2[split2.length - 1].split("\\.");
return Long.compare(Long.parseLong(split1[0]), (Long.parseLong(split3[0])));
})
.map(this::readByteByPath).filter(Objects::nonNull)
.reduce(new byte[]{}, this::addBytes);
String s = DigestUtils.md5DigestAsHex(reduce);
if (!taskModel.getFileMd5().equals(s)) {
throw new OpenApiException("MD5校驗錯誤");
}
FileCopyUtils.copy(reduce, new File(tmpPath + "/" + taskId + taskModel.getFileName()));
deleteDir(tmpPath + taskId);
} catch (IOException e) {
e.printStackTrace();
}
}
private void deleteDir(String path) {
try {
Files.list(Paths.get(path))
.forEach(p -> {
try {
Files.delete(p);
} catch (IOException e) {
e.printStackTrace();
}
});
Files.delete(Paths.get(path));
} catch (IOException e) {
e.printStackTrace();
}
}
private byte[] readByteByPath(Path i) {
try {
return Files.readAllBytes(i);
} catch (IOException e) {
e.printStackTrace();
return null;
}
}
public byte[] addBytes(byte[] data1, byte[] data2) {
byte[] data3 = new byte[data1.length + data2.length];
System.arraycopy(data1, 0, data3, 0, data1.length);
System.arraycopy(data2, 0, data3, data1.length, data2.length);
return data3;
}
private void createTmpPath(String tmpPath) {
Path path = Paths.get(tmpPath);
if (Files.notExists(path)) {
try {
Files.createDirectory(path);
} catch (IOException e) {
e.printStackTrace();
throw new OpenApiException("建立失敗");
}
}
}
}
複製代碼
DTO
類spring
public class CreateTaskDTO {
private String fileName;
private String fileMd5;
private Long fileSize;
getter setter...
}
public class TaskModel {
private String id; // 上傳文件ID
private Long sliceNum; // 分片次數
private Integer size = 1024; // 每片大小 字節
private String fileMd5; // 文件MD5值
private String fileName; // 文件名稱
private Long fileSize; // 文件大小
}
public class UploadFileDTO {
private String taskId; // 任務ID
private String md5; // 分片文件MD5
private Integer no; // 分片文件序號
private MultipartFile file; // 分片文件
}
複製代碼
import com.alibaba.fastjson.JSONObject;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.io.FileSystemResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.util.DigestUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import java.io.*;
import java.nio.file.Paths;
import java.util.function.BiConsumer;
@RunWith(SpringRunner.class)
@SpringBootTest
public class FileUploadTest {
@Autowired
private RestTemplate restTemplate;
String fileLocal = "01.png";
@Test
public void testCreateTaskId() {
String url = "http://127.0.0.1:9090/createTask";
File file = new File(fileLocal);
MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
param.add("fileName", fileLocal);
try {
param.add("fileMd5", DigestUtils.md5DigestAsHex(new FileInputStream(file)));
} catch (IOException e) {
e.printStackTrace();
}
param.add("fileSize", file.length());
HttpEntity<MultiValueMap<String, Object>> httpEntity = new HttpEntity<MultiValueMap<String, Object>>(param);
ResponseEntity<String> responseEntity = restTemplate.postForEntity(url, httpEntity, String.class);
System.out.println(responseEntity.getBody());
}
String s = "{\"id\":\"681aff74-ee43-4d2d-9488-7568854315c7\",\"sliceNum\":57,\"size\":1024,\"fileMd5\":\"55b1bfaa8360f333082956790a10ca8f\",\"fileName\":\"01.png\",\"fileSize\":58185}\n";
@Test
public void testUploadFile() throws Exception {
String url = "http://127.0.0.1:9090/uploadFile";
JSONObject jsonObject = JSONObject.parseObject(s);
String id = jsonObject.getString("id");
Integer sliceNum = jsonObject.getInteger("sliceNum");
Integer size = jsonObject.getInteger("size");
sliceFile(new File("01.png"), size, (no, bytes) -> {
MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
param.add("taskId", id);
param.add("md5", DigestUtils.md5DigestAsHex(bytes));
param.add("no", no);
File file = new File(id + ".tmp." + no);
try {
java.nio.file.Files.write(Paths.get(file.toURI()), bytes);
} catch (IOException e) {
e.printStackTrace();
}
FileSystemResource resource = new FileSystemResource(file);
param.add("file", resource);
sendFile(url, param);
});
}
@Test
public void mergeFile() {
JSONObject jsonObject = JSONObject.parseObject(s);
String id = jsonObject.getString("id");
MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
param.add("taskId", id);
sendFile("http://127.0.0.1:9090/mergeFile", param);
}
public void sendFile(String url, MultiValueMap<String, Object> param) {
HttpEntity<MultiValueMap<String, Object>> httpEntity = new HttpEntity<>(param);
ResponseEntity<String> responseEntity = restTemplate.postForEntity(url, httpEntity, String.class);
System.out.println(responseEntity.getBody());
}
public static void sliceFile(File file, int size, BiConsumer<Long, byte[]> consumer) {
RandomAccessFile randomAccessFile = null;
try {
randomAccessFile = new RandomAccessFile(file, "r");
} catch (FileNotFoundException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
long length = file.length();
long count = length / size;
if (file.length() % count != 0) {
count++;
}
long sum = 0;
for (long i = 0; i < count; i++) {
try {
byte[] bytes;
if (i + 1 == count) {
bytes = new byte[(int) (length - sum)];
randomAccessFile.read(bytes, 0, bytes.length);
} else {
bytes = new byte[size];
sum += size;
randomAccessFile.read(bytes, 0, size);
randomAccessFile.seek(sum);
}
consumer.accept(i, bytes);
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
複製代碼
裏面的異常是本身定義的。json