import { IconProp } from '@fortawesome/fontawesome-svg-core'
import { faBrain } from '@fortawesome/pro-regular-svg-icons'
import {
  faHatWizard,
  faRabbitRunning,
  faScaleBalanced,
} from '@fortawesome/pro-solid-svg-icons'
import { t } from '@lingui/macro'
import { without } from 'lodash'

import { config } from 'config'
import { SavedMedia, SavedMediaContext } from 'modules/api'
import { FeatureFlags, featureFlags } from 'modules/featureFlags'
import { ProductKey } from 'modules/monetization'

import { ImageErrorType } from '../components/AIGeneratedImages/hooks/hooks'
import { StylePresetId } from '../components/AIGeneratedImages/types'
import FluxLogo from '../providers/icons/bfl.png'
import GeminiLogo from '../providers/icons/google-gemini.svg'
import IdeogramLogo from '../providers/icons/ideogram.svg'
import OpenAILogo from '../providers/icons/openai.svg'
import PlaygroundLogo from '../providers/icons/playground.svg'
import StabilityLogo from '../providers/icons/stability.svg'

export type AspectRatio = {
  key: AspectRatioKey
  name: () => string
  width: number
  height: number
}
export type AspectRatioKey = 'square' | 'landscape' | 'portrait'

export const DEFAULT_ASPECT_RATIO: AspectRatio = {
  key: 'square',
  name: () => t`Square`,
  width: 1024,
  height: 1024,
}

export type ImageGenerateModel =
  | 'stable-diffusion-xl-v1-0'
  | 'playground-2.5'
  | 'playground-3'
  | 'dall-e-3'
  | 'imagen-3-flash'
  | 'imagen-3-pro'
  | 'ideogram-v1'
  | 'ideogram-v1-turbo'
  | 'ideogram-v2'
  | 'ideogram-v2-turbo'
  | 'flux-1-schnell'
  | 'flux-1-dev'
  | 'flux-1-pro'

const IdeogramModel = {
  icon: faBrain,
  image: IdeogramLogo,
  flag: 'ideogramTurbo',
  aspectRatios: {
    square: {
      key: 'square',
      name: () => t`Square`,
      width: 1536,
      height: 1536,
    },
    landscape: {
      key: 'landscape',
      name: () => t`Landscape`,
      width: 1792,
      height: 1344,
    },
    portrait: {
      key: 'portrait',
      name: () => t`Portrait`,
      width: 1344,
      height: 1792,
    },
  },
  provider: 'ideogram',
} as const

// If image generation on a model fails, fall back to a model in the same tier, prioritized by quality/cost
const FREE_FALLBACK_MODELS: ImageGenerateModel[] = [
  'flux-1-schnell',
  'imagen-3-flash',
  'playground-2.5',
]

const PREMIUM_FALLBACK_MODELS: ImageGenerateModel[] = [
  'imagen-3-pro',
  'flux-1-pro',
  'dall-e-3',
]

export const IMAGE_GENERATE_MODELS: Record<
  ImageGenerateModel,
  {
    label: () => string
    description: () => string
    icon: IconProp
    image?: StaticImageData
    minProductTier?: ProductKey // Minimum plan required to use this model (free if not specified)
    flag?: keyof FeatureFlags // Determines if its enabled at all
    freeFlag?: keyof FeatureFlags // Determines if everoyne gets it
    disabledFlag?: keyof FeatureFlags // Determines if its disabled in favor of another version
    aspectRatios: Partial<Record<AspectRatioKey, AspectRatio>>
    provider: ImageGenerateProvider
    fallbackModels?: ImageGenerateModel[]
    needsPromptRewrite?: boolean
  }
> = {
  'stable-diffusion-xl-v1-0': {
    label: () => 'Stable Diffusion XL',
    description: () => t`Faster images in a variety of styles`,
    icon: faRabbitRunning,
    image: StabilityLogo,
    flag: 'sdxlModel',
    fallbackModels: without(FREE_FALLBACK_MODELS, 'stable-diffusion-xl-v1-0'),
    minProductTier: 'free',
    aspectRatios: {
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1024,
        height: 1024,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1216,
        height: 832,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 896,
        height: 1152,
      },
    },
    provider: 'baseten',
  },
  'playground-2.5': {
    label: () => 'Playground 2.5',
    description: () =>
      t`Faster images with vivid colors, best for illustrations`,
    icon: faRabbitRunning,
    image: PlaygroundLogo,
    flag: 'playgroundModel',
    fallbackModels: without(FREE_FALLBACK_MODELS, 'playground-2.5'),
    minProductTier: 'free',
    aspectRatios: {
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1024,
        height: 1024,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1216,
        height: 832,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 896,
        height: 1152,
      },
    },
    provider: 'playground',
  },
  'playground-3': {
    label: () => 'Playground 3',
    description: () => t`Best for detailed prompts, capable of text and people`,
    icon: faRabbitRunning,
    image: PlaygroundLogo,
    flag: 'playground3',
    fallbackModels: without(PREMIUM_FALLBACK_MODELS, 'playground-3'),
    minProductTier: 'pro',
    aspectRatios: {
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1024,
        height: 1024,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1216,
        height: 832,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 896,
        height: 1152,
      },
    },
    provider: 'playground',
  },
  'flux-1-schnell': {
    label: () => 'Flux Fast',
    description: () => t`Fastest model with good quality`,
    icon: faRabbitRunning,
    image: FluxLogo,
    flag: 'flux1Schnell',
    fallbackModels: without(FREE_FALLBACK_MODELS, 'flux-1-schnell'),
    aspectRatios: {
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1024,
        height: 1024,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1216,
        height: 832,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 896,
        height: 1152,
      },
    },
    provider: 'baseten',
  },
  'flux-1-dev': {
    label: () => 'Flux Dev',
    description: () => t`Good quality for people and text`,
    icon: faRabbitRunning,
    image: FluxLogo,
    flag: 'flux1Dev',
    fallbackModels: without(FREE_FALLBACK_MODELS, 'flux-1-dev'),
    minProductTier: 'free',
    aspectRatios: {
      // https://www.reddit.com/r/StableDiffusion/comments/1enxdga/flux_recommended_resolutions_from_01_to_20/
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1024,
        height: 1024,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1216,
        height: 832,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 896,
        height: 1152,
      },
    },
    provider: 'flux',
  },
  'flux-1-pro': {
    label: () => 'Flux Pro 1.1',
    description: () => t`Professional quality people, faces, and text`,
    icon: faRabbitRunning,
    image: FluxLogo,
    flag: 'flux1Pro',
    freeFlag: 'fluxProFree',
    minProductTier: 'plus',
    fallbackModels: without(PREMIUM_FALLBACK_MODELS, 'flux-1-pro'),
    aspectRatios: {
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1440,
        height: 1440,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1440,
        height: 960,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 960,
        height: 1440,
      },
    },
    provider: 'flux',
  },
  'imagen-3-flash': {
    label: () => 'Imagen 3 Fast',
    description: () => t`Google's faster model, good for detailed instructions`,
    icon: faScaleBalanced,
    image: GeminiLogo,
    minProductTier: 'free',
    flag: 'imagenFlash',
    fallbackModels: without(FREE_FALLBACK_MODELS, 'imagen-3-flash'),
    aspectRatios: {
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1536,
        height: 1536,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1792,
        height: 1344,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 1344,
        height: 1792,
      },
    },
    provider: 'google',
  },
  'imagen-3-pro': {
    label: () => 'Imagen 3',
    description: () =>
      t`Google's most advanced model, good for text and people`,
    icon: faScaleBalanced,
    image: GeminiLogo,
    minProductTier: 'plus',
    flag: 'imagen3',
    freeFlag: 'imagenFree',
    fallbackModels: without(PREMIUM_FALLBACK_MODELS, 'imagen-3-pro'),
    aspectRatios: {
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1536,
        height: 1536,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1792,
        height: 1344,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 1344,
        height: 1792,
      },
    },
    provider: 'google',
  },
  'ideogram-v1-turbo': {
    ...IdeogramModel,
    label: () => 'Ideogram 1.0 Turbo',
    description: () => t`Fast and good for text`,
    flag: 'ideogramTurbo',
    freeFlag: 'ideogramTurboFree',
    disabledFlag: 'ideogram2',
    minProductTier: 'pro',
    fallbackModels: [
      'ideogram-v2-turbo',
      ...without(PREMIUM_FALLBACK_MODELS, 'ideogram-v1-turbo'),
    ],
  },
  'ideogram-v1': {
    ...IdeogramModel,
    label: () => 'Ideogram 1.0',
    description: () => t`Best for text, high quality overall`,
    flag: 'ideogram',
    disabledFlag: 'ideogram2',
    minProductTier: 'pro',
    fallbackModels: [
      'ideogram-v2-turbo',
      ...without(PREMIUM_FALLBACK_MODELS, 'ideogram-v1-turbo'),
    ],
  },
  'ideogram-v2': {
    ...IdeogramModel,
    label: () => 'Ideogram 2.0',
    description: () => t`Best for text, high quality overall`,
    flag: 'ideogram2',
    minProductTier: 'pro',
    fallbackModels: without(PREMIUM_FALLBACK_MODELS, 'ideogram-v2'),
  },
  'ideogram-v2-turbo': {
    ...IdeogramModel,
    label: () => 'Ideogram 2.0 Turbo',
    description: () => t`Fast and good for text`,
    flag: 'ideogram2',
    freeFlag: 'ideogramTurboFree',
    minProductTier: 'pro',
    fallbackModels: without(PREMIUM_FALLBACK_MODELS, 'ideogram-v2-turbo'),
  },
  'dall-e-3': {
    label: () => 'DALL·E 3',
    description: () => t`OpenAI's most advanced model, high quality but slower`,
    icon: faHatWizard,
    image: OpenAILogo,
    minProductTier: 'pro',
    flag: 'dalle3',
    freeFlag: 'dalle3Free',
    fallbackModels: without(PREMIUM_FALLBACK_MODELS, 'dall-e-3'),
    aspectRatios: {
      square: {
        key: 'square',
        name: () => t`Square`,
        width: 1024,
        height: 1024,
      },
      landscape: {
        key: 'landscape',
        name: () => t`Landscape`,
        width: 1792,
        height: 1024,
      },
      portrait: {
        key: 'portrait',
        name: () => t`Portrait`,
        width: 1024,
        height: 1792,
      },
    },
    provider: 'azure',
  },
}

export const getImageModelInfo = (model: ImageGenerateModel) =>
  IMAGE_GENERATE_MODELS[model] ||
  IMAGE_GENERATE_MODELS[featureFlags.get('aiGeneratedImagesDefaultModel')]

export type ImageGenerateProvider =
  | 'baseten'
  | 'openai'
  | 'azure'
  | 'google'
  | 'playground'
  | 'ideogram'
  | 'flux'

export type GenerateImageOptions = {
  interactionId: string
  prompt: string
  workspaceId: string
  themeId?: string
  upscaleFactor?: number
  // context is not here, fetchGenerateImage will set based on themeId and docId
  docId?: string
  model: ImageGenerateModel
  fallbackModel?: ImageGenerateModel
  context?: SavedMediaContext
  // dont export this year
  count: number
  width?: number
  height?: number
  // AI Image Generation
  stylePreset?: StylePresetId
  stylePrompt?: string
  //for theme generation use
  negative_prompt?: string
  steps?: number
  cfg_scale?: number
  rewrite?: boolean
}

export const fetchGenerateImage = async (
  options: Partial<GenerateImageOptions>
): Promise<SavedMedia[]> => {
  const { themeId, docId, ...rest } = options
  // set the context based on themeId or docId
  const contextObj: Pick<
    GenerateImageOptions,
    'docId' | 'themeId' | 'context'
  > = themeId
    ? {
        context: SavedMediaContext.Theme,
        themeId: options.themeId,
      }
    : docId
    ? {
        context: SavedMediaContext.Doc,
        docId: options.docId,
      }
    : {
        // we dont have either docId or themeId, respect the context
        // passed in
        context: options.context,
      }

  const defaultModel = featureFlags.get('aiGeneratedImagesDefaultModel')
  const model = options.model ?? defaultModel
  const url = `${
    config.API_HOST || 'https://api.gamma.app'
  }/media/images/generate`

  const fallbackModels =
    getImageModelInfo(model).fallbackModels?.filter(isImageModelAvailable) || []
  const fallbackModel: ImageGenerateModel | undefined = fallbackModels.includes(
    defaultModel
  )
    ? defaultModel
    : fallbackModels[0]

  const req = await fetch(url, {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json',
    },
    body: JSON.stringify({
      model,
      ...rest,
      fallbackModel,
      ...contextObj,
    }),
    credentials: 'include',
  })
  if (!req.ok) {
    const json = await req.json()
    const message = `${json.error}: ${json.message}`
    if (json.code === 'prohibited_input' && json.categories) {
      throw new ProhibitedInputError({ categories: json.categories })
    }
    throw new GenerateImageError(message, json)
  }
  return req.json()
}

class GenerateImageError extends Error {
  response: any

  constructor(message: string, response) {
    super(message, response)
    this.response = response
  }
}
GenerateImageError.prototype.name = 'GenerateImageError'

export const ImageErrorMessages: Record<ImageErrorType, () => string> = {
  sexual: () =>
    t`This prompt was blocked because it could generate sexual imagery.`,
  violence: () =>
    t`This prompt was blocked because it could generate violent imagery.`,
  prohibited: () =>
    t`This prompt was blocked because it could generate inappropriate content.`,
}

export class ProhibitedInputError extends Error {
  messageTranslated: string
  code = 'prohibited_input'
  category: string

  constructor({ categories }: { categories: Record<ImageErrorType, any> }) {
    const message = `Cannot generate image, prohibited (reasons=${JSON.stringify(
      Object.keys(categories)
    )})`
    super(message)
    this.category = this.parseCategories(categories)
    this.messageTranslated =
      ImageErrorMessages[this.category]() || t`This prompt was blocked`
  }

  parseCategories(categories: Record<ImageErrorType, any>): string {
    if (categories.sexual || categories['sexual/minors']) {
      return 'sexual'
    } else if (categories.violence || categories['violence/graphic']) {
      return 'violence'
    } else {
      return 'prohbited'
    }
  }
}

export const isImageModelAvailable = (model: ImageGenerateModel | string) => {
  const modelInfo = IMAGE_GENERATE_MODELS[model]
  if (!modelInfo) return false
  if (
    modelInfo.flag &&
    !featureFlags.get(modelInfo.flag) &&
    model !== featureFlags.get('aiGeneratedImagesDefaultModel')
  ) {
    return false
  }
  if (modelInfo.disabledFlag && featureFlags.get(modelInfo.disabledFlag)) {
    return false
  }
  return true
}

export const getRequiredPlanForImageModel = (
  model: ImageGenerateModel | string
): ProductKey => {
  const modelInfo = IMAGE_GENERATE_MODELS[model]
  if (!modelInfo) return 'free'
  if (modelInfo.freeFlag && featureFlags.get(modelInfo.freeFlag)) {
    return 'free'
  }
  return modelInfo.minProductTier || 'free'
}
