package com.baidu.hugegraph.computer.core.receiver.message;

import com.baidu.hugegraph.computer.core.combiner.DoubleValueSumCombiner;
import com.baidu.hugegraph.computer.core.config.ComputerOptions;
import com.baidu.hugegraph.computer.core.config.Config;
import com.baidu.hugegraph.computer.core.config.Null;
import com.baidu.hugegraph.computer.core.graph.id.BytesId;
import com.baidu.hugegraph.computer.core.graph.id.Id;
import com.baidu.hugegraph.computer.core.graph.value.DoubleValue;
import com.baidu.hugegraph.computer.core.graph.value.IdList;
import com.baidu.hugegraph.computer.core.network.buffer.ManagedBuffer;
import com.baidu.hugegraph.computer.core.receiver.ReceiverUtil;
import com.baidu.hugegraph.computer.core.sort.flusher.PeekableIterator;
import com.baidu.hugegraph.computer.core.sort.sorting.RecvSortManager;
import com.baidu.hugegraph.computer.core.store.FileManager;
import com.baidu.hugegraph.computer.core.store.SuperstepFileGenerator;
import com.baidu.hugegraph.computer.core.store.hgkvfile.entry.KvEntry;
import com.baidu.hugegraph.computer.suite.unit.UnitTestBase;
import com.baidu.hugegraph.testutil.Assert;
import java.io.File;
import java.io.IOException;
import java.util.function.Consumer;
import org.apache.commons.io.FileUtils;
import org.junit.Test;

/* loaded from: input_file:com/baidu/hugegraph/computer/core/receiver/message/ComputeMessageRecvPartitionTest.class */
public class ComputeMessageRecvPartitionTest extends UnitTestBase {
    @Test
    public void testCombineMessageRecvPartition() throws IOException {
        Config updateWithRequiredOptions = UnitTestBase.updateWithRequiredOptions(ComputerOptions.JOB_ID, "local_001", ComputerOptions.JOB_WORKERS_COUNT, "1", ComputerOptions.JOB_PARTITIONS_COUNT, "1", ComputerOptions.WORKER_RECEIVED_BUFFERS_BYTES_LIMIT, "1000", ComputerOptions.WORKER_COMBINER_CLASS, DoubleValueSumCombiner.class.getName(), ComputerOptions.WORKER_DATA_DIRS, "[data_dir1, data_dir2]", ComputerOptions.WORKER_RECEIVED_BUFFERS_BYTES_LIMIT, "10", ComputerOptions.ALGORITHM_MESSAGE_CLASS, DoubleValue.class.getName());
        FileUtils.deleteQuietly(new File("data_dir1"));
        FileUtils.deleteQuietly(new File("data_dir2"));
        FileManager fileManager = new FileManager();
        fileManager.init(updateWithRequiredOptions);
        RecvSortManager recvSortManager = new RecvSortManager(context());
        recvSortManager.init(updateWithRequiredOptions);
        ComputeMessageRecvPartition computeMessageRecvPartition = new ComputeMessageRecvPartition(context(), new SuperstepFileGenerator(fileManager, 0), recvSortManager);
        Assert.assertEquals("msg", computeMessageRecvPartition.type());
        computeMessageRecvPartition.getClass();
        addTwentyCombineMessageBuffer(computeMessageRecvPartition::addBuffer);
        checkTenCombineMessages(computeMessageRecvPartition.iterator());
        fileManager.close(updateWithRequiredOptions);
        recvSortManager.close(updateWithRequiredOptions);
    }

    @Test
    public void testNotCombineMessageRecvPartition() throws IOException {
        Config updateWithRequiredOptions = UnitTestBase.updateWithRequiredOptions(ComputerOptions.JOB_ID, "local_001", ComputerOptions.JOB_WORKERS_COUNT, "1", ComputerOptions.JOB_PARTITIONS_COUNT, "1", ComputerOptions.WORKER_COMBINER_CLASS, Null.class.getName(), ComputerOptions.WORKER_DATA_DIRS, "[data_dir1, data_dir2]", ComputerOptions.WORKER_RECEIVED_BUFFERS_BYTES_LIMIT, "10", ComputerOptions.ALGORITHM_MESSAGE_CLASS, IdList.class.getName());
        FileUtils.deleteQuietly(new File("data_dir1"));
        FileUtils.deleteQuietly(new File("data_dir2"));
        FileManager fileManager = new FileManager();
        fileManager.init(updateWithRequiredOptions);
        RecvSortManager recvSortManager = new RecvSortManager(context());
        recvSortManager.init(updateWithRequiredOptions);
        ComputeMessageRecvPartition computeMessageRecvPartition = new ComputeMessageRecvPartition(context(), new SuperstepFileGenerator(fileManager, 0), recvSortManager);
        Assert.assertEquals("msg", computeMessageRecvPartition.type());
        computeMessageRecvPartition.getClass();
        addTwentyDuplicateIdValueListMessageBuffer(computeMessageRecvPartition::addBuffer);
        checkIdValueListMessages(computeMessageRecvPartition.iterator());
        fileManager.close(updateWithRequiredOptions);
        recvSortManager.close(updateWithRequiredOptions);
    }

    public static void addTwentyCombineMessageBuffer(Consumer<ManagedBuffer> consumer) throws IOException {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= 10) {
                return;
            }
            for (int i = 0; i < 2; i++) {
                ReceiverUtil.consumeBuffer(ReceiverUtil.writeMessage(BytesId.of(j2), new DoubleValue(j2)), consumer);
            }
            j = j2 + 1;
        }
    }

    public static void checkTenCombineMessages(PeekableIterator<KvEntry> peekableIterator) throws IOException {
        Assert.assertTrue(peekableIterator.hasNext());
        KvEntry kvEntry = (KvEntry) peekableIterator.next();
        Id readId = ReceiverUtil.readId(kvEntry.key());
        DoubleValue doubleValue = new DoubleValue();
        ReceiverUtil.readValue(kvEntry.value(), doubleValue);
        while (peekableIterator.hasNext()) {
            Id readId2 = ReceiverUtil.readId(((KvEntry) peekableIterator.next()).key());
            DoubleValue doubleValue2 = new DoubleValue();
            ReceiverUtil.readValue(kvEntry.value(), doubleValue2);
            if (readId.equals(readId2)) {
                doubleValue.value(doubleValue.value().doubleValue() + doubleValue2.value().doubleValue());
            } else {
                Assert.assertEquals(((Long) readId.asObject()).longValue() * 2.0d, doubleValue.value().doubleValue(), 0.0d);
            }
        }
        Assert.assertEquals(((Long) readId.asObject()).longValue() * 2.0d, doubleValue.value().doubleValue(), 0.0d);
    }

    private static void addTwentyDuplicateIdValueListMessageBuffer(Consumer<ManagedBuffer> consumer) throws IOException {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= 10) {
                return;
            }
            for (int i = 0; i < 2; i++) {
                BytesId of = BytesId.of(j2);
                IdList idList = new IdList();
                idList.add(of);
                ReceiverUtil.consumeBuffer(ReceiverUtil.writeMessage(of, idList), consumer);
            }
            j = j2 + 1;
        }
    }

    private static void checkIdValueListMessages(PeekableIterator<KvEntry> peekableIterator) throws IOException {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= 10) {
                Assert.assertFalse(peekableIterator.hasNext());
                return;
            }
            for (int i = 0; i < 2; i++) {
                Assert.assertTrue(peekableIterator.hasNext());
                KvEntry kvEntry = (KvEntry) peekableIterator.next();
                Id readId = ReceiverUtil.readId(kvEntry.key());
                BytesId of = BytesId.of(j2);
                Assert.assertEquals(of, readId);
                IdList idList = new IdList();
                idList.add(of);
                IdList idList2 = new IdList();
                ReceiverUtil.readValue(kvEntry.value(), idList2);
                Assert.assertEquals(idList, idList2);
            }
            j = j2 + 1;
        }
    }
}
