import React from 'react';
import styled from 'styled-components';
import { ResponsiveLine } from '@nivo/line';
import { Theme, belowOrEqualTo } from '@allenai/varnish';
import { ThemeProvider, SvgWrapper } from '@nivo/core';
import { BoxLegendSvg } from '@nivo/legends';

import { SubHeader } from './shared';

export const OutperformingOthers = () => {
    const def = Theme.nivo.theme.defaults as any;
    delete def.theme.curve;
    delete def.theme.tooltip.basic.whitespace;

    const data = [
        {
            id: 'Random Init',
            label: 'Random Init',
            data: [
                { x: '50', y: 0.12 },
                { x: '100', y: 0.167 },
                { x: '250', y: 0.253 },
                { x: '500', y: 0.317 },
                { x: '1k', y: 0.393 },
                { x: '2.5k', y: 0.47 },
                { x: '5k', y: 0.513 },
                { x: 'All', y: 0.54 },
            ],
        },
        {
            id: 'ImageNet',
            label: 'ImageNet',
            data: [
                { x: '50', y: 0.133 },
                { x: '100', y: 0.193 },
                { x: '250', y: 0.29 },
                { x: '500', y: 0.37 },
                { x: '1k', y: 0.45 },
                { x: '2.5k', y: 0.547 },
                { x: '5k', y: 0.633 },
                { x: 'All', y: 0.657 },
            ],
        },
        {
            id: 'BigEarthNet',
            label: 'BigEarthNet',
            data: [
                { x: '50', y: 0.153 },
                { x: '100', y: 0.223 },
                { x: '250', y: 0.317 },
                { x: '500', y: 0.413 },
                { x: '1k', y: 0.49 },
                { x: '2.5k', y: 0.577 },
                { x: '5k', y: 0.63 },
                { x: 'All', y: 0.663 },
            ],
        },
        {
            id: 'Million-AID',
            label: 'Million-AID',
            data: [
                { x: '50', y: 0.213 },
                { x: '100', y: 0.317 },
                { x: '250', y: 0.433 },
                { x: '500', y: 0.503 },
                { x: '1k', y: 0.567 },
                { x: '2.5k', y: 0.637 },
                { x: '5k', y: 0.67 },
                { x: 'All', y: 0.71 },
            ],
        },
        {
            id: 'DOTA',
            label: 'DOTA',
            data: [
                { x: '50', y: 0.213 },
                { x: '100', y: 0.29 },
                { x: '250', y: 0.43 },
                { x: '500', y: 0.517 },
                { x: '1k', y: 0.577 },
                { x: '2.5k', y: 0.64 },
                { x: '5k', y: 0.677 },
                { x: 'All', y: 0.693 },
            ],
        },
        {
            id: 'iSaid',
            label: 'iSaid',
            data: [
                { x: '50', y: 0.223 },
                { x: '100', y: 0.31 },
                { x: '250', y: 0.433 },
                { x: '500', y: 0.533 },
                { x: '1k', y: 0.59 },
                { x: '2.5k', y: 0.66 },
                { x: '5k', y: 0.69 },
                { x: 'All', y: 0.71 },
            ],
        },
        {
            id: 'MoCo-v2',
            label: 'MoCo-v2',
            data: [
                { x: '50', y: 0.047 },
                { x: '100', y: 0.05 },
                { x: '250', y: 0.047 },
                { x: '500', y: 0.073 },
                { x: '1k', y: 0.087 },
                { x: '2.5k', y: 0.073 },
                { x: '5k', y: 0.073 },
                { x: 'All', y: 0.08 },
            ],
        },
        {
            id: 'SeCo',
            label: 'SeCo',
            data: [
                { x: '50', y: 0.167 },
                { x: '100', y: 0.237 },
                { x: '250', y: 0.353 },
                { x: '500', y: 0.43 },
                { x: '1k', y: 0.5 },
                { x: '2.5k', y: 0.577 },
                { x: '5k', y: 0.6 },
                { x: 'All', y: 0.633 },
            ],
        },
        {
            id: 'SatlasPretrain',
            label: 'SatlasPretrain',
            data: [
                { x: '50', y: 0.277 },
                { x: '100', y: 0.36 },
                { x: '250', y: 0.483 },
                { x: '500', y: 0.553 },
                { x: '1k', y: 0.593 },
                { x: '2.5k', y: 0.65 },
                { x: '5k', y: 0.69 },
                { x: 'All', y: 0.73 },
            ],
        },
    ];

    const legendWidth = 110;

    const SymbolShape = ({ x, y, id }: { x: number; y: number; id: string | number }) => {
        const s = getFromKey(id);
        return (
            <>
                <rect
                    x={x}
                    y={y + (14 - s.strokeWidth) / 2}
                    fill={s.color}
                    width={17}
                    height={s.strokeWidth}
                />
                {s.strokeDasharray === dashed0 ? (
                    <>
                        {[...Array(9).keys()].map((v: number) => {
                            return (
                                <rect
                                    key={v}
                                    x={x + v * 2}
                                    y={y}
                                    fill={'white'}
                                    width={1}
                                    height={20}
                                />
                            );
                        })}
                    </>
                ) : null}
            </>
        );
    };

    const SvgSymbolShape = (props: { x: number; y: number; id: string | number }) => {
        return (
            <svg width="17" height="12">
                <SymbolShape {...props} />
            </svg>
        );
    };

    const BoxLegendHalf = ({ data }: { data: any[] }) => {
        return (
            <SvgWrapper
                height={(Theme.spacing.sm.getValue() + Theme.spacing.xs2.getValue()) * 5}
                width={legendWidth}
                margin={{ left: 0, right: 0, top: 0, bottom: 0 }}>
                <BoxLegendSvg
                    anchor="center"
                    data={data}
                    containerWidth={legendWidth}
                    containerHeight={
                        (Theme.spacing.sm.getValue() + Theme.spacing.xs2.getValue()) * data.length
                    }
                    direction="column"
                    itemWidth={legendWidth}
                    itemHeight={Theme.spacing.sm.getValue()}
                    itemsSpacing={Theme.spacing.xs2.getValue()}
                    symbolSize={parseFloat(Theme.typography.textStyles.micro.fontSize) * 16}
                    itemDirection="left-to-right"
                    effects={[
                        {
                            on: 'hover',
                            style: {
                                itemBackground: 'rgba(0, 0, 0, .03)',
                                itemOpacity: 1,
                            },
                        },
                    ]}
                    symbolShape={SymbolShape}
                />
            </SvgWrapper>
        );
    };

    return (
        <React.Fragment>
            <SubHeader>Outperforming Other Systems</SubHeader>
            <p>
                Pre-training on SatlasPretrain substantially improves accuracy on downstream tasks
                like road segmentation and ship detection. We compared the downstream performance of
                pre-training on SatlasPretrain against pre-training on ImageNet and other methods.
                When averaged over 3 datasets with 100 downstream examples, SatlasPretrain improved
                average accuracy by 16% over ImageNet.
            </p>
            <ChartWithResponsiveLegend>
                <ChartWrapper>
                    <ResponsiveLine
                        {...def}
                        data={data}
                        margin={{
                            top: Theme.spacing.xs.getValue(),
                            right: Theme.spacing.xs.getValue(),
                            bottom: Theme.spacing.xl3.getValue(),
                            left: Theme.spacing.xl3.getValue(),
                        }}
                        xScale={{ type: 'point' }}
                        yScale={{
                            type: 'linear',
                            min: 0,
                            max: 'auto',
                        }}
                        axisBottom={{
                            legend: 'Number of Training Samples',
                            legendOffset: Theme.spacing.xl2.getValue(),
                            legendPosition: 'middle',
                        }}
                        axisLeft={{
                            legend: 'Average Performance',
                            legendOffset: -Theme.spacing.xl2.getValue(),
                            legendPosition: 'middle',
                            tickValues: [0, 0.2, 0.4, 0.6, 0.73],
                        }}
                        pointSize={0}
                        enableSlices="x"
                        yFormat={(v) =>
                            v.toLocaleString(undefined, {
                                minimumFractionDigits: 2,
                                maximumFractionDigits: 2,
                            })
                        }
                        sliceTooltip={({ slice }) => {
                            return (
                                <Tooltip>
                                    <div>
                                        <strong>{slice.points[0].data.x} Training Samples</strong>
                                    </div>
                                    <SliceDiv>
                                        {slice.points.map((point) => (
                                            <React.Fragment key={point.id}>
                                                <SvgSymbolShape x={0} y={0} id={point.serieId} />
                                                <span>{point.serieId}</span>
                                                <span>{point.data.yFormatted}</span>
                                            </React.Fragment>
                                        ))}
                                    </SliceDiv>
                                </Tooltip>
                            );
                        }}
                        layers={[
                            'grid',
                            'markers',
                            'axes',
                            'areas',
                            'crosshair',
                            'line',
                            'slices',
                            'points',
                            'mesh',
                            'legends',
                            DashedLine,
                        ]}
                    />
                </ChartWrapper>
                <ThemeProvider>
                    <LegendGrid>
                        <BoxLegendHalf data={[...data].reverse().slice(0, 5)} />
                        <BoxLegendHalf data={[...data].reverse().slice(5, 9)} />
                    </LegendGrid>
                </ThemeProvider>
            </ChartWithResponsiveLegend>
        </React.Fragment>
    );
};

const ChartWrapper = styled.div`
    height: 375px;
    width: 100%;
    min-width: 0;

    @media ${({ theme }) => belowOrEqualTo(theme.breakpoints.xs)} {
        order: 1;
    }
`;

const LegendGrid = styled.div`
    display: grid;
    grid-template-columns: auto;
    grid-template-rows: auto auto 1fr;

    @media ${({ theme }) => belowOrEqualTo(theme.breakpoints.xs)} {
        grid-template-columns: auto auto;
        grid-template-rows: auto;
    }
`;

const ChartWithResponsiveLegend = styled.div`
    display: grid;
    grid-template-columns: 1fr auto;
    gap: ${({ theme }) => theme.spacing.md.px};
    padding: ${({ theme }) => theme.spacing.md.px};

    @media ${({ theme }) => belowOrEqualTo(theme.breakpoints.xs)} {
        grid-template-columns: 1fr;
        justify-items: center;
    }
`;

const Tooltip = styled.div`
    background: ${({ theme }) => theme.color.N1};
    padding: ${({ theme }) => theme.spacing.sm};
    border: 1px solid ${({ theme }) => theme.palette.border.dark};
`;

const SliceDiv = styled.div`
    display: grid;
    grid-template-columns: auto 1fr auto;
    align-items: center;
    gap: 0 ${({ theme }) => theme.spacing.xs2.px};
    margin-top: ${({ theme }) => theme.spacing.xs};
`;

const solid = '1,0';
const dashed0 = '1,1';
const seriesStyle = [
    { color: Theme.lightCategoricalColor.Teal.hex, strokeDasharray: dashed0, strokeWidth: 3 },
    { color: Theme.lightCategoricalColor.Magenta.hex, strokeDasharray: dashed0, strokeWidth: 3 },
    { color: Theme.lightCategoricalColor.Orange.hex, strokeDasharray: dashed0, strokeWidth: 3 },
    { color: Theme.lightCategoricalColor.Purple.hex, strokeDasharray: dashed0, strokeWidth: 3 },
    { color: Theme.lightCategoricalColor.Green.hex, strokeDasharray: dashed0, strokeWidth: 3 },
    { color: Theme.lightCategoricalColor.Red.hex, strokeDasharray: dashed0, strokeWidth: 3 },
    { color: Theme.lightCategoricalColor.Blue.hex, strokeDasharray: dashed0, strokeWidth: 3 },
    { color: Theme.lightCategoricalColor.Aqua.hex, strokeDasharray: dashed0, strokeWidth: 3 },
    { color: Theme.darkCategoricalColor.Blue.hex, strokeDasharray: solid, strokeWidth: 5 },
];

const DashedLine = ({ series, lineGenerator, xScale, yScale }: any) => {
    return series.map(({ id, data }: any) => (
        <path
            key={id}
            d={lineGenerator(
                data.map((d: any) => ({
                    x: xScale(d.data.x),
                    y: yScale(d.data.y),
                }))
            )}
            fill="none"
            stroke={getFromKey(id).color}
            opacity={getFromKey(id).strokeWidth === 5 ? 1 : 0.6}
            style={getFromKey(id)}
        />
    ));
};

const keys = [
    'Random Init',
    'ImageNet',
    'BigEarthNet',
    'Million-AID',
    'DOTA',
    'iSaid',
    'MoCo-v2',
    'SeCo',
    'SatlasPretrain',
];

const getFromKey = (key: string | number) => {
    return seriesStyle[keys.findIndex((k) => k === key)];
};
