

import os
import shutil
from paraview.simple import *
import glob
import math

def fit_view(camera, view, eye_index, up_index, bnds):

    padding_factor = 1
    parallel_projection = True

    directions = [0, 1, 2]
    directions.remove(eye_index)
    directions.remove(up_index)
    width_index = directions[0]

    xmn,xmx,ymn,ymx,zmn,zmx = bnds
    xl = xmx - xmn
    yl = ymx - ymn
    zl = zmx - zmn
    l = [xl, yl, zl]

    xc = xmn + xl * 0.5
    yc = ymn + yl * 0.5
    zc = zmn + zl * 0.5

    tanangle = math.tan(camera.GetViewAngle() * math.pi / 180.0 / 2.0)
    xeye = xl * padding_factor / (2 * tanangle)
    yeye = yl * padding_factor / (2 * tanangle)
    zeye = zl * padding_factor / (2 * tanangle)
    eye_d = [ xeye, yeye, zeye ]

    aspect_ratio = float(view.ViewSize[1]) / float(view.ViewSize[0])
    needed_height = aspect_ratio * l[width_index]

    eye = [ xc, yc, zc ]
    up = [0,0,0]
    height = max(needed_height, l[up_index]) * 0.5

    up[up_index] = 1
    eye[eye_index] += max(eye_d)

    camera.SetParallelProjection(parallel_projection)

    if parallel_projection:
        camera.SetParallelScale(height)
        
    camera.SetViewUp(up)
    camera.SetPosition(eye)
    camera.SetFocalPoint(xc,yc,zc)
    

def process(simulation_output_dir = "out"):
    render_output_dir = os.path.join(simulation_output_dir, "Renders")

    solids = glob.glob( os.path.join(simulation_output_dir, "Output", "SliceMovingBody_*.pvd"))
    walls = os.path.join(simulation_output_dir, "Output", "Walls.stl")
    last_only = True

    slices = []
    slices.extend([ (fn, 0, 1) for fn in  glob.glob(os.path.join(simulation_output_dir, "Output", "SliceX_*.pvd")) ])
    slices.extend([ (fn, 1, 2) for fn in  glob.glob(os.path.join(simulation_output_dir, "Output", "SliceY_*.pvd")) ])
    slices.extend([ (fn, 2, 1) for fn in  glob.glob(os.path.join(simulation_output_dir, "Output", "SliceZ_*.pvd")) ])

    if not os.path.isdir(render_output_dir):
        os.makedirs(render_output_dir)

    for slice_fn,eye_index,up_index in slices:
        
        print("Processing",slice_fn)
        nice_name = os.path.splitext(os.path.basename(slice_fn))[0]

        field_reader = PVDReader(FileName=slice_fn)
        annotime = AnnotateTimeFilter(Input=field_reader)

        view = CreateRenderView()
        view.ViewSize = (600,400)
        view.ViewTime = field_reader.TimestepValues[-1]
        
        for solid_fn in solids:
            print("Processing file",solid_fn)
            solid_reader = PVDReader(FileName=solid_fn)
            solid_display = Show(solid_reader, view)

        annotime_display = Show(annotime, view)
        annotime_display.FontSize = 12
        
        display = Show(field_reader, view)
        
        fit_view(camera=view.GetActiveCamera(), 
                 view=view, 
                 eye_index=eye_index, 
                 up_index=up_index, 
                 bnds=field_reader.GetDataInformation().GetBounds())
        Render(view)
        
        print(field_reader.CellArrays)
        variable_index = 0
        bar = None
        for variable_name in field_reader.CellArrays:
            
            view.ViewTime = field_reader.TimestepValues[-1]
            print("\tColoring by ",variable_name)

            data = field_reader.CellData[variable_name]
            n_components = data.GetNumberOfComponents()
            
            ColorBy(rep=display, value=("CELLS", variable_name))
            display.RescaleTransferFunctionToDataRange(True)
            
            print("Variable range",data.GetRange())
            lt = GetColorTransferFunction(variable_name, display)
            lt.ApplyPreset("Viridis (matplotlib)",True)
            lt.ScalarRangeInitialized = 1.0
            lt.VectorMode = 'Component'

            bar = CreateScalarBar(LookupTable=lt)
            bar.ComponentTitle = ""
            bar.LabelFontSize = 12
            bar.WindowLocation = "AnyLocation"
            bar.Position = (0,0.25)
            view.Representations.append(bar)
            
            for component_i in range(n_components):
                
                field_text = Text(Text=nice_name + "\n" + variable_name + "_" + str(component_i))
                field_text_display = Show(field_text, view)
                field_text_display.FontSize = 12
                field_text_display.WindowLocation = "LowerLeftCorner"
                
                print("Variable range",data.GetRange(component_i))
                lt.VectorComponent = component_i

                aspect_ratio = float(view.ViewSize[1]) / float(view.ViewSize[0])
                output_width = 1080
                output_height = int(aspect_ratio * output_width)

                time_steps_to_render = []
                if last_only:
                    time_steps_to_render = [ field_reader.TimestepValues[-1] ]
                else:
                    time_steps_to_render = field_reader.TimestepValues

                count = 0
                for t in time_steps_to_render:
                    fn = "variable_{0}_{1}_{2}.png".format(variable_index, component_i, count)
                    fn = os.path.join(render_output_dir, nice_name, fn)
                    
                    if not os.path.isdir(os.path.dirname(fn)):
                        os.makedirs(os.path.dirname(fn))
                    
                    print("Rendering time step", t, "to file", fn)
                    view.ViewTime = t
                    
                    Render(view)
                    SaveScreenshot(fn, view, ImageResolution=(output_width, output_height))
                    count += 1
                
                Delete(field_text)
                
            view.Representations.remove(bar)
            
            variable_index += 1
            
        ResetSession()

if __name__ == "__main__":
    
    process("out")
