/*
 * Decompiled with CFR 0.152.
 */
package ghidra.util.bytesearch;

import ghidra.util.bytesearch.BytePattern;
import ghidra.util.bytesearch.ExtendedByteSequence;
import ghidra.util.bytesearch.InputStreamBufferByteSequence;
import ghidra.util.bytesearch.Match;
import ghidra.util.task.TaskMonitor;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;

public class BulkPatternSearcher<T extends BytePattern> {
    private static final int DEFAULT_BUFFER_SIZE = 4096;
    private List<T> patterns;
    private SearchState<T> startState;
    private int bufferSize = 4096;
    private int uniqueStateCount;

    public BulkPatternSearcher(List<T> patterns) {
        this.patterns = patterns;
        this.startState = this.buildStateMachine();
    }

    public Iterator<Match<T>> search(byte[] input) {
        return new ByteArrayMatchIterator(input);
    }

    public Iterator<Match<T>> search(byte[] input, int length) {
        return new ByteArrayMatchIterator(input, length);
    }

    public void search(byte[] input, List<Match<T>> results) {
        this.search(input, input.length, results);
    }

    public void search(byte[] input, int numBytes, List<Match<T>> results) {
        for (int patternStart = 0; patternStart < numBytes; ++patternStart) {
            int index;
            SearchState<T> nextState;
            SearchState<T> state = this.startState;
            for (int i = patternStart; i < numBytes && (nextState = state.nextStates[index = input[i] & 0xFF]) != null; ++i) {
                nextState.addMatchesForCompletedPatterns(results, patternStart);
                state = nextState;
            }
        }
    }

    public void matches(byte[] input, int numBytes, List<Match<T>> results) {
        int index;
        SearchState<T> nextState;
        SearchState<T> state = this.startState;
        for (int i = 0; i < numBytes && (nextState = state.nextStates[index = input[i] & 0xFF]) != null; ++i) {
            nextState.addMatchesForCompletedPatterns(results, 0);
            state = nextState;
        }
    }

    public void search(ExtendedByteSequence bytes, List<Match<T>> results, int chunkOffset) {
        for (int patternStart = 0; patternStart < bytes.getLength(); ++patternStart) {
            int index;
            SearchState<T> nextState;
            SearchState<T> state = this.startState;
            for (int j = patternStart; j < bytes.getExtendedLength() && (nextState = state.nextStates[index = bytes.getByte(j) & 0xFF]) != null; ++j) {
                nextState.addMatchesForCompletedPatterns(results, patternStart + chunkOffset);
                state = nextState;
            }
        }
    }

    public void search(InputStream is, List<Match<T>> results, TaskMonitor monitor) throws IOException {
        this.search(is, -1L, results, monitor);
    }

    public void search(InputStream inputStream, long maxRead, List<Match<T>> results, TaskMonitor monitor) throws IOException {
        ExtendedByteSequence combined;
        RestrictedStream restrictedStream = new RestrictedStream(inputStream, maxRead);
        int maxPatternLength = this.getLongestPatternLength();
        int bufSize = Math.max(maxPatternLength, this.bufferSize);
        int offset = 0;
        InputStreamBufferByteSequence pre = new InputStreamBufferByteSequence(bufSize);
        InputStreamBufferByteSequence main = new InputStreamBufferByteSequence(bufSize);
        InputStreamBufferByteSequence post = new InputStreamBufferByteSequence(bufSize);
        main.load(restrictedStream, bufSize);
        post.load(restrictedStream, bufSize);
        while (main.getLength() > 0 && post.getLength() > 0) {
            if (monitor.isCancelled()) {
                return;
            }
            combined = new ExtendedByteSequence(main, pre, post, maxPatternLength);
            this.search(combined, results, offset);
            monitor.incrementProgress((long)main.getLength());
            offset += main.getLength();
            InputStreamBufferByteSequence tmp = pre;
            pre = main;
            main = post;
            post = tmp;
            post.load(restrictedStream, bufSize);
        }
        post.load(inputStream, maxPatternLength);
        combined = new ExtendedByteSequence(main, pre, post, maxPatternLength);
        this.search(combined, results, offset);
        monitor.incrementProgress((long)main.getLength());
    }

    public void setBufferSize(int bufferSize) {
        this.bufferSize = bufferSize;
    }

    private SearchState<T> buildStateMachine() {
        ArrayDeque unprocessed = new ArrayDeque();
        HashMap dedupCache = new HashMap();
        SearchState<T> start = new SearchState<T>(this.patterns, 0);
        unprocessed.add(start);
        while (!unprocessed.isEmpty()) {
            SearchState next = (SearchState)unprocessed.remove();
            next.computeTransitions(unprocessed, dedupCache);
        }
        this.uniqueStateCount = dedupCache.size() + 1;
        dedupCache.clear();
        return start;
    }

    private int getLongestPatternLength() {
        int maxLength = 0;
        for (BytePattern t : this.patterns) {
            maxLength = Math.max(maxLength, t.getSize());
        }
        return maxLength;
    }

    public int getUniqueStateCount() {
        return this.uniqueStateCount;
    }

    private static class SearchState<T extends BytePattern> {
        private List<T> activePatterns;
        private List<T> completedPatterns;
        private SearchState<T>[] nextStates;
        private int level;
        private int hash;

        SearchState(List<T> activePatterns, int level) {
            this.activePatterns = activePatterns;
            this.level = level;
            this.hash = Objects.hash(activePatterns, level);
        }

        void computeTransitions(Queue<SearchState<T>> unresolved, Map<SearchState<T>, SearchState<T>> cache) {
            this.completedPatterns = this.buildFullyMatchedPatternsList();
            this.nextStates = this.createTransitionArray();
            if (this.completedPatterns != null && this.completedPatterns.size() == this.activePatterns.size()) {
                return;
            }
            for (int inputValue = 0; inputValue < 256; ++inputValue) {
                List<T> matchedPatterns = this.getMatchedPatterns(inputValue);
                if (matchedPatterns.isEmpty()) continue;
                this.nextStates[inputValue] = this.getSearchState(matchedPatterns, cache, unresolved);
            }
        }

        public int hashCode() {
            return this.hash;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            SearchState other = (SearchState)obj;
            if (this.hash != other.hash) {
                return false;
            }
            return this.level == other.level && Objects.equals(this.activePatterns, other.activePatterns);
        }

        private SearchState<T>[] createTransitionArray() {
            return (SearchState[])Array.newInstance(this.getClass(), 256);
        }

        private SearchState<T> getSearchState(List<T> patterns, Map<SearchState<T>, SearchState<T>> cache, Queue<SearchState<T>> unresolved) {
            SearchState<T> newState = new SearchState<T>(patterns, this.level + 1);
            SearchState<T> existing = cache.get(newState);
            if (existing != null) {
                return existing;
            }
            cache.put(newState, newState);
            unresolved.add(newState);
            return newState;
        }

        private List<T> getMatchedPatterns(int inputValue) {
            ArrayList<BytePattern> matchedPatterns = new ArrayList<BytePattern>();
            for (BytePattern pattern : this.activePatterns) {
                if (!pattern.isMatch(this.level, inputValue)) continue;
                matchedPatterns.add(pattern);
            }
            return matchedPatterns;
        }

        private void addMatchesForCompletedPatterns(Collection<Match<T>> results, int i) {
            if (this.completedPatterns == null) {
                return;
            }
            for (BytePattern pattern : this.completedPatterns) {
                results.add(new Match<BytePattern>(pattern, i, pattern.getSize()));
            }
        }

        private List<T> buildFullyMatchedPatternsList() {
            ArrayList<BytePattern> list = new ArrayList<BytePattern>();
            for (BytePattern pattern : this.activePatterns) {
                if (pattern.getSize() != this.level) continue;
                list.add(pattern);
            }
            return list.isEmpty() ? null : list;
        }
    }

    private class ByteArrayMatchIterator
    implements Iterator<Match<T>> {
        private byte[] bytes;
        private int length;
        private int patternStart = 0;
        private Queue<Match<T>> resultBuffer = new ArrayDeque();

        ByteArrayMatchIterator(byte[] input) {
            this(input, input.length);
        }

        ByteArrayMatchIterator(byte[] input, int length) {
            this.bytes = input;
            this.length = Math.min(length, this.bytes.length);
            this.findNext();
        }

        private void findNext() {
            while (this.patternStart < this.length && this.resultBuffer.isEmpty()) {
                int index;
                SearchState state = BulkPatternSearcher.this.startState;
                for (int i = this.patternStart; i < this.length && (state = state.nextStates[index = this.bytes[i] & 0xFF]) != null; ++i) {
                    state.addMatchesForCompletedPatterns(this.resultBuffer, this.patternStart);
                }
                ++this.patternStart;
            }
        }

        @Override
        public boolean hasNext() {
            return !this.resultBuffer.isEmpty();
        }

        @Override
        public Match<T> next() {
            Match nextResult = this.resultBuffer.poll();
            if (this.resultBuffer.isEmpty()) {
                this.findNext();
            }
            return nextResult;
        }
    }

    private static class RestrictedStream
    extends InputStream {
        private long maxRead;
        private long totalRead;
        private InputStream is;

        RestrictedStream(InputStream is, long maxRead) {
            this.is = is;
            this.maxRead = maxRead;
        }

        @Override
        public int read(byte[] buf) throws IOException {
            return this.read(buf, 0, buf.length);
        }

        @Override
        public int read(byte[] buf, int offset, int amount) throws IOException {
            int n;
            int amountToRead = amount;
            if (this.maxRead >= 0L) {
                long remaining = this.maxRead - this.totalRead;
                amountToRead = (int)Math.min(remaining, (long)amount);
            }
            n = (n = this.is.read(buf, offset, amountToRead)) > 0 ? n : 0;
            this.totalRead += (long)n;
            return n;
        }

        @Override
        public int read() throws IOException {
            if (this.totalRead >= this.maxRead) {
                return -1;
            }
            int value = this.is.read();
            if (value < 0) {
                return -1;
            }
            ++this.totalRead;
            return value;
        }
    }
}

