tensorflow

http://www.tensorflow.org/lite/examples/audio_classification/overview?hl=zh-cn

Tensorflow音频分类插图

官方有移动端demo

Tensorflow音频分类插图(1)

前端不会  就只能找找有没有java支持

Tensorflow音频分类插图(2)

Tensorflow音频分类插图(3)

Tensorflow音频分类插图(4)

注意版本

Tensorflow音频分类插图(5)

Tensorflow音频分类插图(6)

Tensorflow音频分类插图(7)

注意JDK版本

package com.example.demo17.controller;


import org.tensorflow.*;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;

import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.UnsupportedAudioFileException;
import javax.xml.transform.Result;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class Test {


    private static FloatNdArray t1() {
//        String audioFilePath = "D:\ai\cat.wav";
        String audioFilePath = "C:\Users\user\Downloads\output_Wo9KJb-5zuz1_2.wav";
//        String audioFilePath = "D:\ai11.wav";
        // YAMNet期望的采样率
        int sampleRate = 16000;
        // YAMNet帧大小,0.96秒
        int frameSizeInMs = 96;
        // YAMNet帧步长,0.48秒
        int hopSizeInMs = 48;

        try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(Paths.get(audioFilePath).toFile())) {
            AudioFormat format = audioStream.getFormat();
            if (format.getSampleRate() != sampleRate || format.getChannels() != 1) {
                System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");
            }
            int frameSize = (int) (sampleRate * frameSizeInMs / 1000);
            int hopSize = (int) (sampleRate * hopSizeInMs / 1000);

            byte[] buffer = new byte[frameSize * format.getFrameSize()];
            short[] audioSamples = new short[frameSize];
            // 存储每个帧的音频数据
            List floatList = new ArrayList();
            while (true) {
                int bytesRead = audioStream.read(buffer);
                if (bytesRead == -1) {
                    break;
                }
                // 将读取的字节转换为short数组(假设16位精度)
                for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {
                    audioSamples[i] = (short) ((buffer[i * 2] & 0xFF) | (buffer[i * 2 + 1] << 8));
                }
                // 对当前帧进行处理(例如,归一化和准备送入模型)
                float[] floats = processFrame(audioSamples);
                for (float aFloat : floats) {
                    floatList.add(aFloat);
                }
                // 移动到下一个帧
                System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);
            }

            // 将List转换为float[]
            float[] floatArray = new float[floatList.size()];
            for (int i = 0; i < floatList.size(); i++) {
                floatArray[i] = floatList.get(i);
            }

            return StdArrays.ndCopyOf(floatArray);
        } catch (UnsupportedAudioFileException | IOException e) {
            e.printStackTrace();
        }
        return null;
    }


    private static float[] processFrame(short[] frame) {
        // 示例:归一化音频数据到[-1.0, 1.0]
        float[] normalizedFrame = new float[frame.length];
        for (int i = 0; i < frame.length; i++) {
            // short的最大值为32767,故除以32768得到[-1.0, 1.0]
            normalizedFrame[i] = frame[i] / 32768f;
        }
        return normalizedFrame;
    }

    static Map map=new ConcurrentHashMap();

    public static void main(String[] args) throws Exception {
        FloatNdArray floatNdArray = t1();
        TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);

        //SavedModelBundle savedModelBundle = SavedModelBundle.load("D:\saved_model", "serve");
        SavedModelBundle savedModelBundle = SavedModelBundle.load("C:\Users\user\Downloads\archive", "serve");
        Map signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
        /**
         * 获取基本定义信息
         */
        SignatureDef modelSig = signatureDefMap.get("serving_default");
        String inputTensorName = modelSig.getInputsMap().get("waveform").getName();
        String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();
        savedModelBundle.graph();
        try (Session session = savedModelBundle.session()) {
            /*JDK 17*/
//            Result run = session.runner()
//                    .feed(inputTensorName, tFloat32)
//                    .fetch(outputTensorName)
//                    .run();
//            Tensor out = run.get(0);
//            Shape shape = out.shape();
//
//            System.out.println(shape);
            /*JDK 8*/
            List run = session.runner()
                    .feed(inputTensorName, tFloat32)
                    .fetch(outputTensorName)
                    .run();
            Tensor tensor = run.get(0);
            Shape shape = tensor.shape();
            System.out.println(shape.asArray());
            String l=String.valueOf(shape.asArray()[0]);
            //读取CSV文件
            String csvFile = "C:\Users\user\Downloads\archive\assets\yamnet_class_map.csv";
            try {
                List lines = Files.readAllLines(Paths.get(csvFile));
                for (String line : lines) {
                    String[] values = line.split(",");
                    map.put(values[0], values[2]);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            String s = map.get(l);
            System.out.println(s);
        }
    }
}
本站无任何商业行为
个人在线分享 » Tensorflow音频分类
E-->