本地如何請求MINISTSKIL模型

  • 在SKIL中部署MNIST模型
  • 本地DL4J請求MNIST
    package io.skymind;
    
    import javax.imageio.ImageIO;
    import javax.swing.*;
    import java.awt.*;
    import java.awt.event.MouseAdapter;
    import java.awt.event.MouseEvent;
    import java.awt.event.MouseMotionAdapter;
    import java.awt.geom.Line2D;
    import java.awt.geom.Rectangle2D;
    import java.awt.image.BufferedImage;
    import java.io.File;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.Arrays;
    
    public class clientapp extends JFrame {
        private TextField endpoint;
        private TextField username;
        private TextField password;
        private PaintSurface paintSurface;
    
        public static void main(String[] args) {
            new clientapp();
        }
    
        public clientapp() {
            this.setSize(300, 350);
            this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
            
            this.setTitle("手寫數字識別");
    
            paintSurface = new PaintSurface();
            this.add(paintSurface, BorderLayout.CENTER);
    
            Box controls = Box.createVerticalBox();
            username= new TextField();
            username.setText("***");
            password= new TextField();
            password.setText("***");
            
            endpoint = new TextField();
            endpoint.setText("http://ip:9008/endpoints/mnist/model/mnist/default/");
            Button recognizeButton = new Button("識別");
            recognizeButton.addActionListener((e) -> {
                BufferedImage bi = new BufferedImage(300, 300, BufferedImage.TYPE_INT_RGB);
                Graphics2D ig2 = bi.createGraphics();
    
                ig2.setPaint(Color.BLACK);
                ig2.fillRect(0, 0, 300, 300);
                paintSurface.drawGrid = false;
                paintSurface.paint(ig2);
                paintSurface.drawGrid = true;
    
                Image image = bi.getScaledInstance(28, 28, BufferedImage.SCALE_SMOOTH);
                BufferedImage image2 = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
    
                Graphics2D g2d = image2.createGraphics();
                g2d.drawImage(image, 0, 0, null);
                g2d.dispose();
    
    
    
                try {
                    ImageIO.write(image2, "jpg", new File("D://tmp/0.jpg"));
                } catch (IOException e1) {
                    e1.printStackTrace();
                }
    
    
                float[] pixels = new float[28 * 28];
                for (int y = 0; y < 28; y++) {
                    for (int x = 0; x < 28; x++) {
                        int color = image2.getRGB(x, y);
    
                        int red   = (color >>> 16) & 0xFF;
                        int green = (color >>>  8) & 0xFF;
                        int blue  = (color >>>  0) & 0xFF;
    
                        float luminance = (red * 0.2126f + green * 0.7152f + blue * 0.0722f) / 255;
                        pixels[(y * 28) + x] = luminance;
                    }
                }
    
                for (int y = 0; y < 28; y++) {
                    for (int x = 0; x < 28; x++) {
                        System.out.print(pixels[(y * 28) + x]);
                    }
                    System.out.println();
                }
    
                System.out.println("發送: " + Arrays.toString(pixels));
    
                SkilClient client = new SkilClient(username.getText(), password.getText());
    
                try {
                    ClassifyResult result = client.classify(endpoint.getText(), pixels);
    
                    int top = result.getResults().get(0);
                    float topProb = result.getProbabilities().get(0);
    
                    JOptionPane.showMessageDialog(
                            this,
                            "SKIL 識別結果爲: " + top + " (機率: " + topProb + ")"
                    );
    
                } catch (IOException e2) {
                    e2.printStackTrace();
                    JOptionPane.showMessageDialog(this, "Error: " + e2.getMessage());
                }
    
                paintSurface.shapes.clear();
                paintSurface.repaint();
            });
            controls.add(new Label("用戶名:"));
            controls.add(username);
            controls.add(new Label("密碼:"));
            controls.add(password);
            controls.add(new Label("接口地址:"));
            controls.add(endpoint);
            controls.add(recognizeButton);
    
            this.add(controls, BorderLayout.SOUTH);
            this.setVisible(true);
        }
    
        private class PaintSurface extends JComponent {
            Color color = Color.WHITE;
            ArrayList<Shape> shapes = new ArrayList<Shape>();
            Point startDrag, endDrag;
            boolean drawGrid = true;
    
            public PaintSurface() {
                this.addMouseListener(new MouseAdapter() {
                    public void mousePressed(MouseEvent e) {
                        startDrag = new Point(e.getX(), e.getY());
                        endDrag = startDrag;
                        repaint();
                    }
    
                    public void mouseReleased(MouseEvent e) {
                        startDrag = null;
                        endDrag = null;
                        repaint();
                    }
                });
    
                this.addMouseMotionListener(new MouseMotionAdapter() {
                    public void mouseDragged(MouseEvent e) {
                        endDrag = new Point(e.getX(), e.getY());
    
                        if (endDrag.distance(startDrag) > 10) {
    
                            Rectangle2D.Float r = makeRectangle(startDrag.x, startDrag.y, e.getX(), e.getY());
                            if (r.width < 20) {
                                r.width = 20;
                            }
    
                            if (r.height < 20) {
                                r.height = 20;
                            }
    
                            shapes.add(r);
    
                            startDrag = new Point(e.getX(), e.getY());
                            endDrag = startDrag;
                        }
    
                        repaint();
                    }
                });
            }
            private void paintBackground(Graphics2D g2){
                g2.setPaint(Color.BLACK);
                g2.fillRect(0, 0, getSize().width, getSize().height);
    
                if (drawGrid) {
                    g2.setPaint(Color.LIGHT_GRAY);
                    for (int i = 0; i < getSize().width; i += 10) {
                        Shape line = new Line2D.Float(i, 0, i, getSize().height);
                        g2.draw(line);
                    }
    
                    for (int i = 0; i < getSize().height; i += 10) {
                        Shape line = new Line2D.Float(0, i, getSize().width, i);
                        g2.draw(line);
                    }
                }
            }
    
            public void paint(Graphics g) {
                Graphics2D g2 = (Graphics2D) g;
                g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
                paintBackground(g2);
    
                g2.setStroke(new BasicStroke(2));
                g2.setComposite(AlphaComposite.getInstance(AlphaComposite.SRC_OVER, 1.0f));
    
                for (Shape s : shapes) {
                    g2.setPaint(color);
                    g2.fill(s);
                }
    
                if (startDrag != null && endDrag != null) {
                    g2.setPaint(Color.LIGHT_GRAY);
                    Shape r = makeRectangle(startDrag.x, startDrag.y, endDrag.x, endDrag.y);
                    g2.draw(r);
                }
            }
    
            private Rectangle2D.Float makeRectangle(int x1, int y1, int x2, int y2) {
                return new Rectangle2D.Float(Math.min(x1, x2), Math.min(y1, y2), Math.abs(x1 - x2), Math.abs(y1 - y2));
            }
        }
    }
相關文章
相關標籤/搜索