Java程序员学深度学习 DJL上手2 Springboot集成
一、准备环境
- windows
- idea
- jdk11
- maven
本文使用 model-zoo models 运行目标检测任务。
model-zoo 是来自新加坡的许靖宇建立的包含许多深度学习模型的网站。
二、新建项目
最终目录结构如下:
代码地址在:https://examples.javacodegeeks.com/djl-spring-boot-example/
三、pom.xml<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>ai.djl</groupId> <artifactId>image-object-detection</artifactId> <version>1.0.0-SNAPSHOT</version> <properties> <maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.target>1.8</maven.compiler.target> <djl.version>0.11.0</djl.version> </properties> <repositories> <repository> <id>djl.ai</id> <url>https://oss.sonatype.org/content/repositories/snapshots/</url> </repository> </repositories> <dependencyManagement> <dependencies> <dependency> <groupId>ai.djl</groupId> <artifactId>bom</artifactId> <version>${djl.version}</version> <type>pom</type> <scope>import</scope> </dependency> </dependencies> </dependencyManagement> <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> <version>2.3.4.RELEASE</version> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>model-zoo</artifactId> <version>${djl.version}</version> </dependency> <dependency> <groupId>ai.djl.mxnet</groupId> <artifactId>mxnet-model-zoo</artifactId> <version>${djl.version}</version> </dependency> <dependency> <groupId>ai.djl.mxnet</groupId> <artifactId>mxnet-engine</artifactId> <version>${djl.version}</version> </dependency> <dependency> <groupId>ai.djl.mxnet</groupId> <artifactId>mxnet-native-auto</artifactId> <version>1.6.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ch.qos.logback</groupId> <artifactId>logback-classic</artifactId> <version>1.2.3</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <version>RELEASE</version> <scope>compile</scope> </dependency> </dependencies> <build> <finalName>${project.artifactId}</finalName> <plugins> <plugin> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-maven-plugin</artifactId> <version>2.0.1.RELEASE</version> <executions> <execution> <goals> <goal>repackage</goal> </goals> </execution> </executions> </plugin> </plugins> </build></project>四、源代码1. SpringBoot 入口package com.jcg.djl;import org.springframework.boot.SpringApplication;import org.springframework.boot.autoconfigure.SpringBootApplication;@SpringBootApplicationpublic class ImageObjectDetectionApplication { public static void main(String[] args) { SpringApplication.run(ImageObjectDetectionApplication.class, args); }}2. Controllerpackage com.jcg.djl;import ai.djl.Application;import ai.djl.ModelException;import ai.djl.inference.Predictor;import ai.djl.modality.cv.Image;import ai.djl.modality.cv.ImageFactory;import ai.djl.modality.cv.output.DetectedObjects;import ai.djl.repository.zoo.Criteria;import ai.djl.repository.zoo.ModelZoo;import ai.djl.repository.zoo.ZooModel;import ai.djl.training.util.ProgressBar;import ai.djl.translate.TranslateException;import lombok.extern.slf4j.Slf4j;import org.apache.commons.compress.utils.IOUtils;import org.springframework.core.io.ClassPathResource;import org.springframework.http.MediaType;import org.springframework.http.ResponseEntity;import org.springframework.web.bind.annotation.*;import org.springframework.web.multipart.MultipartFile;import org.springframework.web.servlet.support.ServletUriComponentsBuilder;import java.io.IOException;import java.io.InputStream;import java.nio.file.Files;import java.nio.file.Path;import java.nio.file.Paths;import java.util.Objects;@Slf4j@RestControllerpublic class ImageDetectController { @PostMapping(value = "/upload", produces = MediaType.IMAGE_PNG_VALUE) public ResponseEntity<String> diagnose(@RequestParam("file") MultipartFile file) throws ModelException, TranslateException, IOException { byte[] bytes = file.getBytes(); Path imageFile = Paths.get(Objects.requireNonNull(file.getOriginalFilename())); Files.write(imageFile, bytes); return predict(imageFile); } public ResponseEntity<String> predict(Path imageFile) throws IOException, ModelException, TranslateException { Image img = ImageFactory.getInstance().fromFile(imageFile); Criteria<Image, DetectedObjects> criteria = Criteria.builder() .optApplication(Application.CV.OBJECT_DETECTION) .setTypes(Image.class, DetectedObjects.class) .optFilter("backbone", "resnet50") .optProgress(new ProgressBar()) .build(); try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) { try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) { DetectedObjects detection = predictor.predict(img); return saveBoundingBoxImage(img, detection); } } } private ResponseEntity<String> saveBoundingBoxImage(Image img, DetectedObjects detection) throws IOException { Path outputDir = Paths.get("src/main/resources"); Files.createDirectories(outputDir); // Make image copy with alpha channel because original image was jpg Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB); newImage.drawBoundingBoxes(detection); Path imagePath = outputDir.resolve("detected.png"); // OpenJDK can't save jpg with alpha channel newImage.save(Files.newOutputStream(imagePath), "png"); log.info("Detected objects image has been saved in:{}" , imagePath); String fileDownloadUri = ServletUriComponentsBuilder.fromCurrentContextPath() .path("get") .toUriString(); return ResponseEntity.ok(fileDownloadUri); } @GetMapping( value = "/get", produces = MediaType.IMAGE_PNG_VALUE ) public @ResponseBody byte[] getImageWithMediaType() throws IOException { InputStream in = new ClassPathResource( "detected.png").getInputStream(); return IOUtils.toByteArray(in); }}3. application.xmldjl: # 设定应用种类 application-type: OBJECT_DETECTION # 设定输入数据格式, 有的模型支持多种数据格式 input-class: java.awt.image.BufferedImage # 设定输出数据格式 output-class: ai.djl.modality.cv.output.DetectedObjects # 设定一个筛选器来筛选你的模型 model-filter: size: 512 # backbone: mobilenet1.0 # 覆写已有的输入输出配置 arguments: threshold: 0.5 # 只展示预测结果大于等于 0.5五、使用方式1. 运行程序:mvn spring-boot:run2. 打开网页
http://localhost:8080
3. 上传要识别的图片
4. 下载识别结果
打开地址: http://localhost:8080/get
评论